384 lines
8.1 KiB
Go

package gen
import (
"fmt"
"io"
)
const (
errcheck = "\nif err != nil { return }"
lenAsUint32 = "uint32(len(%s))"
literalFmt = "%s"
intFmt = "%d"
quotedFmt = `"%s"`
mapHeader = "MapHeader"
arrayHeader = "ArrayHeader"
mapKey = "MapKeyPtr"
stringTyp = "String"
u32 = "uint32"
)
// Method is a bitfield representing something that the
// generator knows how to print.
type Method uint8
// are the bits in 'f' set in 'm'?
func (m Method) isset(f Method) bool { return (m&f == f) }
// String implements fmt.Stringer
func (m Method) String() string {
switch m {
case 0, invalidmeth:
return "<invalid method>"
case Decode:
return "decode"
case Encode:
return "encode"
case Marshal:
return "marshal"
case Unmarshal:
return "unmarshal"
case Size:
return "size"
case Test:
return "test"
default:
// return e.g. "decode+encode+test"
modes := [...]Method{Decode, Encode, Marshal, Unmarshal, Size, Test}
any := false
nm := ""
for _, mm := range modes {
if m.isset(mm) {
if any {
nm += "+" + mm.String()
} else {
nm += mm.String()
any = true
}
}
}
return nm
}
}
func strtoMeth(s string) Method {
switch s {
case "encode":
return Encode
case "decode":
return Decode
case "marshal":
return Marshal
case "unmarshal":
return Unmarshal
case "size":
return Size
case "test":
return Test
default:
return 0
}
}
const (
Decode Method = 1 << iota // msgp.Decodable
Encode // msgp.Encodable
Marshal // msgp.Marshaler
Unmarshal // msgp.Unmarshaler
Size // msgp.Sizer
Test // generate tests
invalidmeth // this isn't a method
encodetest = Encode | Decode | Test // tests for Encodable and Decodable
marshaltest = Marshal | Unmarshal | Test // tests for Marshaler and Unmarshaler
)
type Printer struct {
gens []generator
}
func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
if m.isset(Test) && tests == nil {
panic("cannot print tests with 'nil' tests argument!")
}
gens := make([]generator, 0, 7)
if m.isset(Decode) {
gens = append(gens, decode(out))
}
if m.isset(Encode) {
gens = append(gens, encode(out))
}
if m.isset(Marshal) {
gens = append(gens, marshal(out))
}
if m.isset(Unmarshal) {
gens = append(gens, unmarshal(out))
}
if m.isset(Size) {
gens = append(gens, sizes(out))
}
if m.isset(marshaltest) {
gens = append(gens, mtest(tests))
}
if m.isset(encodetest) {
gens = append(gens, etest(tests))
}
if len(gens) == 0 {
panic("NewPrinter called with invalid method flags")
}
return &Printer{gens: gens}
}
// TransformPass is a pass that transforms individual
// elements. (Note that if the returned is different from
// the argument, it should not point to the same objects.)
type TransformPass func(Elem) Elem
// IgnoreTypename is a pass that just ignores
// types of a given name.
func IgnoreTypename(name string) TransformPass {
return func(e Elem) Elem {
if e.TypeName() == name {
return nil
}
return e
}
}
// ApplyDirective applies a directive to a named pass
// and all of its dependents.
func (p *Printer) ApplyDirective(pass Method, t TransformPass) {
for _, g := range p.gens {
if g.Method().isset(pass) {
g.Add(t)
}
}
}
// Print prints an Elem.
func (p *Printer) Print(e Elem) error {
for _, g := range p.gens {
// Elem.SetVarname() is called before the Print() step in parse.FileSet.PrintTo().
// Elem.SetVarname() generates identifiers as it walks the Elem. This can cause
// collisions between idents created during SetVarname and idents created during Print,
// hence the separate prefixes.
resetIdent("zb")
err := g.Execute(e)
resetIdent("za")
if err != nil {
return err
}
}
return nil
}
// generator is the interface through
// which code is generated.
type generator interface {
Method() Method
Add(p TransformPass)
Execute(Elem) error // execute writes the method for the provided object.
}
type passes []TransformPass
func (p *passes) Add(t TransformPass) {
*p = append(*p, t)
}
func (p *passes) applyall(e Elem) Elem {
for _, t := range *p {
e = t(e)
if e == nil {
return nil
}
}
return e
}
type traversal interface {
gMap(*Map)
gSlice(*Slice)
gArray(*Array)
gPtr(*Ptr)
gBase(*BaseElem)
gStruct(*Struct)
}
// type-switch dispatch to the correct
// method given the type of 'e'
func next(t traversal, e Elem) {
switch e := e.(type) {
case *Map:
t.gMap(e)
case *Struct:
t.gStruct(e)
case *Slice:
t.gSlice(e)
case *Array:
t.gArray(e)
case *Ptr:
t.gPtr(e)
case *BaseElem:
t.gBase(e)
default:
panic("bad element type")
}
}
// possibly-immutable method receiver
func imutMethodReceiver(p Elem) string {
switch e := p.(type) {
case *Struct:
// TODO(HACK): actually do real math here.
if len(e.Fields) <= 3 {
for i := range e.Fields {
if be, ok := e.Fields[i].FieldElem.(*BaseElem); !ok || (be.Value == IDENT || be.Value == Bytes) {
goto nope
}
}
return p.TypeName()
}
nope:
return "*" + p.TypeName()
// gets dereferenced automatically
case *Array:
return "*" + p.TypeName()
// everything else can be
// by-value.
default:
return p.TypeName()
}
}
// if necessary, wraps a type
// so that its method receiver
// is of the write type.
func methodReceiver(p Elem) string {
switch p.(type) {
// structs and arrays are
// dereferenced automatically,
// so no need to alter varname
case *Struct, *Array:
return "*" + p.TypeName()
// set variable name to
// *varname
default:
p.SetVarname("(*" + p.Varname() + ")")
return "*" + p.TypeName()
}
}
func unsetReceiver(p Elem) {
switch p.(type) {
case *Struct, *Array:
default:
p.SetVarname("z")
}
}
// shared utility for generators
type printer struct {
w io.Writer
err error
}
// writes "var {{name}} {{typ}};"
func (p *printer) declare(name string, typ string) {
p.printf("\nvar %s %s", name, typ)
}
// does:
//
// if m != nil && size > 0 {
// m = make(type, size)
// } else if len(m) > 0 {
// for key, _ := range m { delete(m, key) }
// }
//
func (p *printer) resizeMap(size string, m *Map) {
vn := m.Varname()
if !p.ok() {
return
}
p.printf("\nif %s == nil && %s > 0 {", vn, size)
p.printf("\n%s = make(%s, %s)", vn, m.TypeName(), size)
p.printf("\n} else if len(%s) > 0 {", vn)
p.clearMap(vn)
p.closeblock()
}
// assign key to value based on varnames
func (p *printer) mapAssign(m *Map) {
if !p.ok() {
return
}
p.printf("\n%s[%s] = %s", m.Varname(), m.Keyidx, m.Validx)
}
// clear map keys
func (p *printer) clearMap(name string) {
p.printf("\nfor key, _ := range %[1]s { delete(%[1]s, key) }", name)
}
func (p *printer) resizeSlice(size string, s *Slice) {
p.printf("\nif cap(%[1]s) >= int(%[2]s) { %[1]s = (%[1]s)[:%[2]s] } else { %[1]s = make(%[3]s, %[2]s) }", s.Varname(), size, s.TypeName())
}
func (p *printer) arrayCheck(want string, got string) {
p.printf("\nif %[1]s != %[2]s { err = msgp.ArrayError{Wanted: %[2]s, Got: %[1]s}; return }", got, want)
}
func (p *printer) closeblock() { p.print("\n}") }
// does:
//
// for idx := range iter {
// {{generate inner}}
// }
//
func (p *printer) rangeBlock(idx string, iter string, t traversal, inner Elem) {
p.printf("\n for %s := range %s {", idx, iter)
next(t, inner)
p.closeblock()
}
func (p *printer) nakedReturn() {
if p.ok() {
p.print("\nreturn\n}\n")
}
}
func (p *printer) comment(s string) {
p.print("\n// " + s)
}
func (p *printer) printf(format string, args ...interface{}) {
if p.err == nil {
_, p.err = fmt.Fprintf(p.w, format, args...)
}
}
func (p *printer) print(format string) {
if p.err == nil {
_, p.err = io.WriteString(p.w, format)
}
}
func (p *printer) initPtr(pt *Ptr) {
if pt.Needsinit() {
vname := pt.Varname()
p.printf("\nif %s == nil { %s = new(%s); }", vname, vname, pt.Value.TypeName())
}
}
func (p *printer) ok() bool { return p.err == nil }
func tobaseConvert(b *BaseElem) string {
return b.ToBase() + "(" + b.Varname() + ")"
}