1
2
3
4
5 package template
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "net/url"
13 "reflect"
14 "strings"
15 "sync"
16 "unicode"
17 "unicode/utf8"
18 )
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 type FuncMap map[string]any
35
36
37
38
39
40 func builtins() FuncMap {
41 return FuncMap{
42 "and": and,
43 "call": call,
44 "html": HTMLEscaper,
45 "index": index,
46 "slice": slice,
47 "js": JSEscaper,
48 "len": length,
49 "not": not,
50 "or": or,
51 "print": fmt.Sprint,
52 "printf": fmt.Sprintf,
53 "println": fmt.Sprintln,
54 "urlquery": URLQueryEscaper,
55
56
57 "eq": eq,
58 "ge": ge,
59 "gt": gt,
60 "le": le,
61 "lt": lt,
62 "ne": ne,
63 }
64 }
65
66 var builtinFuncsOnce struct {
67 sync.Once
68 v map[string]reflect.Value
69 }
70
71
72
73 func builtinFuncs() map[string]reflect.Value {
74 builtinFuncsOnce.Do(func() {
75 builtinFuncsOnce.v = createValueFuncs(builtins())
76 })
77 return builtinFuncsOnce.v
78 }
79
80
81 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
82 m := make(map[string]reflect.Value)
83 addValueFuncs(m, funcMap)
84 return m
85 }
86
87
88 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
89 for name, fn := range in {
90 if !goodName(name) {
91 panic(fmt.Errorf("function name %q is not a valid identifier", name))
92 }
93 v := reflect.ValueOf(fn)
94 if v.Kind() != reflect.Func {
95 panic("value for " + name + " not a function")
96 }
97 if !goodFunc(v.Type()) {
98 panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
99 }
100 out[name] = v
101 }
102 }
103
104
105
106 func addFuncs(out, in FuncMap) {
107 for name, fn := range in {
108 out[name] = fn
109 }
110 }
111
112
113 func goodFunc(typ reflect.Type) bool {
114
115 switch {
116 case typ.NumOut() == 1:
117 return true
118 case typ.NumOut() == 2 && typ.Out(1) == errorType:
119 return true
120 }
121 return false
122 }
123
124
125 func goodName(name string) bool {
126 if name == "" {
127 return false
128 }
129 for i, r := range name {
130 switch {
131 case r == '_':
132 case i == 0 && !unicode.IsLetter(r):
133 return false
134 case !unicode.IsLetter(r) && !unicode.IsDigit(r):
135 return false
136 }
137 }
138 return true
139 }
140
141
142 func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
143 if tmpl != nil && tmpl.common != nil {
144 tmpl.muFuncs.RLock()
145 defer tmpl.muFuncs.RUnlock()
146 if fn := tmpl.execFuncs[name]; fn.IsValid() {
147 return fn, false, true
148 }
149 }
150 if fn := builtinFuncs()[name]; fn.IsValid() {
151 return fn, true, true
152 }
153 return reflect.Value{}, false, false
154 }
155
156
157
158 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
159 if !value.IsValid() {
160 if !canBeNil(argType) {
161 return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
162 }
163 value = reflect.Zero(argType)
164 }
165 if value.Type().AssignableTo(argType) {
166 return value, nil
167 }
168 if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
169 value = value.Convert(argType)
170 return value, nil
171 }
172 return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
173 }
174
175 func intLike(typ reflect.Kind) bool {
176 switch typ {
177 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
178 return true
179 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
180 return true
181 }
182 return false
183 }
184
185
186 func indexArg(index reflect.Value, cap int) (int, error) {
187 var x int64
188 switch index.Kind() {
189 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
190 x = index.Int()
191 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
192 x = int64(index.Uint())
193 case reflect.Invalid:
194 return 0, fmt.Errorf("cannot index slice/array with nil")
195 default:
196 return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
197 }
198 if x < 0 || int(x) < 0 || int(x) > cap {
199 return 0, fmt.Errorf("index out of range: %d", x)
200 }
201 return int(x), nil
202 }
203
204
205
206
207
208
209 func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
210 item = indirectInterface(item)
211 if !item.IsValid() {
212 return reflect.Value{}, fmt.Errorf("index of untyped nil")
213 }
214 for _, index := range indexes {
215 index = indirectInterface(index)
216 var isNil bool
217 if item, isNil = indirect(item); isNil {
218 return reflect.Value{}, fmt.Errorf("index of nil pointer")
219 }
220 switch item.Kind() {
221 case reflect.Array, reflect.Slice, reflect.String:
222 x, err := indexArg(index, item.Len())
223 if err != nil {
224 return reflect.Value{}, err
225 }
226 item = item.Index(x)
227 case reflect.Map:
228 index, err := prepareArg(index, item.Type().Key())
229 if err != nil {
230 return reflect.Value{}, err
231 }
232 if x := item.MapIndex(index); x.IsValid() {
233 item = x
234 } else {
235 item = reflect.Zero(item.Type().Elem())
236 }
237 case reflect.Invalid:
238
239 panic("unreachable")
240 default:
241 return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
242 }
243 }
244 return item, nil
245 }
246
247
248
249
250
251
252
253 func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
254 item = indirectInterface(item)
255 if !item.IsValid() {
256 return reflect.Value{}, fmt.Errorf("slice of untyped nil")
257 }
258 if len(indexes) > 3 {
259 return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
260 }
261 var cap int
262 switch item.Kind() {
263 case reflect.String:
264 if len(indexes) == 3 {
265 return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
266 }
267 cap = item.Len()
268 case reflect.Array, reflect.Slice:
269 cap = item.Cap()
270 default:
271 return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
272 }
273
274 idx := [3]int{0, item.Len()}
275 for i, index := range indexes {
276 x, err := indexArg(index, cap)
277 if err != nil {
278 return reflect.Value{}, err
279 }
280 idx[i] = x
281 }
282
283 if idx[0] > idx[1] {
284 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
285 }
286 if len(indexes) < 3 {
287 return item.Slice(idx[0], idx[1]), nil
288 }
289
290 if idx[1] > idx[2] {
291 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
292 }
293 return item.Slice3(idx[0], idx[1], idx[2]), nil
294 }
295
296
297
298
299 func length(item reflect.Value) (int, error) {
300 item, isNil := indirect(item)
301 if isNil {
302 return 0, fmt.Errorf("len of nil pointer")
303 }
304 switch item.Kind() {
305 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
306 return item.Len(), nil
307 }
308 return 0, fmt.Errorf("len of type %s", item.Type())
309 }
310
311
312
313
314
315 func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
316 fn = indirectInterface(fn)
317 if !fn.IsValid() {
318 return reflect.Value{}, fmt.Errorf("call of nil")
319 }
320 typ := fn.Type()
321 if typ.Kind() != reflect.Func {
322 return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
323 }
324 if !goodFunc(typ) {
325 return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
326 }
327 numIn := typ.NumIn()
328 var dddType reflect.Type
329 if typ.IsVariadic() {
330 if len(args) < numIn-1 {
331 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
332 }
333 dddType = typ.In(numIn - 1).Elem()
334 } else {
335 if len(args) != numIn {
336 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
337 }
338 }
339 argv := make([]reflect.Value, len(args))
340 for i, arg := range args {
341 arg = indirectInterface(arg)
342
343 argType := dddType
344 if !typ.IsVariadic() || i < numIn-1 {
345 argType = typ.In(i)
346 }
347
348 var err error
349 if argv[i], err = prepareArg(arg, argType); err != nil {
350 return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
351 }
352 }
353 return safeCall(fn, argv)
354 }
355
356
357
358 func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
359 defer func() {
360 if r := recover(); r != nil {
361 if e, ok := r.(error); ok {
362 err = e
363 } else {
364 err = fmt.Errorf("%v", r)
365 }
366 }
367 }()
368 ret := fun.Call(args)
369 if len(ret) == 2 && !ret[1].IsNil() {
370 return ret[0], ret[1].Interface().(error)
371 }
372 return ret[0], nil
373 }
374
375
376
377 func truth(arg reflect.Value) bool {
378 t, _ := isTrue(indirectInterface(arg))
379 return t
380 }
381
382
383
384 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
385 panic("unreachable")
386 }
387
388
389
390 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
391 panic("unreachable")
392 }
393
394
395 func not(arg reflect.Value) bool {
396 return !truth(arg)
397 }
398
399
400
401
402
403 var (
404 errBadComparisonType = errors.New("invalid type for comparison")
405 errBadComparison = errors.New("incompatible types for comparison")
406 errNoComparison = errors.New("missing argument for comparison")
407 )
408
409 type kind int
410
411 const (
412 invalidKind kind = iota
413 boolKind
414 complexKind
415 intKind
416 floatKind
417 stringKind
418 uintKind
419 )
420
421 func basicKind(v reflect.Value) (kind, error) {
422 switch v.Kind() {
423 case reflect.Bool:
424 return boolKind, nil
425 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
426 return intKind, nil
427 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
428 return uintKind, nil
429 case reflect.Float32, reflect.Float64:
430 return floatKind, nil
431 case reflect.Complex64, reflect.Complex128:
432 return complexKind, nil
433 case reflect.String:
434 return stringKind, nil
435 }
436 return invalidKind, errBadComparisonType
437 }
438
439
440 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
441 arg1 = indirectInterface(arg1)
442 if arg1 != zero {
443 if t1 := arg1.Type(); !t1.Comparable() {
444 return false, fmt.Errorf("uncomparable type %s: %v", t1, arg1)
445 }
446 }
447 if len(arg2) == 0 {
448 return false, errNoComparison
449 }
450 k1, _ := basicKind(arg1)
451 for _, arg := range arg2 {
452 arg = indirectInterface(arg)
453 k2, _ := basicKind(arg)
454 truth := false
455 if k1 != k2 {
456
457 switch {
458 case k1 == intKind && k2 == uintKind:
459 truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
460 case k1 == uintKind && k2 == intKind:
461 truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
462 default:
463 if arg1 != zero && arg != zero {
464 return false, errBadComparison
465 }
466 }
467 } else {
468 switch k1 {
469 case boolKind:
470 truth = arg1.Bool() == arg.Bool()
471 case complexKind:
472 truth = arg1.Complex() == arg.Complex()
473 case floatKind:
474 truth = arg1.Float() == arg.Float()
475 case intKind:
476 truth = arg1.Int() == arg.Int()
477 case stringKind:
478 truth = arg1.String() == arg.String()
479 case uintKind:
480 truth = arg1.Uint() == arg.Uint()
481 default:
482 if arg == zero || arg1 == zero {
483 truth = arg1 == arg
484 } else {
485 if t2 := arg.Type(); !t2.Comparable() {
486 return false, fmt.Errorf("uncomparable type %s: %v", t2, arg)
487 }
488 truth = arg1.Interface() == arg.Interface()
489 }
490 }
491 }
492 if truth {
493 return true, nil
494 }
495 }
496 return false, nil
497 }
498
499
500 func ne(arg1, arg2 reflect.Value) (bool, error) {
501
502 equal, err := eq(arg1, arg2)
503 return !equal, err
504 }
505
506
507 func lt(arg1, arg2 reflect.Value) (bool, error) {
508 arg1 = indirectInterface(arg1)
509 k1, err := basicKind(arg1)
510 if err != nil {
511 return false, err
512 }
513 arg2 = indirectInterface(arg2)
514 k2, err := basicKind(arg2)
515 if err != nil {
516 return false, err
517 }
518 truth := false
519 if k1 != k2 {
520
521 switch {
522 case k1 == intKind && k2 == uintKind:
523 truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
524 case k1 == uintKind && k2 == intKind:
525 truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
526 default:
527 return false, errBadComparison
528 }
529 } else {
530 switch k1 {
531 case boolKind, complexKind:
532 return false, errBadComparisonType
533 case floatKind:
534 truth = arg1.Float() < arg2.Float()
535 case intKind:
536 truth = arg1.Int() < arg2.Int()
537 case stringKind:
538 truth = arg1.String() < arg2.String()
539 case uintKind:
540 truth = arg1.Uint() < arg2.Uint()
541 default:
542 panic("invalid kind")
543 }
544 }
545 return truth, nil
546 }
547
548
549 func le(arg1, arg2 reflect.Value) (bool, error) {
550
551 lessThan, err := lt(arg1, arg2)
552 if lessThan || err != nil {
553 return lessThan, err
554 }
555 return eq(arg1, arg2)
556 }
557
558
559 func gt(arg1, arg2 reflect.Value) (bool, error) {
560
561 lessOrEqual, err := le(arg1, arg2)
562 if err != nil {
563 return false, err
564 }
565 return !lessOrEqual, nil
566 }
567
568
569 func ge(arg1, arg2 reflect.Value) (bool, error) {
570
571 lessThan, err := lt(arg1, arg2)
572 if err != nil {
573 return false, err
574 }
575 return !lessThan, nil
576 }
577
578
579
580 var (
581 htmlQuot = []byte(""")
582 htmlApos = []byte("'")
583 htmlAmp = []byte("&")
584 htmlLt = []byte("<")
585 htmlGt = []byte(">")
586 htmlNull = []byte("\uFFFD")
587 )
588
589
590 func HTMLEscape(w io.Writer, b []byte) {
591 last := 0
592 for i, c := range b {
593 var html []byte
594 switch c {
595 case '\000':
596 html = htmlNull
597 case '"':
598 html = htmlQuot
599 case '\'':
600 html = htmlApos
601 case '&':
602 html = htmlAmp
603 case '<':
604 html = htmlLt
605 case '>':
606 html = htmlGt
607 default:
608 continue
609 }
610 w.Write(b[last:i])
611 w.Write(html)
612 last = i + 1
613 }
614 w.Write(b[last:])
615 }
616
617
618 func HTMLEscapeString(s string) string {
619
620 if !strings.ContainsAny(s, "'\"&<>\000") {
621 return s
622 }
623 var b bytes.Buffer
624 HTMLEscape(&b, []byte(s))
625 return b.String()
626 }
627
628
629
630 func HTMLEscaper(args ...any) string {
631 return HTMLEscapeString(evalArgs(args))
632 }
633
634
635
636 var (
637 jsLowUni = []byte(`\u00`)
638 hex = []byte("0123456789ABCDEF")
639
640 jsBackslash = []byte(`\\`)
641 jsApos = []byte(`\'`)
642 jsQuot = []byte(`\"`)
643 jsLt = []byte(`\u003C`)
644 jsGt = []byte(`\u003E`)
645 jsAmp = []byte(`\u0026`)
646 jsEq = []byte(`\u003D`)
647 )
648
649
650 func JSEscape(w io.Writer, b []byte) {
651 last := 0
652 for i := 0; i < len(b); i++ {
653 c := b[i]
654
655 if !jsIsSpecial(rune(c)) {
656
657 continue
658 }
659 w.Write(b[last:i])
660
661 if c < utf8.RuneSelf {
662
663
664 switch c {
665 case '\\':
666 w.Write(jsBackslash)
667 case '\'':
668 w.Write(jsApos)
669 case '"':
670 w.Write(jsQuot)
671 case '<':
672 w.Write(jsLt)
673 case '>':
674 w.Write(jsGt)
675 case '&':
676 w.Write(jsAmp)
677 case '=':
678 w.Write(jsEq)
679 default:
680 w.Write(jsLowUni)
681 t, b := c>>4, c&0x0f
682 w.Write(hex[t : t+1])
683 w.Write(hex[b : b+1])
684 }
685 } else {
686
687 r, size := utf8.DecodeRune(b[i:])
688 if unicode.IsPrint(r) {
689 w.Write(b[i : i+size])
690 } else {
691 fmt.Fprintf(w, "\\u%04X", r)
692 }
693 i += size - 1
694 }
695 last = i + 1
696 }
697 w.Write(b[last:])
698 }
699
700
701 func JSEscapeString(s string) string {
702
703 if strings.IndexFunc(s, jsIsSpecial) < 0 {
704 return s
705 }
706 var b bytes.Buffer
707 JSEscape(&b, []byte(s))
708 return b.String()
709 }
710
711 func jsIsSpecial(r rune) bool {
712 switch r {
713 case '\\', '\'', '"', '<', '>', '&', '=':
714 return true
715 }
716 return r < ' ' || utf8.RuneSelf <= r
717 }
718
719
720
721 func JSEscaper(args ...any) string {
722 return JSEscapeString(evalArgs(args))
723 }
724
725
726
727 func URLQueryEscaper(args ...any) string {
728 return url.QueryEscape(evalArgs(args))
729 }
730
731
732
733
734
735
736 func evalArgs(args []any) string {
737 ok := false
738 var s string
739
740 if len(args) == 1 {
741 s, ok = args[0].(string)
742 }
743 if !ok {
744 for i, arg := range args {
745 a, ok := printableValue(reflect.ValueOf(arg))
746 if ok {
747 args[i] = a
748 }
749 }
750 s = fmt.Sprint(args...)
751 }
752 return s
753 }
754
View as plain text