package jrpc2 import ( "bytes" "context" "encoding/json" "fmt" "strings" "github.com/creachadair/jrpc2/channel" "github.com/creachadair/jrpc2/code" ) // An Assigner assigns a Handler to handle the specified method name, or nil if // no method is available to handle the request. type Assigner interface { // Assign returns the handler for the named method, or nil. Assign(ctx context.Context, method string) Handler // Names returns a slice of all known method names for the assigner. The // resulting slice is ordered lexicographically and contains no duplicates. Names() []string } // A Handler handles a single request. type Handler interface { // Handle invokes the method with the specified request. The response value // must be JSON-marshalable or nil. In case of error, the handler can // return a value of type *jrpc2.Error to control the response code sent // back to the caller; otherwise the server will wrap the resulting value. // // The context passed to the handler by a *jrpc2.Server includes two extra // values that the handler may extract. // // To obtain a server metrics value, write: // // sm := jrpc2.ServerMetrics(ctx) // // To obtain the inbound request message, write: // // req := jrpc2.InboundRequest(ctx) // // The inbound request is the same value passed to the Handle method -- the // latter is primarily useful in handlers generated by handler.New, which do // not receive this value directly. Handle(context.Context, *Request) (interface{}, error) } // A Request is a request message from a client to a server. type Request struct { id json.RawMessage // the request ID, nil for notifications method string // the name of the method being requested params json.RawMessage // method parameters } // IsNotification reports whether the request is a notification, and thus does // not require a value response. func (r *Request) IsNotification() bool { return r.id == nil } // ID returns the request identifier for r, or "" if r is a notification. func (r *Request) ID() string { return string(r.id) } // Method reports the method name for the request. func (r *Request) Method() string { return r.method } // HasParams reports whether the request has non-empty parameters. func (r *Request) HasParams() bool { return len(r.params) != 0 } // UnmarshalParams decodes the request parameters of r into v. If r has empty // parameters, it returns nil without modifying v. If r is invalid it returns // an InvalidParams error. // // By default, unknown object keys are ignored when unmarshaling into a v of // struct type. This can be overridden either by giving the type of v a custom // implementation of json.Unmarshaler, or implementing a DisallowUnknownFields // method. The jrpc2.StrictFields helper function adapts existing values to // this interface. // // If v has type *json.RawMessage, decoding cannot fail. func (r *Request) UnmarshalParams(v interface{}) error { if len(r.params) == 0 { return nil } switch t := v.(type) { case *json.RawMessage: *t = json.RawMessage(string(r.params)) // copy return nil case strictFielder: dec := json.NewDecoder(bytes.NewReader(r.params)) dec.DisallowUnknownFields() if err := dec.Decode(v); err != nil { return Errorf(code.InvalidParams, "invalid parameters: %v", err.Error()) } return nil } return json.Unmarshal(r.params, v) } // ParamString returns the encoded request parameters of r as a string. // If r has no parameters, it returns "". func (r *Request) ParamString() string { return string(r.params) } // ErrInvalidVersion is returned by ParseRequests if one or more of the // requests in the input has a missing or invalid version marker. var ErrInvalidVersion = Errorf(code.InvalidRequest, "incorrect version marker") // ParseRequests parses a single request or a batch of requests from JSON. // The result parameters are either nil or have concrete type json.RawMessage. // // If any of the requests is missing or has an invalid JSON-RPC version, it // returns ErrInvalidVersion along with the parsed results. Otherwise, no // validation apart from basic structure is performed on the results. func ParseRequests(msg []byte) ([]*Request, error) { var req jmessages if err := req.parseJSON(msg); err != nil { return nil, err } var err error out := make([]*Request, len(req)) for i, req := range req { if req.V != Version { err = ErrInvalidVersion } out[i] = &Request{ id: fixID(req.ID), method: req.M, params: req.P, } } return out, err } // A Response is a response message from a server to a client. type Response struct { id string err *Error result json.RawMessage // Waiters synchronize on reading from ch. The first successful reader from // ch completes the request and is responsible for updating rsp and then // closing ch. The client owns writing to ch, and is responsible to ensure // that at most one write is ever performed. ch chan *jmessage cancel func() } // ID returns the request identifier for r. func (r *Response) ID() string { return r.id } // SetID sets the request identifier for r. This is for use in proxies. func (r *Response) SetID(id string) { r.id = id } // Error returns a non-nil *Error if the response contains an error. func (r *Response) Error() *Error { return r.err } // UnmarshalResult decodes the result message into v. If the request failed, // UnmarshalResult returns the *Error value that would also be returned by // r.Error(), and v is unmodified. // // By default, unknown object keys are ignored when unmarshaling into a v of // struct type. This can be overridden either by giving the type of v a custom // implementation of json.Unmarshaler, or implementing a DisallowUnknownFields // method. The jrpc2.StrictFields helper function adapts existing values to // this interface. func (r *Response) UnmarshalResult(v interface{}) error { if r.err != nil { return r.err } switch t := v.(type) { case *json.RawMessage: *t = json.RawMessage(string(r.result)) // copy return nil case strictFielder: dec := json.NewDecoder(bytes.NewReader(r.result)) dec.DisallowUnknownFields() return dec.Decode(v) } return json.Unmarshal(r.result, v) } // ResultString returns the encoded result message of r as a string. // If r has no result, for example if r is an error response, it returns "". func (r *Response) ResultString() string { return string(r.result) } // MarshalJSON converts the response to equivalent JSON. func (r *Response) MarshalJSON() ([]byte, error) { return json.Marshal(&jmessage{ V: Version, ID: json.RawMessage(r.id), R: r.result, E: r.err, }) } // wait blocks until r is complete. It is safe to call this multiple times and // from concurrent goroutines. func (r *Response) wait() { raw, ok := <-r.ch if ok { // N.B. We intentionally DO NOT have the sender close the channel, to // prevent a data race between callers of Wait. The channel is closed // by the first waiter to get a real value (ok == true). // // The first waiter must update the response value, THEN close the // channel and cancel the context. This order ensures that subsequent // waiters all get the same response, and do not race on accessing it. r.err = raw.E r.result = raw.R close(r.ch) r.cancel() // release the context observer // Safety check: The response IDs should match. Do this after delivery so // a failure does not orphan resources. if id := string(fixID(raw.ID)); id != r.id { panic(fmt.Sprintf("Mismatched response ID %q expecting %q", id, r.id)) } } } // jmessages is either a single protocol message or an array of protocol // messages. This handles the decoding of batch requests in JSON-RPC 2.0. type jmessages []*jmessage func (j jmessages) toJSON() ([]byte, error) { if len(j) == 1 && !j[0].batch { return json.Marshal(j[0]) } return json.Marshal([]*jmessage(j)) } // N.B. Not UnmarshalJSON, because json.Unmarshal checks for validity early and // here we want to control the error that is returned. func (j *jmessages) parseJSON(data []byte) error { *j = (*j)[:0] // reset state // When parsing requests, validation checks are deferred: The only immediate // mode of failure for unmarshaling is if the request is not a valid object // or array. var msgs []json.RawMessage var batch bool if len(data) == 0 || data[0] != '[' { msgs = append(msgs, nil) if err := json.Unmarshal(data, &msgs[0]); err != nil { return Errorf(code.ParseError, "invalid request message") } } else if err := json.Unmarshal(data, &msgs); err != nil { return Errorf(code.ParseError, "invalid request batch") } else { batch = true } // Now parse the individual request messages, but do not fail on errors. We // know that the messages are intact, but validity is checked at usage. for _, raw := range msgs { req := new(jmessage) req.parseJSON(raw) req.batch = batch *j = append(*j, req) } return nil } // jmessage is the transmission format of a protocol message. type jmessage struct { V string `json:"jsonrpc"` // must be Version ID json.RawMessage `json:"id,omitempty"` // may be nil // Fields belonging to request or notification objects M string `json:"method,omitempty"` P json.RawMessage `json:"params,omitempty"` // may be nil // Fields belonging to response or error objects E *Error `json:"error,omitempty"` // set on error R json.RawMessage `json:"result,omitempty"` // set on success // N.B.: In a valid protocol message, M and P are mutually exclusive with E // and R. Specifically, if M != "" then E and R must both be unset. This is // checked during parsing. batch bool // this message was part of a batch err error // if not nil, this message is invalid and err is why } func (j *jmessage) fail(code code.Code, msg string) error { j.err = Errorf(code, msg) return j.err } func (j *jmessage) parseJSON(data []byte) error { // Unmarshal into a map so we can check for extra keys. The json.Decoder // has DisallowUnknownFields, but fails decoding eagerly for fields that do // not map to known tags. We want to fully parse the object so we can // propagate the "id" in error responses, if it is set. So we have to decode // and check the fields ourselves. var obj map[string]json.RawMessage if err := json.Unmarshal(data, &obj); err != nil { return j.fail(code.ParseError, "request is not a JSON object") } *j = jmessage{} // reset content var extra []string // extra field names for key, val := range obj { switch key { case "jsonrpc": if json.Unmarshal(val, &j.V) != nil { j.fail(code.ParseError, "invalid version key") } case "id": j.ID = val case "method": if json.Unmarshal(val, &j.M) != nil { j.fail(code.ParseError, "invalid method name") } case "params": // As a special case, reduce "null" to nil in the parameters. // Otherwise, require per spec that val is an array or object. if !isNull(val) { j.P = val } if len(j.P) != 0 && j.P[0] != '[' && j.P[0] != '{' { j.fail(code.InvalidRequest, "parameters must be array or object") } case "error": if json.Unmarshal(val, &j.E) != nil { j.fail(code.ParseError, "invalid error value") } case "result": j.R = val default: extra = append(extra, key) } } // Report an error if request/response fields overlap. if j.M != "" && (j.E != nil || j.R != nil) { j.fail(code.InvalidRequest, "mixed request and reply fields") } // Report an error for extraneous fields. if j.err == nil && len(extra) != 0 { j.err = DataErrorf(code.InvalidRequest, extra, "extra fields in request") } return nil } // isRequestOrNotification reports whether j is a request or notification. func (j *jmessage) isRequestOrNotification() bool { return j.E == nil && j.R == nil && j.M != "" } // isNotification reports whether j is a notification func (j *jmessage) isNotification() bool { return j.isRequestOrNotification() && fixID(j.ID) == nil } type jerror struct { C int32 `json:"code"` M string `json:"message,omitempty"` D json.RawMessage `json:"data,omitempty"` } // fixID filters id, treating "null" as a synonym for an unset ID. This // supports interoperation with JSON-RPC v1 where "null" is used as an ID for // notifications. func fixID(id json.RawMessage) json.RawMessage { if !isNull(id) { return id } return nil } // encode marshals rsps as JSON and forwards it to the channel. func encode(ch channel.Sender, rsps jmessages) (int, error) { bits, err := rsps.toJSON() if err != nil { return 0, err } return len(bits), ch.Send(bits) } // Network guesses a network type for the specified address. The assignment of // a network type uses the following heuristics: // // If s does not have the form [host]:port, the network is assigned as "unix". // The network "unix" is also assigned if port == "", port contains characters // other than ASCII letters, digits, and "-", or if host contains a "/". // // Otherwise, the network is assigned as "tcp". Note that this function does // not verify whether the address is lexically valid. func Network(s string) string { i := strings.LastIndex(s, ":") if i < 0 { return "unix" } host, port := s[:i], s[i+1:] if port == "" || !isServiceName(port) { return "unix" } else if strings.IndexByte(host, '/') >= 0 { return "unix" } return "tcp" } // isServiceName reports whether s looks like a legal service name from the // services(5) file. The grammar of such names is not well-defined, but for our // purposes it includes letters, digits, and "-". func isServiceName(s string) bool { for i := range s { b := s[i] if b >= '0' && b <= '9' || b >= 'A' && b <= 'Z' || b >= 'a' && b <= 'z' || b == '-' { continue } return false } return true } // isNull reports whether msg is exactly the JSON "null" value. func isNull(msg json.RawMessage) bool { return len(msg) == 4 && msg[0] == 'n' && msg[1] == 'u' && msg[2] == 'l' && msg[3] == 'l' } // filterError filters an *Error value to distinguish context errors from other // error types. If err is not a context error, it is returned unchanged. func filterError(e *Error) error { switch e.code { case code.Cancelled: return context.Canceled case code.DeadlineExceeded: return context.DeadlineExceeded } return e } // strictFielder is an optional interface that can be implemented by a type to // reject unknown fields when unmarshaling from JSON. If a type does not // implement this interface, unknown fields are ignored. type strictFielder interface { DisallowUnknownFields() } // StrictFields wraps a value v to implement the DisallowUnknownFields method, // requiring unknown fields to be rejected when unmarshaling from JSON. // // For example: // // var obj RequestType // err := req.UnmarshalParams(jrpc2.StrictFields(&obj))` // func StrictFields(v interface{}) interface{} { return &strict{v: v} } type strict struct{ v interface{} } func (strict) DisallowUnknownFields() {}