2022-02-06 07:06:32 +00:00

305 lines
8.4 KiB
Go

// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved.
package jrpc2
import (
"bytes"
"encoding/json"
"github.com/creachadair/jrpc2/code"
)
// ErrInvalidVersion is returned by ParseRequests if one or more of the
// requests in the input has a missing or invalid version marker.
var ErrInvalidVersion error = &Error{Code: code.InvalidRequest, Message: "incorrect version marker"}
// ParseRequests parses a single request or a batch of requests from JSON.
//
// If msg is syntactically valid apart from one or more of the requests having
// a missing or invalid JSON-RPC version, ParseRequests returns ErrInvalidVersion
// along with the parsed results.
func ParseRequests(msg []byte) ([]*Request, error) {
var req jmessages
if err := req.parseJSON(msg); err != nil {
return nil, err
}
var err error
out := make([]*Request, len(req))
for i, req := range req {
if req.V != Version {
err = ErrInvalidVersion
}
out[i] = &Request{
id: fixID(req.ID),
method: req.M,
params: req.P,
}
}
return out, err
}
// jmessages is either a single protocol message or an array of protocol
// messages. This handles the decoding of batch requests in JSON-RPC 2.0.
type jmessages []*jmessage
func (j jmessages) toJSON() ([]byte, error) {
if len(j) == 1 && !j[0].batch {
return j[0].toJSON()
}
var sb bytes.Buffer
sb.WriteByte('[')
for i, msg := range j {
if i > 0 {
sb.WriteByte(',')
}
bits, err := msg.toJSON()
if err != nil {
return nil, err
}
sb.Write(bits)
}
sb.WriteByte(']')
return sb.Bytes(), nil
}
// N.B. Not UnmarshalJSON, because json.Unmarshal checks for validity early and
// here we want to control the error that is returned.
func (j *jmessages) parseJSON(data []byte) error {
*j = (*j)[:0] // reset state
// When parsing requests, validation checks are deferred: The only immediate
// mode of failure for unmarshaling is if the request is not a valid object
// or array.
var msgs []json.RawMessage
var batch bool
if firstByte(data) != '[' {
msgs = append(msgs, nil)
if err := json.Unmarshal(data, &msgs[0]); err != nil {
return errInvalidRequest
}
} else if err := json.Unmarshal(data, &msgs); err != nil {
return errInvalidRequest
} else {
batch = true
}
// Now parse the individual request messages, but do not fail on errors. We
// know that the messages are intact, but validity is checked at usage.
for _, raw := range msgs {
req := new(jmessage)
req.parseJSON(raw)
req.batch = batch
*j = append(*j, req)
}
return nil
}
// jmessage is the transmission format of a protocol message.
type jmessage struct {
V string // must be Version
ID json.RawMessage // may be nil
// Fields belonging to request or notification objects
M string
P json.RawMessage // may be nil
// Fields belonging to response or error objects
E *Error // set on error
R json.RawMessage // set on success
// N.B.: In a valid protocol message, M and P are mutually exclusive with E
// and R. Specifically, if M != "" then E and R must both be unset. This is
// checked during parsing.
batch bool // this message was part of a batch
err error // if not nil, this message is invalid and err is why
}
// isValidID reports whether v is a valid JSON encoding of a request ID.
// Precondition: v is a valid JSON value, or empty.
func isValidID(v json.RawMessage) bool {
if len(v) == 0 || isNull(v) {
return true // nil or empty is OK, as is "null"
} else if v[0] == '"' || v[0] == '-' || (v[0] >= '0' && v[0] <= '9') {
return true // strings and numbers are OK
// N.B. This definition does not reject fractional numbers, although the
// spec says numeric IDs should not have fractional parts.
}
return false // anything else is garbage
}
func (j *jmessage) fail(code code.Code, msg string) {
j.err = &Error{Code: code, Message: msg}
}
func (j *jmessage) toJSON() ([]byte, error) {
var sb bytes.Buffer
sb.WriteString(`{"jsonrpc":"2.0"`)
if len(j.ID) != 0 {
sb.WriteString(`,"id":`)
sb.Write(j.ID)
}
switch {
case j.M != "":
m, err := json.Marshal(j.M)
if err != nil {
return nil, err
}
sb.WriteString(`,"method":`)
sb.Write(m)
if len(j.P) != 0 {
sb.WriteString(`,"params":`)
sb.Write(j.P)
}
case len(j.R) != 0:
sb.WriteString(`,"result":`)
sb.Write(j.R)
case j.E != nil:
e, err := json.Marshal(j.E)
if err != nil {
return nil, err
}
sb.WriteString(`,"error":`)
sb.Write(e)
}
sb.WriteByte('}')
return sb.Bytes(), nil
}
func (j *jmessage) parseJSON(data []byte) error {
// Unmarshal into a map so we can check for extra keys. The json.Decoder
// has DisallowUnknownFields, but fails decoding eagerly for fields that do
// not map to known tags. We want to fully parse the object so we can
// propagate the "id" in error responses, if it is set. So we have to decode
// and check the fields ourselves.
var obj map[string]json.RawMessage
if err := json.Unmarshal(data, &obj); err != nil {
j.fail(code.ParseError, "request is not a JSON object")
return j.err
}
*j = jmessage{} // reset content
var extra []string // extra field names
for key, val := range obj {
switch key {
case "jsonrpc":
if json.Unmarshal(val, &j.V) != nil {
j.fail(code.ParseError, "invalid version key")
}
case "id":
if isValidID(val) {
j.ID = val
} else {
j.fail(code.InvalidRequest, "invalid request ID")
}
case "method":
if json.Unmarshal(val, &j.M) != nil {
j.fail(code.ParseError, "invalid method name")
}
case "params":
// As a special case, reduce "null" to nil in the parameters.
// Otherwise, require per spec that val is an array or object.
if !isNull(val) {
j.P = val
}
if fb := firstByte(j.P); fb != 0 && fb != '[' && fb != '{' {
j.fail(code.InvalidRequest, "parameters must be array or object")
}
case "error":
if json.Unmarshal(val, &j.E) != nil {
j.fail(code.ParseError, "invalid error value")
}
case "result":
j.R = val
default:
extra = append(extra, key)
}
}
// Report an error if request/response fields overlap.
if j.M != "" && (j.E != nil || j.R != nil) {
j.fail(code.InvalidRequest, "mixed request and reply fields")
}
// Report an error for extraneous fields.
if j.err == nil && len(extra) != 0 {
j.err = Errorf(code.InvalidRequest, "extra fields in request").WithData(extra)
}
return nil
}
// isRequestOrNotification reports whether j is a request or notification.
func (j *jmessage) isRequestOrNotification() bool { return j.M != "" && j.E == nil && j.R == nil }
// isNotification reports whether j is a notification
func (j *jmessage) isNotification() bool { return j.isRequestOrNotification() && fixID(j.ID) == nil }
// fixID filters id, treating "null" as a synonym for an unset ID. This
// supports interoperation with JSON-RPC v1 where "null" is used as an ID for
// notifications.
func fixID(id json.RawMessage) json.RawMessage {
if !isNull(id) {
return id
}
return nil
}
// sender is the subset of channel.Channel needed to send messages.
type sender interface{ Send([]byte) error }
// receiver is the subset of channel.Channel needed to receive messages.
type receiver interface{ Recv() ([]byte, error) }
// encode marshals rsps as JSON and forwards it to the channel.
func encode(ch sender, rsps jmessages) (int, error) {
bits, err := rsps.toJSON()
if err != nil {
return 0, err
}
return len(bits), ch.Send(bits)
}
// isNull reports whether msg is exactly the JSON "null" value.
func isNull(msg json.RawMessage) bool {
return len(msg) == 4 && msg[0] == 'n' && msg[1] == 'u' && msg[2] == 'l' && msg[3] == 'l'
}
// firstByte returns the first non-whitespace byte of data, or 0 if there is none.
func firstByte(data []byte) byte {
clean := bytes.TrimSpace(data)
if len(clean) == 0 {
return 0
}
return clean[0]
}
// strictFielder is an optional interface that can be implemented by a type to
// reject unknown fields when unmarshaling from JSON. If a type does not
// implement this interface, unknown fields are ignored.
type strictFielder interface {
DisallowUnknownFields()
}
// StrictFields wraps a value v to require unknown fields to be rejected when
// unmarshaling from JSON.
//
// For example:
//
// var obj RequestType
// err := req.UnmarshalParams(jrpc2.StrictFields(&obj))`
//
func StrictFields(v interface{}) interface{} { return &strict{v: v} }
type strict struct{ v interface{} }
func (s *strict) UnmarshalJSON(data []byte) error {
dec := json.NewDecoder(bytes.NewReader(data))
dec.DisallowUnknownFields()
return dec.Decode(s.v)
}