510 lines
10 KiB
Go

package msgpack
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"reflect"
"time"
"github.com/vmihailenco/msgpack/codes"
)
const bytesAllocLimit = 1024 * 1024 // 1mb
type bufReader interface {
io.Reader
io.ByteScanner
}
func newBufReader(r io.Reader) bufReader {
if br, ok := r.(bufReader); ok {
return br
}
return bufio.NewReader(r)
}
func makeBuffer() []byte {
return make([]byte, 0, 64)
}
// Unmarshal decodes the MessagePack-encoded data and stores the result
// in the value pointed to by v.
func Unmarshal(data []byte, v ...interface{}) error {
return NewDecoder(bytes.NewReader(data)).Decode(v...)
}
type Decoder struct {
r io.Reader
s io.ByteScanner
buf []byte
extLen int
rec []byte // accumulates read data if not nil
decodeMapFunc func(*Decoder) (interface{}, error)
}
// NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may read data from r
// beyond the MessagePack values requested. Buffering can be disabled
// by passing a reader that implements io.ByteScanner interface.
func NewDecoder(r io.Reader) *Decoder {
d := &Decoder{
decodeMapFunc: decodeMap,
buf: makeBuffer(),
}
d.resetReader(r)
return d
}
func (d *Decoder) SetDecodeMapFunc(fn func(*Decoder) (interface{}, error)) {
d.decodeMapFunc = fn
}
func (d *Decoder) Reset(r io.Reader) error {
d.resetReader(r)
return nil
}
func (d *Decoder) resetReader(r io.Reader) {
reader := newBufReader(r)
d.r = reader
d.s = reader
}
func (d *Decoder) Decode(v ...interface{}) error {
for _, vv := range v {
if err := d.decode(vv); err != nil {
return err
}
}
return nil
}
func (d *Decoder) decode(dst interface{}) error {
var err error
switch v := dst.(type) {
case *string:
if v != nil {
*v, err = d.DecodeString()
return err
}
case *[]byte:
if v != nil {
return d.decodeBytesPtr(v)
}
case *int:
if v != nil {
*v, err = d.DecodeInt()
return err
}
case *int8:
if v != nil {
*v, err = d.DecodeInt8()
return err
}
case *int16:
if v != nil {
*v, err = d.DecodeInt16()
return err
}
case *int32:
if v != nil {
*v, err = d.DecodeInt32()
return err
}
case *int64:
if v != nil {
*v, err = d.DecodeInt64()
return err
}
case *uint:
if v != nil {
*v, err = d.DecodeUint()
return err
}
case *uint8:
if v != nil {
*v, err = d.DecodeUint8()
return err
}
case *uint16:
if v != nil {
*v, err = d.DecodeUint16()
return err
}
case *uint32:
if v != nil {
*v, err = d.DecodeUint32()
return err
}
case *uint64:
if v != nil {
*v, err = d.DecodeUint64()
return err
}
case *bool:
if v != nil {
*v, err = d.DecodeBool()
return err
}
case *float32:
if v != nil {
*v, err = d.DecodeFloat32()
return err
}
case *float64:
if v != nil {
*v, err = d.DecodeFloat64()
return err
}
case *[]string:
return d.decodeStringSlicePtr(v)
case *map[string]string:
return d.decodeMapStringStringPtr(v)
case *map[string]interface{}:
return d.decodeMapStringInterfacePtr(v)
case *time.Duration:
if v != nil {
vv, err := d.DecodeInt64()
*v = time.Duration(vv)
return err
}
case *time.Time:
if v != nil {
*v, err = d.DecodeTime()
return err
}
}
v := reflect.ValueOf(dst)
if !v.IsValid() {
return errors.New("msgpack: Decode(nil)")
}
if v.Kind() != reflect.Ptr {
return fmt.Errorf("msgpack: Decode(nonsettable %T)", dst)
}
v = v.Elem()
if !v.IsValid() {
return fmt.Errorf("msgpack: Decode(nonsettable %T)", dst)
}
return d.DecodeValue(v)
}
func (d *Decoder) DecodeValue(v reflect.Value) error {
decode := getDecoder(v.Type())
return decode(d, v)
}
func (d *Decoder) DecodeNil() error {
c, err := d.readCode()
if err != nil {
return err
}
if c != codes.Nil {
return fmt.Errorf("msgpack: invalid code=%x decoding nil", c)
}
return nil
}
func (d *Decoder) decodeNilValue(v reflect.Value) error {
err := d.DecodeNil()
if v.IsNil() {
return err
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
v.Set(reflect.Zero(v.Type()))
return err
}
func (d *Decoder) DecodeBool() (bool, error) {
c, err := d.readCode()
if err != nil {
return false, err
}
return d.bool(c)
}
func (d *Decoder) bool(c codes.Code) (bool, error) {
if c == codes.False {
return false, nil
}
if c == codes.True {
return true, nil
}
return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c)
}
// DecodeInterface decodes value into interface. Possible value types are:
// - nil,
// - bool,
// - int8, int16, int32, int64,
// - uint8, uint16, uint32, uint64,
// - float32 and float64,
// - string,
// - []byte,
// - slices of any of the above,
// - maps of any of the above.
func (d *Decoder) DecodeInterface() (interface{}, error) {
c, err := d.readCode()
if err != nil {
return nil, err
}
if codes.IsFixedNum(c) {
return int8(c), nil
}
if codes.IsFixedMap(c) {
d.s.UnreadByte()
return d.DecodeMap()
}
if codes.IsFixedArray(c) {
return d.decodeSlice(c)
}
if codes.IsFixedString(c) {
return d.string(c)
}
switch c {
case codes.Nil:
return nil, nil
case codes.False, codes.True:
return d.bool(c)
case codes.Float:
return d.float32(c)
case codes.Double:
return d.float64(c)
case codes.Uint8:
return d.uint8()
case codes.Uint16:
return d.uint16()
case codes.Uint32:
return d.uint32()
case codes.Uint64:
return d.uint64()
case codes.Int8:
return d.int8()
case codes.Int16:
return d.int16()
case codes.Int32:
return d.int32()
case codes.Int64:
return d.int64()
case codes.Bin8, codes.Bin16, codes.Bin32:
return d.bytes(c, nil)
case codes.Str8, codes.Str16, codes.Str32:
return d.string(c)
case codes.Array16, codes.Array32:
return d.decodeSlice(c)
case codes.Map16, codes.Map32:
d.s.UnreadByte()
return d.DecodeMap()
case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
codes.Ext8, codes.Ext16, codes.Ext32:
return d.extInterface(c)
}
return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
}
// DecodeInterfaceLoose is like DecodeInterface except that:
// - int8, int16, and int32 are converted to int64,
// - uint8, uint16, and uint32 are converted to uint64,
// - float32 is converted to float64.
func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) {
c, err := d.readCode()
if err != nil {
return nil, err
}
if codes.IsFixedNum(c) {
return int64(c), nil
}
if codes.IsFixedMap(c) {
d.s.UnreadByte()
return d.DecodeMap()
}
if codes.IsFixedArray(c) {
return d.decodeSlice(c)
}
if codes.IsFixedString(c) {
return d.string(c)
}
switch c {
case codes.Nil:
return nil, nil
case codes.False, codes.True:
return d.bool(c)
case codes.Float, codes.Double:
return d.float64(c)
case codes.Uint8, codes.Uint16, codes.Uint32, codes.Uint64:
return d.uint(c)
case codes.Int8, codes.Int16, codes.Int32, codes.Int64:
return d.int(c)
case codes.Bin8, codes.Bin16, codes.Bin32:
return d.bytes(c, nil)
case codes.Str8, codes.Str16, codes.Str32:
return d.string(c)
case codes.Array16, codes.Array32:
return d.decodeSlice(c)
case codes.Map16, codes.Map32:
d.s.UnreadByte()
return d.DecodeMap()
case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
codes.Ext8, codes.Ext16, codes.Ext32:
return d.extInterface(c)
}
return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
}
// Skip skips next value.
func (d *Decoder) Skip() error {
c, err := d.readCode()
if err != nil {
return err
}
if codes.IsFixedNum(c) {
return nil
} else if codes.IsFixedMap(c) {
return d.skipMap(c)
} else if codes.IsFixedArray(c) {
return d.skipSlice(c)
} else if codes.IsFixedString(c) {
return d.skipBytes(c)
}
switch c {
case codes.Nil, codes.False, codes.True:
return nil
case codes.Uint8, codes.Int8:
return d.skipN(1)
case codes.Uint16, codes.Int16:
return d.skipN(2)
case codes.Uint32, codes.Int32, codes.Float:
return d.skipN(4)
case codes.Uint64, codes.Int64, codes.Double:
return d.skipN(8)
case codes.Bin8, codes.Bin16, codes.Bin32:
return d.skipBytes(c)
case codes.Str8, codes.Str16, codes.Str32:
return d.skipBytes(c)
case codes.Array16, codes.Array32:
return d.skipSlice(c)
case codes.Map16, codes.Map32:
return d.skipMap(c)
case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
codes.Ext8, codes.Ext16, codes.Ext32:
return d.skipExt(c)
}
return fmt.Errorf("msgpack: unknown code %x", c)
}
// PeekCode returns the next MessagePack code without advancing the reader.
// Subpackage msgpack/codes contains list of available codes.
func (d *Decoder) PeekCode() (codes.Code, error) {
c, err := d.s.ReadByte()
if err != nil {
return 0, err
}
return codes.Code(c), d.s.UnreadByte()
}
func (d *Decoder) hasNilCode() bool {
code, err := d.PeekCode()
return err == nil && code == codes.Nil
}
func (d *Decoder) readCode() (codes.Code, error) {
d.extLen = 0
c, err := d.s.ReadByte()
if err != nil {
return 0, err
}
if d.rec != nil {
d.rec = append(d.rec, c)
}
return codes.Code(c), nil
}
func (d *Decoder) readFull(b []byte) error {
_, err := io.ReadFull(d.r, b)
if err != nil {
return err
}
if d.rec != nil {
d.rec = append(d.rec, b...)
}
return nil
}
func (d *Decoder) readN(n int) ([]byte, error) {
buf, err := readN(d.r, d.buf, n)
if err != nil {
return nil, err
}
d.buf = buf
if d.rec != nil {
d.rec = append(d.rec, buf...)
}
return buf, nil
}
func readN(r io.Reader, b []byte, n int) ([]byte, error) {
if b == nil {
if n == 0 {
return make([]byte, 0), nil
}
if n <= bytesAllocLimit {
b = make([]byte, n)
} else {
b = make([]byte, bytesAllocLimit)
}
}
if n <= cap(b) {
b = b[:n]
_, err := io.ReadFull(r, b)
return b, err
}
b = b[:cap(b)]
var pos int
for {
alloc := n - len(b)
if alloc > bytesAllocLimit {
alloc = bytesAllocLimit
}
b = append(b, make([]byte, alloc)...)
_, err := io.ReadFull(r, b[pos:])
if err != nil {
return nil, err
}
if len(b) == n {
break
}
pos = len(b)
}
return b, nil
}
func min(a, b int) int {
if a <= b {
return a
}
return b
}