// Copyright (C) 2017 Michael J. Fromberger. All Rights Reserved. package jrpc2 import ( "context" "encoding/json" "errors" "io" "strconv" "strings" "sync" "time" "github.com/creachadair/jrpc2/channel" "github.com/creachadair/jrpc2/code" "github.com/creachadair/jrpc2/metrics" "golang.org/x/sync/semaphore" ) // A Server is a JSON-RPC 2.0 server. The server receives requests and sends // responses on a channel.Channel provided by the caller, and dispatches // requests to user-defined Handlers. type Server struct { wg sync.WaitGroup // ready when workers are done at shutdown time mux Assigner // associates method names with handlers sem *semaphore.Weighted // bounds concurrent execution (default 1) // Configurable settings allowP bool // allow server notifications to the client log func(string, ...interface{}) // write debug logs here rpcLog RPCLogger // log RPC requests and responses here newctx func() context.Context // create a new base request context dectx decoder // decode context from request metrics *metrics.M // metrics collected during execution start time.Time // when Start was called builtin bool // whether built-in rpc.* methods are enabled mu *sync.Mutex // protects the fields below nbar sync.WaitGroup // notification barrier (see the dispatch method) err error // error from a previous operation work chan struct{} // for signaling message availability inq *queue // inbound requests awaiting processing ch channel.Channel // the channel to the client // For each request ID currently in-flight, this map carries a cancel // function attached to the context that was sent to the handler. used map[string]context.CancelFunc // For each push-call ID currently in flight, this map carries the response // waiting for its reply. call map[string]*Response callID int64 } // NewServer returns a new unstarted server that will dispatch incoming // JSON-RPC requests according to mux. To start serving, call Start. // // N.B. It is only safe to modify mux after the server has been started if mux // itself is safe for concurrent use by multiple goroutines. // // This function will panic if mux == nil. func NewServer(mux Assigner, opts *ServerOptions) *Server { if mux == nil { panic("nil assigner") } s := &Server{ mux: mux, sem: semaphore.NewWeighted(opts.concurrency()), allowP: opts.allowPush(), log: opts.logFunc(), rpcLog: opts.rpcLog(), newctx: opts.newContext(), dectx: opts.decodeContext(), mu: new(sync.Mutex), metrics: opts.metrics(), start: opts.startTime(), builtin: opts.allowBuiltin(), inq: newQueue(), used: make(map[string]context.CancelFunc), call: make(map[string]*Response), callID: 1, } return s } // Start enables processing of requests from c and returns. Start does not // block while the server runs. This function will panic if the server is // already running. It returns s to allow chaining with construction. func (s *Server) Start(c channel.Channel) *Server { s.mu.Lock() defer s.mu.Unlock() if s.ch != nil { panic("server is already running") } // Set up the queues and condition variable used by the workers. s.ch = c if s.start.IsZero() { s.start = time.Now().In(time.UTC) } s.metrics.Count("rpc.serversActive", 1) // Reset all the I/O structures and start up the workers. s.err = nil // Reset the signal channel. s.work = make(chan struct{}, 1) // s.wg waits for the maintenance goroutines for receiving input and // processing the request queue. In addition, each request in flight adds a // goroutine to s.wg. At server shutdown, s.wg completes when the // maintenance goroutines and all pending requests are finished. s.wg.Add(2) // Accept requests from the client and enqueue them for processing. go func() { defer s.wg.Done(); s.read(c) }() // Remove requests from the queue and dispatch them to handlers. go func() { defer s.wg.Done(); s.serve() }() return s } // serve processes requests from the queue and dispatches them to handlers. // The responses are written back by the handler goroutines. // // The flow of an inbound request is: // // serve -- main serving loop // * nextRequest -- process the next request batch // * dispatch // * assign -- assign handlers to requests // | ... // | // * invoke -- invoke handlers // | \ handler -- handle an individual request // | ... // * deliver -- send responses to the client // func (s *Server) serve() { for { next, err := s.nextRequest() if err != nil { s.log("Error reading from client: %v", err) return } s.wg.Add(1) go func() { defer s.wg.Done() next() }() } } func (s *Server) signal() { select { case s.work <- struct{}{}: default: } } // nextRequest blocks until a request batch is available and returns a function // that dispatches it to the appropriate handlers. The result is only an error // if the connection failed; errors reported by the handler are reported to the // caller and not returned here. // // The caller must invoke the returned function to complete the request. func (s *Server) nextRequest() (func() error, error) { s.mu.Lock() defer s.mu.Unlock() for s.ch != nil && s.inq.isEmpty() { s.mu.Unlock() <-s.work s.mu.Lock() } if s.ch == nil && s.inq.isEmpty() { return nil, s.err } ch := s.ch // capture next := s.inq.pop() s.log("Dequeued request batch of length %d (qlen=%d)", len(next), s.inq.size()) // Construct a dispatcher to run the handlers outside the lock. return s.dispatch(next, ch), nil } // waitForBarrier blocks until all notification handlers that have been issued // have completed, then adds n to the barrier. // // The caller must hold s.mu, but the lock is released during the wait to avert // a deadlock with handlers calling back into the server. See #27. // s.nbar counts the number of notifications that have been issued and are not // yet complete. func (s *Server) waitForBarrier(n int) { s.mu.Unlock() defer s.mu.Lock() s.nbar.Wait() s.nbar.Add(n) } // dispatch constructs a function that invokes each of the specified tasks. // The caller must hold s.mu when calling dispatch, but the returned function // should be executed outside the lock to wait for the handlers to return. // // dispatch blocks until any notification received prior to this batch has // completed, to ensure that notifications are processed in a partial order // that respects order of receipt. Notifications within a batch are handled // concurrently. func (s *Server) dispatch(next jmessages, ch sender) func() error { // Resolve all the task handlers or record errors. start := time.Now() tasks := s.checkAndAssign(next) // Ensure all notifications already issued have completed; see #24. todo, notes := tasks.numToDo() s.waitForBarrier(notes) return func() error { var wg sync.WaitGroup for _, t := range tasks { if t.err != nil { continue // nothing to do here; this task has already failed } todo-- if todo == 0 { t.val, t.err = s.invoke(t.ctx, t.m, t.hreq) if t.hreq.IsNotification() { s.nbar.Done() } break } t := t wg.Add(1) go func() { defer wg.Done() t.val, t.err = s.invoke(t.ctx, t.m, t.hreq) if t.hreq.IsNotification() { s.nbar.Done() } }() } // Wait for all the handlers to return, then deliver any responses. wg.Wait() return s.deliver(tasks.responses(s.rpcLog), ch, time.Since(start)) } } // deliver cleans up completed responses and arranges their replies (if any) to // be sent back to the client. func (s *Server) deliver(rsps jmessages, ch sender, elapsed time.Duration) error { if len(rsps) == 0 { return nil } s.log("Completed %d requests [%v elapsed]", len(rsps), elapsed) s.mu.Lock() defer s.mu.Unlock() // Cancel the contexts of all the inflight requests that were executed. // The extra check is necessary, to prevent a duplicate request from // cancelling its valid predecessor in that ID. for _, rsp := range rsps { if rsp.err == nil { s.cancel(string(rsp.ID)) } } nw, err := encode(ch, rsps) s.metrics.CountAndSetMax("rpc.bytesWritten", int64(nw)) return err } // checkAndAssign resolves all the task handlers for the given batch, or // records errors for them as appropriate. The caller must hold s.mu. func (s *Server) checkAndAssign(next jmessages) tasks { var ts tasks var ids []string dup := make(map[string]*task) // :: id ⇒ first task in batch with id // Phase 1: Filter out responses from push calls and check for duplicate // request ID.s for _, req := range next { fid := fixID(req.ID) id := string(fid) if !req.isRequestOrNotification() && s.call[id] != nil { // This is a result or error for a pending push-call. // // N.B. It is important to check for this before checking for // duplicate request IDs, since the ID spaces could overlap. rsp := s.call[id] delete(s.call, id) rsp.ch <- req continue // don't send a reply for this } else if req.err != nil { // keep the existing error } else if !s.versionOK(req.V) { req.err = ErrInvalidVersion } t := &task{ hreq: &Request{id: fid, method: req.M, params: req.P}, batch: req.batch, err: req.err, } if old := dup[id]; old != nil { // A previous task already used this ID, fail both. old.err = errDuplicateID.WithData(id) t.err = old.err } else if id != "" && s.used[id] != nil { // A task from a previous batch already used this ID, fail this one. t.err = errDuplicateID.WithData(id) } else if id != "" { // This is the first task with this ID in the batch. dup[id] = t } ts = append(ts, t) ids = append(ids, id) } // Phase 2: Assign method handlers and set up contexts. for i, t := range ts { id := ids[i] if t.err != nil { // deferred validation error } else if t.hreq.method == "" { t.err = errEmptyMethod } else if s.setContext(t, id) { t.m = s.assign(t.ctx, t.hreq.method) if t.m == nil { t.err = errNoSuchMethod.WithData(t.hreq.method) } } if t.err != nil { s.log("Request check error for %q (params %q): %v", t.hreq.method, string(t.hreq.params), t.err) s.metrics.Count("rpc.errors", 1) } } return ts } // setContext constructs and attaches a request context to t, and reports // whether this succeeded. func (s *Server) setContext(t *task, id string) bool { base, params, err := s.dectx(s.newctx(), t.hreq.method, t.hreq.params) t.hreq.params = params if err != nil { t.err = Errorf(code.InternalError, "invalid request context: %v", err) return false } t.ctx = context.WithValue(base, inboundRequestKey{}, t.hreq) // Store the cancellation for a request that needs a reply, so that we can // respond to cancellation requests. if id != "" { ctx, cancel := context.WithCancel(t.ctx) s.used[id] = cancel t.ctx = ctx } return true } // invoke invokes the handler m for the specified request type, and marshals // the return value into JSON if there is one. func (s *Server) invoke(base context.Context, h Handler, req *Request) (json.RawMessage, error) { ctx := context.WithValue(base, serverKey{}, s) if err := s.sem.Acquire(ctx, 1); err != nil { return nil, err } defer s.sem.Release(1) s.rpcLog.LogRequest(ctx, req) v, err := h.Handle(ctx, req) if err != nil { if req.IsNotification() { s.log("Discarding error from notification to %q: %v", req.Method(), err) return nil, nil // a notification } return nil, err // a call reporting an error } return json.Marshal(v) } // ServerInfo returns an atomic snapshot of the current server info for s. func (s *Server) ServerInfo() *ServerInfo { info := &ServerInfo{ Methods: []string{"*"}, StartTime: s.start, Counter: make(map[string]int64), MaxValue: make(map[string]int64), Label: make(map[string]interface{}), } if n, ok := s.mux.(Namer); ok { info.Methods = n.Names() } s.metrics.Snapshot(metrics.Snapshot{ Counter: info.Counter, MaxValue: info.MaxValue, Label: info.Label, }) return info } // ErrPushUnsupported is returned by the Notify and Call methods if server // pushes are not enabled. var ErrPushUnsupported = errors.New("server push is not enabled") // Notify posts a single server-side notification to the client. // // This is a non-standard extension of JSON-RPC, and may not be supported by // all clients. Unless s was constructed with the AllowPush option set true, // this method will always report an error (ErrPushUnsupported) without sending // anything. If Notify is called after the client connection is closed, it // returns ErrConnClosed. func (s *Server) Notify(ctx context.Context, method string, params interface{}) error { if !s.allowP { return ErrPushUnsupported } _, err := s.pushReq(ctx, false /* no ID */, method, params) return err } // Callback posts a single server-side call to the client. It blocks until a // reply is received, ctx ends, or the client connection terminates. A // successful callback reports a nil error and a non-nil response. Errors // returned by the client have concrete type *jrpc2.Error. // // This is a non-standard extension of JSON-RPC, and may not be supported by // all clients. If you are not sure whether the client supports push calls, you // should set a deadline on ctx, otherwise the callback may block forever for a // client response that will never arrive. // // Unless s was constructed with the AllowPush option set true, this method // will always report an error (ErrPushUnsupported) without sending // anything. If Callback is called after the client connection is closed, it // returns ErrConnClosed. func (s *Server) Callback(ctx context.Context, method string, params interface{}) (*Response, error) { if !s.allowP { return nil, ErrPushUnsupported } rsp, err := s.pushReq(ctx, true /* set ID */, method, params) if err != nil { return nil, err } rsp.wait() if err := rsp.Error(); err != nil { return nil, filterError(err) } return rsp, nil } // waitCallback blocks until pctx ends, and then if p is still waiting for a // response, deliver an error to the caller. func (s *Server) waitCallback(pctx context.Context, id string, p *Response) { <-pctx.Done() s.mu.Lock() defer s.mu.Unlock() if _, ok := s.call[id]; !ok { return } delete(s.call, id) err := pctx.Err() s.log("Context ended for callback id %q, err=%v", id, err) p.ch <- &jmessage{ ID: json.RawMessage(id), E: &Error{Code: code.FromError(err), Message: err.Error()}, } } func (s *Server) pushReq(ctx context.Context, wantID bool, method string, params interface{}) (rsp *Response, _ error) { var bits []byte if params != nil { v, err := json.Marshal(params) if err != nil { return nil, err } bits = v } s.mu.Lock() defer s.mu.Unlock() if s.ch == nil { return nil, ErrConnClosed } kind := "notification" var jid json.RawMessage if wantID { kind = "call" id := strconv.FormatInt(s.callID, 10) s.callID++ cbctx, cancel := context.WithCancel(ctx) jid = json.RawMessage(id) rsp = &Response{ ch: make(chan *jmessage, 1), id: id, cancel: cancel, } s.call[id] = rsp go s.waitCallback(cbctx, id, rsp) } s.log("Posting server %s %q %s", kind, method, string(bits)) nw, err := encode(s.ch, jmessages{{ ID: jid, M: method, P: bits, }}) s.metrics.CountAndSetMax("rpc.bytesWritten", int64(nw)) s.metrics.Count("rpc."+kind+"sPushed", 1) return rsp, err } // Metrics returns the server metrics collector for s. If s does not define a // collector, this method returns nil, which is ready for use but discards all // metrics. func (s *Server) Metrics() *metrics.M { return s.metrics } // Stop shuts down the server. It is safe to call this method multiple times or // from concurrent goroutines; it will only take effect once. func (s *Server) Stop() { s.mu.Lock() defer s.mu.Unlock() s.stop(errServerStopped) } // ServerStatus describes the status of a stopped server. // // A server is said to have succeeded if it stopped because the client channel // closed or because its Stop method was called. On success, Err == nil, and // the flag fields indicate the reason why the server exited. // Otherwise, Err != nil is the error value that caused the server to exit. type ServerStatus struct { Err error // the error that caused the server to stop (nil on success) // On success, these flags explain the reason why the server stopped. // At most one of these fields will be true. Stopped bool // server exited because Stop was called Closed bool // server exited because the client channel closed } // Success reports whether the server exited without error. func (s ServerStatus) Success() bool { return s.Err == nil } // WaitStatus blocks until the server terminates, and returns the resulting // status. After WaitStatus returns, whether or not there was an error, it is // safe to call s.Start again to restart the server with a fresh channel. func (s *Server) WaitStatus() ServerStatus { s.wg.Wait() // Postcondition check. if !s.inq.isEmpty() { panic("s.inq is not empty at shutdown") } stat := ServerStatus{Err: s.err} if s.err == io.EOF || channel.IsErrClosing(s.err) { stat.Err = nil stat.Closed = true } else if s.err == errServerStopped { stat.Err = nil stat.Stopped = true } return stat } // Wait blocks until the server terminates and returns the resulting error. // It is equivalent to s.WaitStatus().Err. func (s *Server) Wait() error { return s.WaitStatus().Err } // stop shuts down the connection and records err as its final state. The // caller must hold s.mu. If multiple callers invoke stop, only the first will // successfully record its error status. func (s *Server) stop(err error) { if s.ch == nil { return // nothing is running } s.log("Server signaled to stop with err=%v", err) s.ch.Close() // Remove any pending requests from the queue, but retain notifications. // The server will process pending notifications before giving up. // // TODO(@creachadair): We need better tests for this behaviour. var keep jmessages s.inq.each(func(cur jmessages) { for _, req := range cur { if req.isNotification() { keep = append(keep, req) s.log("Retaining notification %p", req) } else { s.cancel(string(req.ID)) } } }) s.inq.reset() for _, elt := range keep { s.inq.push(jmessages{elt}) } close(s.work) // Cancel any in-flight requests that made it out of the queue, and // terminate any pending callback invocations. for _, rsp := range s.call { rsp.cancel() // the waiter will clean up the map } for id, cancel := range s.used { cancel() delete(s.used, id) } // Postcondition check. if len(s.used) != 0 { panic("s.used is not empty at shutdown") } s.err = err s.ch = nil s.metrics.Count("rpc.serversActive", -1) } // read is the main receiver loop, decoding requests from the client and adding // them to the queue. Decoding errors and message-format problems are handled // and reported back to the client directly, so that any message that survives // into the request queue is structurally valid. func (s *Server) read(ch receiver) { for { // If the message is not sensible, report an error; otherwise enqueue it // for processing. Errors in individual requests are handled later. var in jmessages var derr error bits, err := ch.Recv() s.metrics.CountAndSetMax("rpc.bytesRead", int64(len(bits))) if err == nil || (err == io.EOF && len(bits) != 0) { err = nil derr = in.parseJSON(bits) s.metrics.Count("rpc.requests", int64(len(in))) } s.mu.Lock() if err != nil { // receive failure; shut down s.stop(err) s.mu.Unlock() return } else if derr != nil { // parse failure; report and continue s.pushError(derr) } else if len(in) == 0 { s.pushError(errEmptyBatch) } else { s.log("Received request batch of size %d (qlen=%d)", len(in), s.inq.size()) s.inq.push(in) if s.inq.size() == 1 { // the queue was empty s.signal() } } s.mu.Unlock() } } // ServerInfo is the concrete type of responses from the rpc.serverInfo method. type ServerInfo struct { // The list of method names exported by this server. Methods []string `json:"methods,omitempty"` // Metric values defined by the evaluation of methods. Counter map[string]int64 `json:"counters,omitempty"` MaxValue map[string]int64 `json:"maxValue,omitempty"` Label map[string]interface{} `json:"labels,omitempty"` // When the server started. StartTime time.Time `json:"startTime,omitempty"` } // assign returns a Handler to handle the specified name, or nil. // The caller must hold s.mu. func (s *Server) assign(ctx context.Context, name string) Handler { if s.builtin && strings.HasPrefix(name, "rpc.") { switch name { case rpcServerInfo: return methodFunc(s.handleRPCServerInfo) default: return nil // reserved } } return s.mux.Assign(ctx, name) } // pushError reports an error for the given request ID directly back to the // client, bypassing the normal request handling mechanism. The caller must // hold s.mu when calling this method. func (s *Server) pushError(err error) { s.log("Invalid request: %v", err) var jerr *Error if e, ok := err.(*Error); ok { jerr = e } else { jerr = &Error{Code: code.FromError(err), Message: err.Error()} } nw, err := encode(s.ch, jmessages{{ ID: json.RawMessage("null"), E: jerr, }}) s.metrics.Count("rpc.errors", 1) s.metrics.CountAndSetMax("rpc.bytesWritten", int64(nw)) if err != nil { s.log("Writing error response: %v", err) } } // cancel reports whether id is an active call. If so, it also calls the // cancellation function associated with id and removes it from the // reservations. The caller must hold s.mu. func (s *Server) cancel(id string) bool { cancel, ok := s.used[id] if ok { cancel() delete(s.used, id) } return ok } func (s *Server) versionOK(v string) bool { return v == Version } // A task represents a pending method invocation received by the server. type task struct { m Handler // the assigned handler (after assignment) ctx context.Context // the context passed to the handler hreq *Request // the request passed to the handler batch bool // whether the request was part of a batch val json.RawMessage // the result value (when complete) err error // the error value (when complete) } type tasks []*task func (ts tasks) responses(rpcLog RPCLogger) jmessages { var rsps jmessages for _, task := range ts { if task.hreq.id == nil { // Spec: "The Server MUST NOT reply to a Notification, including // those that are within a batch request. Notifications are not // confirmable by definition, since they do not have a Response // object to be returned. As such, the Client would not be aware of // any errors." // // However, parse and validation errors must still be reported, with // an ID of null if the request ID was not resolvable. if c := code.FromError(task.err); c != code.ParseError && c != code.InvalidRequest { continue } } rsp := &jmessage{ID: task.hreq.id, batch: task.batch} if rsp.ID == nil { rsp.ID = json.RawMessage("null") } if task.m == nil { // No method was ever assigned for this task, so it was never run. rsp.err = errors.New("task not executed") } if task.err == nil { rsp.R = task.val } else if e, ok := task.err.(*Error); ok { rsp.E = e } else if c := code.FromError(task.err); c != code.NoError { rsp.E = &Error{Code: c, Message: task.err.Error()} } else { rsp.E = &Error{Code: code.InternalError, Message: task.err.Error()} } rpcLog.LogResponse(task.ctx, &Response{ id: string(rsp.ID), err: rsp.E, result: rsp.R, }) rsps = append(rsps, rsp) } return rsps } // numToDo reports the number of tasks in ts that need to be executed, and the // number of those that are notifications. func (ts tasks) numToDo() (todo, notes int) { for _, t := range ts { if t.err == nil { todo++ if t.hreq.IsNotification() { notes++ } } } return }