598 lines
14 KiB
Go
598 lines
14 KiB
Go
// Copyright 2011 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package template
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/url"
|
|
"reflect"
|
|
"strings"
|
|
"unicode"
|
|
"unicode/utf8"
|
|
)
|
|
|
|
// FuncMap is the type of the map defining the mapping from names to functions.
|
|
// Each function must have either a single return value, or two return values of
|
|
// which the second has type error. In that case, if the second (error)
|
|
// return value evaluates to non-nil during execution, execution terminates and
|
|
// Execute returns that error.
|
|
type FuncMap map[string]interface{}
|
|
|
|
var builtins = FuncMap{
|
|
"and": and,
|
|
"call": call,
|
|
"html": HTMLEscaper,
|
|
"index": index,
|
|
"js": JSEscaper,
|
|
"len": length,
|
|
"not": not,
|
|
"or": or,
|
|
"print": fmt.Sprint,
|
|
"printf": fmt.Sprintf,
|
|
"println": fmt.Sprintln,
|
|
"urlquery": URLQueryEscaper,
|
|
|
|
// Comparisons
|
|
"eq": eq, // ==
|
|
"ge": ge, // >=
|
|
"gt": gt, // >
|
|
"le": le, // <=
|
|
"lt": lt, // <
|
|
"ne": ne, // !=
|
|
}
|
|
|
|
var builtinFuncs = createValueFuncs(builtins)
|
|
|
|
// createValueFuncs turns a FuncMap into a map[string]reflect.Value
|
|
func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
|
|
m := make(map[string]reflect.Value)
|
|
addValueFuncs(m, funcMap)
|
|
return m
|
|
}
|
|
|
|
// addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
|
|
func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
|
|
for name, fn := range in {
|
|
v := reflect.ValueOf(fn)
|
|
if v.Kind() != reflect.Func {
|
|
panic("value for " + name + " not a function")
|
|
}
|
|
if !goodFunc(v.Type()) {
|
|
panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
|
|
}
|
|
out[name] = v
|
|
}
|
|
}
|
|
|
|
// addFuncs adds to values the functions in funcs. It does no checking of the input -
|
|
// call addValueFuncs first.
|
|
func addFuncs(out, in FuncMap) {
|
|
for name, fn := range in {
|
|
out[name] = fn
|
|
}
|
|
}
|
|
|
|
// goodFunc checks that the function or method has the right result signature.
|
|
func goodFunc(typ reflect.Type) bool {
|
|
// We allow functions with 1 result or 2 results where the second is an error.
|
|
switch {
|
|
case typ.NumOut() == 1:
|
|
return true
|
|
case typ.NumOut() == 2 && typ.Out(1) == errorType:
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// findFunction looks for a function in the template, and global map.
|
|
func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
|
|
if tmpl != nil && tmpl.common != nil {
|
|
if fn := tmpl.execFuncs[name]; fn.IsValid() {
|
|
return fn, true
|
|
}
|
|
}
|
|
if fn := builtinFuncs[name]; fn.IsValid() {
|
|
return fn, true
|
|
}
|
|
return reflect.Value{}, false
|
|
}
|
|
|
|
// Indexing.
|
|
|
|
// index returns the result of indexing its first argument by the following
|
|
// arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
|
|
// indexed item must be a map, slice, or array.
|
|
func index(item interface{}, indices ...interface{}) (interface{}, error) {
|
|
v := reflect.ValueOf(item)
|
|
for _, i := range indices {
|
|
index := reflect.ValueOf(i)
|
|
var isNil bool
|
|
if v, isNil = indirect(v); isNil {
|
|
return nil, fmt.Errorf("index of nil pointer")
|
|
}
|
|
switch v.Kind() {
|
|
case reflect.Array, reflect.Slice, reflect.String:
|
|
var x int64
|
|
switch index.Kind() {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
x = index.Int()
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
|
x = int64(index.Uint())
|
|
default:
|
|
return nil, fmt.Errorf("cannot index slice/array with type %s", index.Type())
|
|
}
|
|
if x < 0 || x >= int64(v.Len()) {
|
|
return nil, fmt.Errorf("index out of range: %d", x)
|
|
}
|
|
v = v.Index(int(x))
|
|
case reflect.Map:
|
|
if !index.IsValid() {
|
|
index = reflect.Zero(v.Type().Key())
|
|
}
|
|
if !index.Type().AssignableTo(v.Type().Key()) {
|
|
return nil, fmt.Errorf("%s is not index type for %s", index.Type(), v.Type())
|
|
}
|
|
if x := v.MapIndex(index); x.IsValid() {
|
|
v = x
|
|
} else {
|
|
v = reflect.Zero(v.Type().Elem())
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("can't index item of type %s", v.Type())
|
|
}
|
|
}
|
|
return v.Interface(), nil
|
|
}
|
|
|
|
// Length
|
|
|
|
// length returns the length of the item, with an error if it has no defined length.
|
|
func length(item interface{}) (int, error) {
|
|
v, isNil := indirect(reflect.ValueOf(item))
|
|
if isNil {
|
|
return 0, fmt.Errorf("len of nil pointer")
|
|
}
|
|
switch v.Kind() {
|
|
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
|
|
return v.Len(), nil
|
|
}
|
|
return 0, fmt.Errorf("len of type %s", v.Type())
|
|
}
|
|
|
|
// Function invocation
|
|
|
|
// call returns the result of evaluating the first argument as a function.
|
|
// The function must return 1 result, or 2 results, the second of which is an error.
|
|
func call(fn interface{}, args ...interface{}) (interface{}, error) {
|
|
v := reflect.ValueOf(fn)
|
|
typ := v.Type()
|
|
if typ.Kind() != reflect.Func {
|
|
return nil, fmt.Errorf("non-function of type %s", typ)
|
|
}
|
|
if !goodFunc(typ) {
|
|
return nil, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
|
|
}
|
|
numIn := typ.NumIn()
|
|
var dddType reflect.Type
|
|
if typ.IsVariadic() {
|
|
if len(args) < numIn-1 {
|
|
return nil, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
|
|
}
|
|
dddType = typ.In(numIn - 1).Elem()
|
|
} else {
|
|
if len(args) != numIn {
|
|
return nil, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
|
|
}
|
|
}
|
|
argv := make([]reflect.Value, len(args))
|
|
for i, arg := range args {
|
|
value := reflect.ValueOf(arg)
|
|
// Compute the expected type. Clumsy because of variadics.
|
|
var argType reflect.Type
|
|
if !typ.IsVariadic() || i < numIn-1 {
|
|
argType = typ.In(i)
|
|
} else {
|
|
argType = dddType
|
|
}
|
|
if !value.IsValid() && canBeNil(argType) {
|
|
value = reflect.Zero(argType)
|
|
}
|
|
if !value.Type().AssignableTo(argType) {
|
|
return nil, fmt.Errorf("arg %d has type %s; should be %s", i, value.Type(), argType)
|
|
}
|
|
argv[i] = value
|
|
}
|
|
result := v.Call(argv)
|
|
if len(result) == 2 && !result[1].IsNil() {
|
|
return result[0].Interface(), result[1].Interface().(error)
|
|
}
|
|
return result[0].Interface(), nil
|
|
}
|
|
|
|
// Boolean logic.
|
|
|
|
func truth(a interface{}) bool {
|
|
t, _ := isTrue(reflect.ValueOf(a))
|
|
return t
|
|
}
|
|
|
|
// and computes the Boolean AND of its arguments, returning
|
|
// the first false argument it encounters, or the last argument.
|
|
func and(arg0 interface{}, args ...interface{}) interface{} {
|
|
if !truth(arg0) {
|
|
return arg0
|
|
}
|
|
for i := range args {
|
|
arg0 = args[i]
|
|
if !truth(arg0) {
|
|
break
|
|
}
|
|
}
|
|
return arg0
|
|
}
|
|
|
|
// or computes the Boolean OR of its arguments, returning
|
|
// the first true argument it encounters, or the last argument.
|
|
func or(arg0 interface{}, args ...interface{}) interface{} {
|
|
if truth(arg0) {
|
|
return arg0
|
|
}
|
|
for i := range args {
|
|
arg0 = args[i]
|
|
if truth(arg0) {
|
|
break
|
|
}
|
|
}
|
|
return arg0
|
|
}
|
|
|
|
// not returns the Boolean negation of its argument.
|
|
func not(arg interface{}) (truth bool) {
|
|
truth, _ = isTrue(reflect.ValueOf(arg))
|
|
return !truth
|
|
}
|
|
|
|
// Comparison.
|
|
|
|
// TODO: Perhaps allow comparison between signed and unsigned integers.
|
|
|
|
var (
|
|
errBadComparisonType = errors.New("invalid type for comparison")
|
|
errBadComparison = errors.New("incompatible types for comparison")
|
|
errNoComparison = errors.New("missing argument for comparison")
|
|
)
|
|
|
|
type kind int
|
|
|
|
const (
|
|
invalidKind kind = iota
|
|
boolKind
|
|
complexKind
|
|
intKind
|
|
floatKind
|
|
integerKind
|
|
stringKind
|
|
uintKind
|
|
)
|
|
|
|
func basicKind(v reflect.Value) (kind, error) {
|
|
switch v.Kind() {
|
|
case reflect.Bool:
|
|
return boolKind, nil
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return intKind, nil
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
|
return uintKind, nil
|
|
case reflect.Float32, reflect.Float64:
|
|
return floatKind, nil
|
|
case reflect.Complex64, reflect.Complex128:
|
|
return complexKind, nil
|
|
case reflect.String:
|
|
return stringKind, nil
|
|
}
|
|
return invalidKind, errBadComparisonType
|
|
}
|
|
|
|
// eq evaluates the comparison a == b || a == c || ...
|
|
func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
|
|
v1 := reflect.ValueOf(arg1)
|
|
k1, err := basicKind(v1)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if len(arg2) == 0 {
|
|
return false, errNoComparison
|
|
}
|
|
for _, arg := range arg2 {
|
|
v2 := reflect.ValueOf(arg)
|
|
k2, err := basicKind(v2)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
truth := false
|
|
if k1 != k2 {
|
|
// Special case: Can compare integer values regardless of type's sign.
|
|
switch {
|
|
case k1 == intKind && k2 == uintKind:
|
|
truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
|
|
case k1 == uintKind && k2 == intKind:
|
|
truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
|
|
default:
|
|
return false, errBadComparison
|
|
}
|
|
} else {
|
|
switch k1 {
|
|
case boolKind:
|
|
truth = v1.Bool() == v2.Bool()
|
|
case complexKind:
|
|
truth = v1.Complex() == v2.Complex()
|
|
case floatKind:
|
|
truth = v1.Float() == v2.Float()
|
|
case intKind:
|
|
truth = v1.Int() == v2.Int()
|
|
case stringKind:
|
|
truth = v1.String() == v2.String()
|
|
case uintKind:
|
|
truth = v1.Uint() == v2.Uint()
|
|
default:
|
|
panic("invalid kind")
|
|
}
|
|
}
|
|
if truth {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
// ne evaluates the comparison a != b.
|
|
func ne(arg1, arg2 interface{}) (bool, error) {
|
|
// != is the inverse of ==.
|
|
equal, err := eq(arg1, arg2)
|
|
return !equal, err
|
|
}
|
|
|
|
// lt evaluates the comparison a < b.
|
|
func lt(arg1, arg2 interface{}) (bool, error) {
|
|
v1 := reflect.ValueOf(arg1)
|
|
k1, err := basicKind(v1)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
v2 := reflect.ValueOf(arg2)
|
|
k2, err := basicKind(v2)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
truth := false
|
|
if k1 != k2 {
|
|
// Special case: Can compare integer values regardless of type's sign.
|
|
switch {
|
|
case k1 == intKind && k2 == uintKind:
|
|
truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
|
|
case k1 == uintKind && k2 == intKind:
|
|
truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
|
|
default:
|
|
return false, errBadComparison
|
|
}
|
|
} else {
|
|
switch k1 {
|
|
case boolKind, complexKind:
|
|
return false, errBadComparisonType
|
|
case floatKind:
|
|
truth = v1.Float() < v2.Float()
|
|
case intKind:
|
|
truth = v1.Int() < v2.Int()
|
|
case stringKind:
|
|
truth = v1.String() < v2.String()
|
|
case uintKind:
|
|
truth = v1.Uint() < v2.Uint()
|
|
default:
|
|
panic("invalid kind")
|
|
}
|
|
}
|
|
return truth, nil
|
|
}
|
|
|
|
// le evaluates the comparison <= b.
|
|
func le(arg1, arg2 interface{}) (bool, error) {
|
|
// <= is < or ==.
|
|
lessThan, err := lt(arg1, arg2)
|
|
if lessThan || err != nil {
|
|
return lessThan, err
|
|
}
|
|
return eq(arg1, arg2)
|
|
}
|
|
|
|
// gt evaluates the comparison a > b.
|
|
func gt(arg1, arg2 interface{}) (bool, error) {
|
|
// > is the inverse of <=.
|
|
lessOrEqual, err := le(arg1, arg2)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return !lessOrEqual, nil
|
|
}
|
|
|
|
// ge evaluates the comparison a >= b.
|
|
func ge(arg1, arg2 interface{}) (bool, error) {
|
|
// >= is the inverse of <.
|
|
lessThan, err := lt(arg1, arg2)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return !lessThan, nil
|
|
}
|
|
|
|
// HTML escaping.
|
|
|
|
var (
|
|
htmlQuot = []byte(""") // shorter than """
|
|
htmlApos = []byte("'") // shorter than "'" and apos was not in HTML until HTML5
|
|
htmlAmp = []byte("&")
|
|
htmlLt = []byte("<")
|
|
htmlGt = []byte(">")
|
|
)
|
|
|
|
// HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
|
|
func HTMLEscape(w io.Writer, b []byte) {
|
|
last := 0
|
|
for i, c := range b {
|
|
var html []byte
|
|
switch c {
|
|
case '"':
|
|
html = htmlQuot
|
|
case '\'':
|
|
html = htmlApos
|
|
case '&':
|
|
html = htmlAmp
|
|
case '<':
|
|
html = htmlLt
|
|
case '>':
|
|
html = htmlGt
|
|
default:
|
|
continue
|
|
}
|
|
w.Write(b[last:i])
|
|
w.Write(html)
|
|
last = i + 1
|
|
}
|
|
w.Write(b[last:])
|
|
}
|
|
|
|
// HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
|
|
func HTMLEscapeString(s string) string {
|
|
// Avoid allocation if we can.
|
|
if strings.IndexAny(s, `'"&<>`) < 0 {
|
|
return s
|
|
}
|
|
var b bytes.Buffer
|
|
HTMLEscape(&b, []byte(s))
|
|
return b.String()
|
|
}
|
|
|
|
// HTMLEscaper returns the escaped HTML equivalent of the textual
|
|
// representation of its arguments.
|
|
func HTMLEscaper(args ...interface{}) string {
|
|
return HTMLEscapeString(evalArgs(args))
|
|
}
|
|
|
|
// JavaScript escaping.
|
|
|
|
var (
|
|
jsLowUni = []byte(`\u00`)
|
|
hex = []byte("0123456789ABCDEF")
|
|
|
|
jsBackslash = []byte(`\\`)
|
|
jsApos = []byte(`\'`)
|
|
jsQuot = []byte(`\"`)
|
|
jsLt = []byte(`\x3C`)
|
|
jsGt = []byte(`\x3E`)
|
|
)
|
|
|
|
// JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
|
|
func JSEscape(w io.Writer, b []byte) {
|
|
last := 0
|
|
for i := 0; i < len(b); i++ {
|
|
c := b[i]
|
|
|
|
if !jsIsSpecial(rune(c)) {
|
|
// fast path: nothing to do
|
|
continue
|
|
}
|
|
w.Write(b[last:i])
|
|
|
|
if c < utf8.RuneSelf {
|
|
// Quotes, slashes and angle brackets get quoted.
|
|
// Control characters get written as \u00XX.
|
|
switch c {
|
|
case '\\':
|
|
w.Write(jsBackslash)
|
|
case '\'':
|
|
w.Write(jsApos)
|
|
case '"':
|
|
w.Write(jsQuot)
|
|
case '<':
|
|
w.Write(jsLt)
|
|
case '>':
|
|
w.Write(jsGt)
|
|
default:
|
|
w.Write(jsLowUni)
|
|
t, b := c>>4, c&0x0f
|
|
w.Write(hex[t : t+1])
|
|
w.Write(hex[b : b+1])
|
|
}
|
|
} else {
|
|
// Unicode rune.
|
|
r, size := utf8.DecodeRune(b[i:])
|
|
if unicode.IsPrint(r) {
|
|
w.Write(b[i : i+size])
|
|
} else {
|
|
fmt.Fprintf(w, "\\u%04X", r)
|
|
}
|
|
i += size - 1
|
|
}
|
|
last = i + 1
|
|
}
|
|
w.Write(b[last:])
|
|
}
|
|
|
|
// JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
|
|
func JSEscapeString(s string) string {
|
|
// Avoid allocation if we can.
|
|
if strings.IndexFunc(s, jsIsSpecial) < 0 {
|
|
return s
|
|
}
|
|
var b bytes.Buffer
|
|
JSEscape(&b, []byte(s))
|
|
return b.String()
|
|
}
|
|
|
|
func jsIsSpecial(r rune) bool {
|
|
switch r {
|
|
case '\\', '\'', '"', '<', '>':
|
|
return true
|
|
}
|
|
return r < ' ' || utf8.RuneSelf <= r
|
|
}
|
|
|
|
// JSEscaper returns the escaped JavaScript equivalent of the textual
|
|
// representation of its arguments.
|
|
func JSEscaper(args ...interface{}) string {
|
|
return JSEscapeString(evalArgs(args))
|
|
}
|
|
|
|
// URLQueryEscaper returns the escaped value of the textual representation of
|
|
// its arguments in a form suitable for embedding in a URL query.
|
|
func URLQueryEscaper(args ...interface{}) string {
|
|
return url.QueryEscape(evalArgs(args))
|
|
}
|
|
|
|
// evalArgs formats the list of arguments into a string. It is therefore equivalent to
|
|
// fmt.Sprint(args...)
|
|
// except that each argument is indirected (if a pointer), as required,
|
|
// using the same rules as the default string evaluation during template
|
|
// execution.
|
|
func evalArgs(args []interface{}) string {
|
|
ok := false
|
|
var s string
|
|
// Fast path for simple common case.
|
|
if len(args) == 1 {
|
|
s, ok = args[0].(string)
|
|
}
|
|
if !ok {
|
|
for i, arg := range args {
|
|
a, ok := printableValue(reflect.ValueOf(arg))
|
|
if ok {
|
|
args[i] = a
|
|
} // else left fmt do its thing
|
|
}
|
|
s = fmt.Sprint(args...)
|
|
}
|
|
return s
|
|
}
|