// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. package jrpc2 import ( "context" "encoding/json" "errors" "fmt" "io" "strconv" "sync" "github.com/creachadair/jrpc2/channel" "github.com/creachadair/jrpc2/code" ) // A Client is a JSON-RPC 2.0 client. The client sends requests and receives // responses on a channel.Channel provided by the caller. type Client struct { done *sync.WaitGroup // done when the reader is finished at shutdown time log func(string, ...interface{}) // write debug logs here enctx encoder snote func(*jmessage) scall func(context.Context, *jmessage) []byte chook func(*Client, *Response) cbctx context.Context // terminates when the client is closed cbcancel func() // cancels cbctx mu sync.Mutex // protects the fields below ch channel.Channel // channel to the server err error // error from a previous operation pending map[string]*Response // requests pending completion, by ID nextID int64 // next unused request ID } // NewClient returns a new client that communicates with the server via ch. func NewClient(ch channel.Channel, opts *ClientOptions) *Client { cbctx, cbcancel := context.WithCancel(context.Background()) c := &Client{ done: new(sync.WaitGroup), log: opts.logFunc(), enctx: opts.encodeContext(), snote: opts.handleNotification(), scall: opts.handleCallback(), chook: opts.handleCancel(), cbctx: cbctx, cbcancel: cbcancel, // Lock-protected fields ch: ch, pending: make(map[string]*Response), nextID: 1, // Note that we start the ID counter at 1 here to avoid issues with a // server implementation that treats 0 as equivalent to null. } // The main client loop reads responses from the server and delivers them // back to pending requests by their ID. Outbound requests do not queue; // they are sent synchronously in the Send method. c.done.Add(1) go func() { defer c.done.Done() for c.accept(ch) == nil { } }() return c } // accept receives the next batch of responses from the server. This may // either be a list or a single object, the decoder for jmessages knows how to // handle both. The caller must not hold c.mu. func (c *Client) accept(ch receiver) error { var in jmessages bits, err := ch.Recv() if err == nil { err = in.parseJSON(bits) } if err != nil { if !isUninteresting(err) { c.log("Decoding error: %v", err) } c.mu.Lock() c.stop(err) c.mu.Unlock() return err } c.log("Received %d responses", len(in)) c.done.Add(1) go func() { defer c.done.Done() c.mu.Lock() defer c.mu.Unlock() for _, rsp := range in { c.deliver(rsp) } }() return nil } // handleRequest handles a callback or notification from the server. The // caller must hold c.mu. This function does not block for the handler. // Precondition: msg is a request or notification, not a response or error. func (c *Client) handleRequest(msg *jmessage) { if msg.isNotification() { if c.snote == nil { c.log("Discarding notification: %v", msg) } else { c.snote(msg) } } else if c.scall == nil { c.log("Discarding callback request: %v", msg) } else if c.ch == nil { c.log("Client channel is closed; discarding callback: %v", msg) } else { // Run the callback handler in its own goroutine. The context will be // cancelled automatically when the client is closed. ctx := context.WithValue(c.cbctx, clientKey{}, c) c.done.Add(1) go func() { defer c.done.Done() bits := c.scall(ctx, msg) c.mu.Lock() defer c.mu.Unlock() if c.err != nil { c.log("Discarding callback response: %v", c.err) } else if err := c.ch.Send(bits); err != nil { c.log("Sending reply for callback %v failed: %v", msg, err) } }() } } // For each response, find the request pending on its ID and deliver it. The // caller must hold c.mu. Unknown response IDs are logged and discarded. As // we are under the lock, we do not wait for the pending receiver to pick up // the response; we just drop it in their channel. The channel is buffered so // we don't need to rendezvous. func (c *Client) deliver(rsp *jmessage) { if rsp.isRequestOrNotification() { c.handleRequest(rsp) return } id := string(fixID(rsp.ID)) if p := c.pending[id]; p == nil { c.log("Discarding response for unknown ID %q", id) } else if !c.versionOK(rsp.V) { delete(c.pending, id) p.ch <- &jmessage{ ID: rsp.ID, E: &Error{ Code: code.InvalidRequest, Message: fmt.Sprintf("incorrect version marker %q", rsp.V), }, } c.log("Invalid response for ID %q", id) } else { // Remove the pending request from the set and deliver its response. // Determining whether it's an error is the caller's responsibility. delete(c.pending, id) p.ch <- rsp c.log("Completed request for ID %q", id) } } // req constructs a fresh request for the specified method and parameters. // This does not transmit the request to the server; use c.send to do so. func (c *Client) req(ctx context.Context, method string, params interface{}) (*jmessage, error) { bits, err := c.marshalParams(ctx, method, params) if err != nil { return nil, err } c.mu.Lock() defer c.mu.Unlock() id := json.RawMessage(strconv.FormatInt(c.nextID, 10)) c.nextID++ return &jmessage{ ID: id, M: method, P: bits, }, nil } // note constructs a notification request for the specified method and parameters. func (c *Client) note(ctx context.Context, method string, params interface{}) (*jmessage, error) { bits, err := c.marshalParams(ctx, method, params) if err != nil { return nil, err } return &jmessage{M: method, P: bits}, nil } // send transmits the specified requests to the server and returns a slice of // pending responses awaiting a reply from the server. // // The resulting slice will contain one entry for each input request that // expects a response (that is, all those that are not notifications). If all // the requests are notifications, the slice will be empty. // // This method blocks until the entire batch of requests has been transmitted. func (c *Client) send(ctx context.Context, reqs jmessages) ([]*Response, error) { if len(reqs) == 0 { return nil, errors.New("empty request batch") } // Marshal and prepare responses outside the lock. This may wind up being // wasted work if there is already a failure, but in that case we're already // on a closing path. b, err := reqs.toJSON() if err != nil { return nil, Errorf(code.InternalError, "marshaling request failed: %v", err) } var pends []*Response var pctxs []context.Context for _, req := range reqs { if id := string(req.ID); id != "" { pctx, p := newPending(ctx, id) pends = append(pends, p) pctxs = append(pctxs, pctx) } } c.mu.Lock() defer c.mu.Unlock() if c.err != nil { return nil, c.err } c.log("Outgoing batch: %s", string(b)) if err := c.ch.Send(b); err != nil { return nil, err } // Now that we have sent them, record the requests for which we are awaiting // replies. We do this after transmission so that an error in sending does // not leave us with zombies that will never be fulfilled. for i, p := range pends { c.pending[p.id] = p go c.waitComplete(pctxs[i], p.id, p) } return pends, nil } // waitComplete waits for completion of the context governing p. When the // context ends, check whether the request is still in the pending set for the // client: If so, a reply has not yet been delivered. Otherwise, the // cancellation is a no-op ("too late"). func (c *Client) waitComplete(pctx context.Context, id string, p *Response) { <-pctx.Done() cleanup := func() {} c.mu.Lock() defer func() { c.mu.Unlock() cleanup() // N.B. outside the lock }() if _, ok := c.pending[id]; !ok { return } err := pctx.Err() c.log("Context ended for id %q, err=%v", id, err) delete(c.pending, id) var jerr *Error if c.err != nil && !isUninteresting(c.err) { jerr = &Error{Code: code.InternalError, Message: c.err.Error()} } else if err != nil { jerr = &Error{Code: code.FromError(err), Message: err.Error()} } p.ch <- &jmessage{ ID: json.RawMessage(id), E: jerr, } // If there is a cancellation hook, give it a chance to run. if c.chook != nil { cleanup = func() { p.wait() // ensure the response has settled c.log("Calling OnCancel for id %q", id) c.chook(c, p) } } } // Call initiates a single request and blocks until the response returns. // A successful call reports a nil error and a non-nil response. Errors from // the server have concrete type *jrpc2.Error. // // rsp, err := c.Call(ctx, method, params) // if e, ok := err.(*jrpc2.Error); ok { // log.Fatalf("Error from server: %v", err) // } else if err != nil { // log.Fatalf("Call failed: %v", err) // } // handleValidResponse(rsp) // func (c *Client) Call(ctx context.Context, method string, params interface{}) (*Response, error) { req, err := c.req(ctx, method, params) if err != nil { return nil, err } rsp, err := c.send(ctx, jmessages{req}) if err != nil { return nil, err } rsp[0].wait() if err := rsp[0].Error(); err != nil { return nil, filterError(err) } return rsp[0], nil } // CallResult invokes Call with the given method and params. If it succeeds, // the result is decoded into result. This is a convenient shorthand for Call // followed by UnmarshalResult. It will panic if result == nil. func (c *Client) CallResult(ctx context.Context, method string, params, result interface{}) error { rsp, err := c.Call(ctx, method, params) if err != nil { return err } return rsp.UnmarshalResult(result) } // Batch initiates a batch of concurrent requests, and blocks until all the // responses return. The responses are returned in the same order as the // original specs, omitting notifications. // // Any error reported by Batch represents an error in encoding or sending the // batch to the server. Errors reported by the server in response to requests // must be recovered from the responses. func (c *Client) Batch(ctx context.Context, specs []Spec) ([]*Response, error) { reqs := make(jmessages, len(specs)) for i, spec := range specs { var req *jmessage var err error if spec.Notify { req, err = c.note(ctx, spec.Method, spec.Params) } else { req, err = c.req(ctx, spec.Method, spec.Params) } if err != nil { return nil, err } reqs[i] = req } rsps, err := c.send(ctx, reqs) if err != nil { return nil, err } for _, rsp := range rsps { rsp.wait() } return rsps, nil } // A Spec combines a method name and parameter value as part of a Batch. If // the Notify field is true, the request is sent as a notification. type Spec struct { Method string Params interface{} Notify bool } // Notify transmits a notification to the specified method and parameters. It // blocks until the notification has been sent. func (c *Client) Notify(ctx context.Context, method string, params interface{}) error { req, err := c.note(ctx, method, params) if err != nil { return err } _, err = c.send(ctx, jmessages{req}) return err } // Close shuts down the client, terminating any pending in-flight requests. func (c *Client) Close() error { c.mu.Lock() c.stop(errClientStopped) c.mu.Unlock() c.done.Wait() // Don't remark on a closed channel or EOF as a noteworthy failure. if isUninteresting(c.err) { return nil } return c.err } func isUninteresting(err error) bool { return err == io.EOF || channel.IsErrClosing(err) || err == errClientStopped } // stop closes down the reader for c and records err as its final state. The // caller must hold c.mu. If multiple callers invoke stop, only the first will // successfully record its error status. func (c *Client) stop(err error) { if c.ch == nil { return // nothing is running } c.ch.Close() // Unblock and fail any pending callbacks. c.cbcancel() // Unblock and fail any pending requests. for _, p := range c.pending { p.cancel() } c.err = err c.ch = nil } func (c *Client) versionOK(v string) bool { return v == Version } // marshalParams validates and marshals params to JSON for a request. The // value of params must be either nil or encodable as a JSON object or array. func (c *Client) marshalParams(ctx context.Context, method string, params interface{}) (json.RawMessage, error) { if params == nil { return c.enctx(ctx, method, nil) // no parameters, that is OK } pbits, err := json.Marshal(params) if err != nil { return nil, err } if fb := firstByte(pbits); fb != '[' && fb != '{' && !isNull(pbits) { // JSON-RPC requires that if parameters are provided at all, they are // an array or an object. return nil, &Error{Code: code.InvalidRequest, Message: "invalid parameters: array or object required"} } bits, err := c.enctx(ctx, method, pbits) if err != nil { return nil, err } return bits, err } func newPending(ctx context.Context, id string) (context.Context, *Response) { // Buffer the channel so the response reader does not need to rendezvous // with the recipient. pctx, cancel := context.WithCancel(ctx) return pctx, &Response{ ch: make(chan *jmessage, 1), id: id, cancel: cancel, } }