384 lines
8.1 KiB
Go
384 lines
8.1 KiB
Go
|
package gen
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"io"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
errcheck = "\nif err != nil { return }"
|
||
|
lenAsUint32 = "uint32(len(%s))"
|
||
|
literalFmt = "%s"
|
||
|
intFmt = "%d"
|
||
|
quotedFmt = `"%s"`
|
||
|
mapHeader = "MapHeader"
|
||
|
arrayHeader = "ArrayHeader"
|
||
|
mapKey = "MapKeyPtr"
|
||
|
stringTyp = "String"
|
||
|
u32 = "uint32"
|
||
|
)
|
||
|
|
||
|
// Method is a bitfield representing something that the
|
||
|
// generator knows how to print.
|
||
|
type Method uint8
|
||
|
|
||
|
// are the bits in 'f' set in 'm'?
|
||
|
func (m Method) isset(f Method) bool { return (m&f == f) }
|
||
|
|
||
|
// String implements fmt.Stringer
|
||
|
func (m Method) String() string {
|
||
|
switch m {
|
||
|
case 0, invalidmeth:
|
||
|
return "<invalid method>"
|
||
|
case Decode:
|
||
|
return "decode"
|
||
|
case Encode:
|
||
|
return "encode"
|
||
|
case Marshal:
|
||
|
return "marshal"
|
||
|
case Unmarshal:
|
||
|
return "unmarshal"
|
||
|
case Size:
|
||
|
return "size"
|
||
|
case Test:
|
||
|
return "test"
|
||
|
default:
|
||
|
// return e.g. "decode+encode+test"
|
||
|
modes := [...]Method{Decode, Encode, Marshal, Unmarshal, Size, Test}
|
||
|
any := false
|
||
|
nm := ""
|
||
|
for _, mm := range modes {
|
||
|
if m.isset(mm) {
|
||
|
if any {
|
||
|
nm += "+" + mm.String()
|
||
|
} else {
|
||
|
nm += mm.String()
|
||
|
any = true
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nm
|
||
|
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func strtoMeth(s string) Method {
|
||
|
switch s {
|
||
|
case "encode":
|
||
|
return Encode
|
||
|
case "decode":
|
||
|
return Decode
|
||
|
case "marshal":
|
||
|
return Marshal
|
||
|
case "unmarshal":
|
||
|
return Unmarshal
|
||
|
case "size":
|
||
|
return Size
|
||
|
case "test":
|
||
|
return Test
|
||
|
default:
|
||
|
return 0
|
||
|
}
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
Decode Method = 1 << iota // msgp.Decodable
|
||
|
Encode // msgp.Encodable
|
||
|
Marshal // msgp.Marshaler
|
||
|
Unmarshal // msgp.Unmarshaler
|
||
|
Size // msgp.Sizer
|
||
|
Test // generate tests
|
||
|
invalidmeth // this isn't a method
|
||
|
encodetest = Encode | Decode | Test // tests for Encodable and Decodable
|
||
|
marshaltest = Marshal | Unmarshal | Test // tests for Marshaler and Unmarshaler
|
||
|
)
|
||
|
|
||
|
type Printer struct {
|
||
|
gens []generator
|
||
|
}
|
||
|
|
||
|
func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
|
||
|
if m.isset(Test) && tests == nil {
|
||
|
panic("cannot print tests with 'nil' tests argument!")
|
||
|
}
|
||
|
gens := make([]generator, 0, 7)
|
||
|
if m.isset(Decode) {
|
||
|
gens = append(gens, decode(out))
|
||
|
}
|
||
|
if m.isset(Encode) {
|
||
|
gens = append(gens, encode(out))
|
||
|
}
|
||
|
if m.isset(Marshal) {
|
||
|
gens = append(gens, marshal(out))
|
||
|
}
|
||
|
if m.isset(Unmarshal) {
|
||
|
gens = append(gens, unmarshal(out))
|
||
|
}
|
||
|
if m.isset(Size) {
|
||
|
gens = append(gens, sizes(out))
|
||
|
}
|
||
|
if m.isset(marshaltest) {
|
||
|
gens = append(gens, mtest(tests))
|
||
|
}
|
||
|
if m.isset(encodetest) {
|
||
|
gens = append(gens, etest(tests))
|
||
|
}
|
||
|
if len(gens) == 0 {
|
||
|
panic("NewPrinter called with invalid method flags")
|
||
|
}
|
||
|
return &Printer{gens: gens}
|
||
|
}
|
||
|
|
||
|
// TransformPass is a pass that transforms individual
|
||
|
// elements. (Note that if the returned is different from
|
||
|
// the argument, it should not point to the same objects.)
|
||
|
type TransformPass func(Elem) Elem
|
||
|
|
||
|
// IgnoreTypename is a pass that just ignores
|
||
|
// types of a given name.
|
||
|
func IgnoreTypename(name string) TransformPass {
|
||
|
return func(e Elem) Elem {
|
||
|
if e.TypeName() == name {
|
||
|
return nil
|
||
|
}
|
||
|
return e
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ApplyDirective applies a directive to a named pass
|
||
|
// and all of its dependents.
|
||
|
func (p *Printer) ApplyDirective(pass Method, t TransformPass) {
|
||
|
for _, g := range p.gens {
|
||
|
if g.Method().isset(pass) {
|
||
|
g.Add(t)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Print prints an Elem.
|
||
|
func (p *Printer) Print(e Elem) error {
|
||
|
for _, g := range p.gens {
|
||
|
// Elem.SetVarname() is called before the Print() step in parse.FileSet.PrintTo().
|
||
|
// Elem.SetVarname() generates identifiers as it walks the Elem. This can cause
|
||
|
// collisions between idents created during SetVarname and idents created during Print,
|
||
|
// hence the separate prefixes.
|
||
|
resetIdent("zb")
|
||
|
err := g.Execute(e)
|
||
|
resetIdent("za")
|
||
|
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// generator is the interface through
|
||
|
// which code is generated.
|
||
|
type generator interface {
|
||
|
Method() Method
|
||
|
Add(p TransformPass)
|
||
|
Execute(Elem) error // execute writes the method for the provided object.
|
||
|
}
|
||
|
|
||
|
type passes []TransformPass
|
||
|
|
||
|
func (p *passes) Add(t TransformPass) {
|
||
|
*p = append(*p, t)
|
||
|
}
|
||
|
|
||
|
func (p *passes) applyall(e Elem) Elem {
|
||
|
for _, t := range *p {
|
||
|
e = t(e)
|
||
|
if e == nil {
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
type traversal interface {
|
||
|
gMap(*Map)
|
||
|
gSlice(*Slice)
|
||
|
gArray(*Array)
|
||
|
gPtr(*Ptr)
|
||
|
gBase(*BaseElem)
|
||
|
gStruct(*Struct)
|
||
|
}
|
||
|
|
||
|
// type-switch dispatch to the correct
|
||
|
// method given the type of 'e'
|
||
|
func next(t traversal, e Elem) {
|
||
|
switch e := e.(type) {
|
||
|
case *Map:
|
||
|
t.gMap(e)
|
||
|
case *Struct:
|
||
|
t.gStruct(e)
|
||
|
case *Slice:
|
||
|
t.gSlice(e)
|
||
|
case *Array:
|
||
|
t.gArray(e)
|
||
|
case *Ptr:
|
||
|
t.gPtr(e)
|
||
|
case *BaseElem:
|
||
|
t.gBase(e)
|
||
|
default:
|
||
|
panic("bad element type")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// possibly-immutable method receiver
|
||
|
func imutMethodReceiver(p Elem) string {
|
||
|
switch e := p.(type) {
|
||
|
case *Struct:
|
||
|
// TODO(HACK): actually do real math here.
|
||
|
if len(e.Fields) <= 3 {
|
||
|
for i := range e.Fields {
|
||
|
if be, ok := e.Fields[i].FieldElem.(*BaseElem); !ok || (be.Value == IDENT || be.Value == Bytes) {
|
||
|
goto nope
|
||
|
}
|
||
|
}
|
||
|
return p.TypeName()
|
||
|
}
|
||
|
nope:
|
||
|
return "*" + p.TypeName()
|
||
|
|
||
|
// gets dereferenced automatically
|
||
|
case *Array:
|
||
|
return "*" + p.TypeName()
|
||
|
|
||
|
// everything else can be
|
||
|
// by-value.
|
||
|
default:
|
||
|
return p.TypeName()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// if necessary, wraps a type
|
||
|
// so that its method receiver
|
||
|
// is of the write type.
|
||
|
func methodReceiver(p Elem) string {
|
||
|
switch p.(type) {
|
||
|
|
||
|
// structs and arrays are
|
||
|
// dereferenced automatically,
|
||
|
// so no need to alter varname
|
||
|
case *Struct, *Array:
|
||
|
return "*" + p.TypeName()
|
||
|
// set variable name to
|
||
|
// *varname
|
||
|
default:
|
||
|
p.SetVarname("(*" + p.Varname() + ")")
|
||
|
return "*" + p.TypeName()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func unsetReceiver(p Elem) {
|
||
|
switch p.(type) {
|
||
|
case *Struct, *Array:
|
||
|
default:
|
||
|
p.SetVarname("z")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// shared utility for generators
|
||
|
type printer struct {
|
||
|
w io.Writer
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
// writes "var {{name}} {{typ}};"
|
||
|
func (p *printer) declare(name string, typ string) {
|
||
|
p.printf("\nvar %s %s", name, typ)
|
||
|
}
|
||
|
|
||
|
// does:
|
||
|
//
|
||
|
// if m != nil && size > 0 {
|
||
|
// m = make(type, size)
|
||
|
// } else if len(m) > 0 {
|
||
|
// for key, _ := range m { delete(m, key) }
|
||
|
// }
|
||
|
//
|
||
|
func (p *printer) resizeMap(size string, m *Map) {
|
||
|
vn := m.Varname()
|
||
|
if !p.ok() {
|
||
|
return
|
||
|
}
|
||
|
p.printf("\nif %s == nil && %s > 0 {", vn, size)
|
||
|
p.printf("\n%s = make(%s, %s)", vn, m.TypeName(), size)
|
||
|
p.printf("\n} else if len(%s) > 0 {", vn)
|
||
|
p.clearMap(vn)
|
||
|
p.closeblock()
|
||
|
}
|
||
|
|
||
|
// assign key to value based on varnames
|
||
|
func (p *printer) mapAssign(m *Map) {
|
||
|
if !p.ok() {
|
||
|
return
|
||
|
}
|
||
|
p.printf("\n%s[%s] = %s", m.Varname(), m.Keyidx, m.Validx)
|
||
|
}
|
||
|
|
||
|
// clear map keys
|
||
|
func (p *printer) clearMap(name string) {
|
||
|
p.printf("\nfor key, _ := range %[1]s { delete(%[1]s, key) }", name)
|
||
|
}
|
||
|
|
||
|
func (p *printer) resizeSlice(size string, s *Slice) {
|
||
|
p.printf("\nif cap(%[1]s) >= int(%[2]s) { %[1]s = (%[1]s)[:%[2]s] } else { %[1]s = make(%[3]s, %[2]s) }", s.Varname(), size, s.TypeName())
|
||
|
}
|
||
|
|
||
|
func (p *printer) arrayCheck(want string, got string) {
|
||
|
p.printf("\nif %[1]s != %[2]s { err = msgp.ArrayError{Wanted: %[2]s, Got: %[1]s}; return }", got, want)
|
||
|
}
|
||
|
|
||
|
func (p *printer) closeblock() { p.print("\n}") }
|
||
|
|
||
|
// does:
|
||
|
//
|
||
|
// for idx := range iter {
|
||
|
// {{generate inner}}
|
||
|
// }
|
||
|
//
|
||
|
func (p *printer) rangeBlock(idx string, iter string, t traversal, inner Elem) {
|
||
|
p.printf("\n for %s := range %s {", idx, iter)
|
||
|
next(t, inner)
|
||
|
p.closeblock()
|
||
|
}
|
||
|
|
||
|
func (p *printer) nakedReturn() {
|
||
|
if p.ok() {
|
||
|
p.print("\nreturn\n}\n")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (p *printer) comment(s string) {
|
||
|
p.print("\n// " + s)
|
||
|
}
|
||
|
|
||
|
func (p *printer) printf(format string, args ...interface{}) {
|
||
|
if p.err == nil {
|
||
|
_, p.err = fmt.Fprintf(p.w, format, args...)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (p *printer) print(format string) {
|
||
|
if p.err == nil {
|
||
|
_, p.err = io.WriteString(p.w, format)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (p *printer) initPtr(pt *Ptr) {
|
||
|
if pt.Needsinit() {
|
||
|
vname := pt.Varname()
|
||
|
p.printf("\nif %s == nil { %s = new(%s); }", vname, vname, pt.Value.TypeName())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (p *printer) ok() bool { return p.err == nil }
|
||
|
|
||
|
func tobaseConvert(b *BaseElem) string {
|
||
|
return b.ToBase() + "(" + b.Varname() + ")"
|
||
|
}
|