715 lines
21 KiB
Go

package jrpc2
import (
"container/list"
"context"
"encoding/json"
"io"
"strconv"
"strings"
"sync"
"time"
"github.com/creachadair/jrpc2/channel"
"github.com/creachadair/jrpc2/code"
"github.com/creachadair/jrpc2/metrics"
"golang.org/x/sync/semaphore"
)
type logger = func(string, ...interface{})
// A Server is a JSON-RPC 2.0 server. The server receives requests and sends
// responses on a channel.Channel provided by the caller, and dispatches
// requests to user-defined Handlers.
type Server struct {
wg sync.WaitGroup // ready when workers are done at shutdown time
mux Assigner // associates method names with handlers
sem *semaphore.Weighted // bounds concurrent execution (default 1)
allow1 bool // allow v1 requests with no version marker
allowP bool // allow server notifications to the client
log logger // write debug logs here
rpcLog RPCLogger // log RPC requests and responses here
dectx decoder // decode context from request
ckreq verifier // request checking hook
expctx bool // whether to expect request context
metrics *metrics.M // metrics collected during execution
start time.Time // when Start was called
builtin bool // whether built-in rpc.* methods are enabled
mu *sync.Mutex // protects the fields below
nbar sync.WaitGroup // notification barrier (see the dispatch method)
err error // error from a previous operation
work *sync.Cond // for signaling message availability
inq *list.List // inbound requests awaiting processing
ch channel.Channel // the channel to the client
// For each request ID currently in-flight, this map carries a cancel
// function attached to the context that was sent to the handler.
used map[string]context.CancelFunc
// For each push-call ID currently in flight, this map carries the response
// waiting for its reply.
call map[string]*Response
callID int64
}
// NewServer returns a new unstarted server that will dispatch incoming
// JSON-RPC requests according to mux. To start serving, call Start.
//
// N.B. It is only safe to modify mux after the server has been started if mux
// itself is safe for concurrent use by multiple goroutines.
//
// This function will panic if mux == nil.
func NewServer(mux Assigner, opts *ServerOptions) *Server {
if mux == nil {
panic("nil assigner")
}
dc, exp := opts.decodeContext()
s := &Server{
mux: mux,
sem: semaphore.NewWeighted(opts.concurrency()),
allow1: opts.allowV1(),
allowP: opts.allowPush(),
log: opts.logger(),
rpcLog: opts.rpcLog(),
dectx: dc,
ckreq: opts.checkRequest(),
expctx: exp,
mu: new(sync.Mutex),
metrics: opts.metrics(),
start: opts.startTime(),
builtin: opts.allowBuiltin(),
inq: list.New(),
used: make(map[string]context.CancelFunc),
call: make(map[string]*Response),
callID: 1,
}
s.work = sync.NewCond(s.mu)
return s
}
// Start enables processing of requests from c. This function will panic if the
// server is already running.
func (s *Server) Start(c channel.Channel) *Server {
s.mu.Lock()
defer s.mu.Unlock()
if s.ch != nil {
panic("server is already running")
}
// Set up the queues and condition variable used by the workers.
s.ch = c
if s.start.IsZero() {
s.start = time.Now().In(time.UTC)
}
// Reset all the I/O structures and start up the workers.
s.err = nil
// s.wg waits for the maintenance goroutines for receiving input and
// processing the request queue. In addition, each request in flight adds a
// goroutine to s.wg. At server shutdown, s.wg completes when the
// maintenance goroutines and all pending requests are finished.
s.wg.Add(2)
// Accept requests from the client and enqueue them for processing.
go func() { defer s.wg.Done(); s.read(c) }()
// Remove requests from the queue and dispatch them to handlers.
go func() { defer s.wg.Done(); s.serve() }()
return s
}
// serve processes requests from the queue and dispatches them to handlers.
// The responses are written back by the handler goroutines.
//
// The flow of an inbound request is:
//
// serve -- main serving loop
// * nextRequest -- process the next request batch
// * dispatch
// * assign -- assign handlers to requests
// | ...
// |
// * invoke -- invoke handlers
// | \ handler -- handle an individual request
// | ...
// * deliver -- send responses to the client
//
func (s *Server) serve() {
for {
next, err := s.nextRequest()
if err != nil {
s.log("Reading next request: %v", err)
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
next()
}()
}
}
// nextRequest blocks until a request batch is available and returns a function
// that dispatches it to the appropriate handlers. The result is only an error
// if the connection failed; errors reported by the handler are reported to the
// caller and not returned here.
//
// The caller must invoke the returned function to complete the request.
func (s *Server) nextRequest() (func() error, error) {
s.mu.Lock()
defer s.mu.Unlock()
for s.ch != nil && s.inq.Len() == 0 {
s.work.Wait()
}
if s.ch == nil && s.inq.Len() == 0 {
return nil, s.err
}
ch := s.ch // capture
next := s.inq.Remove(s.inq.Front()).(jmessages)
s.log("Processing %d requests", len(next))
// Construct a dispatcher to run the handlers outside the lock.
return s.dispatch(next, ch), nil
}
// waitForBarrier blocks until all notification handlers that have been issued
// have completed, then adds n to the barrier.
//
// The caller must hold s.mu, but the lock is released during the wait to avert
// a deadlock with handlers calling back into the server. See #27.
// s.nbar counts the number of notifications that have been issued and are not
// yet complete.
func (s *Server) waitForBarrier(n int) {
s.mu.Unlock()
defer s.mu.Lock()
s.nbar.Wait()
s.nbar.Add(n)
}
// dispatch constructs a function that invokes each of the specified tasks.
// The caller must hold s.mu when calling dispatch, but the returned function
// should be executed outside the lock to wait for the handlers to return.
//
// dispatch blocks until any notification received prior to this batch has
// completed, to ensure that notifications are processed in a partial order
// that respects order of receipt. Notifications within a batch are handled
// concurrently.
func (s *Server) dispatch(next jmessages, ch channel.Sender) func() error {
// Resolve all the task handlers or record errors.
start := time.Now()
tasks := s.checkAndAssign(next)
last := len(tasks) - 1
// Ensure all notifications already issued have completed; see #24.
s.waitForBarrier(tasks.numValidNotifications())
return func() error {
var wg sync.WaitGroup
for i, t := range tasks {
if t.err != nil {
continue // nothing to do here; this task has already failed
}
t := t
wg.Add(1)
run := func() {
defer wg.Done()
if t.hreq.IsNotification() {
defer s.nbar.Done()
}
t.val, t.err = s.invoke(t.ctx, t.m, t.hreq)
}
if i < last {
go run()
} else {
run()
}
}
// Wait for all the handlers to return, then deliver any responses.
wg.Wait()
return s.deliver(tasks.responses(s.rpcLog), ch, time.Since(start))
}
}
// deliver cleans up completed responses and arranges their replies (if any) to
// be sent back to the client.
func (s *Server) deliver(rsps jmessages, ch channel.Sender, elapsed time.Duration) error {
if len(rsps) == 0 {
return nil
}
s.log("Completed %d requests [%v elapsed]", len(rsps), elapsed)
s.mu.Lock()
defer s.mu.Unlock()
// Ensure all the inflight requests get their contexts cancelled.
for _, rsp := range rsps {
s.cancel(string(rsp.ID))
}
nw, err := encode(ch, rsps)
s.metrics.CountAndSetMax("rpc.bytesWritten", int64(nw))
return err
}
// checkAndAssign resolves all the task handlers for the given batch, or
// records errors for them as appropriate. The caller must hold s.mu.
func (s *Server) checkAndAssign(next jmessages) tasks {
var ts tasks
for _, req := range next {
s.log("Checking request for %q: %s", req.M, string(req.P))
fid := fixID(req.ID)
t := &task{
hreq: &Request{id: fid, method: req.M, params: req.P},
batch: req.batch,
}
if req.err != nil {
t.err = req.err // deferred validation error
} else if id := string(fid); id != "" && s.used[id] != nil {
t.err = Errorf(code.InvalidRequest, "duplicate request id %q", id)
} else if !s.versionOK(req.V) {
t.err = ErrInvalidVersion
} else if !req.isRequestOrNotification() && s.call[id] != nil {
// This is a result or error for a pending push-call.
rsp := s.call[id]
delete(s.call, id)
rsp.ch <- req
continue // don't send a reply for this
} else if req.M == "" {
t.err = Errorf(code.InvalidRequest, "empty method name")
} else if s.setContext(t, id) {
t.m = s.assign(t.ctx, req.M)
if t.m == nil {
t.err = Errorf(code.MethodNotFound, "no such method %q", req.M)
}
}
if t.err != nil {
s.log("Task error: %v", t.err)
s.metrics.Count("rpc.errors", 1)
}
ts = append(ts, t)
}
return ts
}
// setContext constructs and attaches a request context to t, and reports
// whether this succeeded.
func (s *Server) setContext(t *task, id string) bool {
base, params, err := s.dectx(context.Background(), t.hreq.method, t.hreq.params)
t.hreq.params = params
if err != nil {
t.err = Errorf(code.InternalError, "invalid request context: %v", err)
return false
}
// Check request.
if err := s.ckreq(base, t.hreq); err != nil {
t.err = err
return false
}
t.ctx = context.WithValue(base, inboundRequestKey{}, t.hreq)
// Store the cancellation for a request that needs a reply, so that we can
// respond to rpc.cancel requests.
if id != "" {
ctx, cancel := context.WithCancel(t.ctx)
s.used[id] = cancel
t.ctx = ctx
}
return true
}
// invoke invokes the handler m for the specified request type, and marshals
// the return value into JSON if there is one.
func (s *Server) invoke(base context.Context, h Handler, req *Request) (json.RawMessage, error) {
ctx := context.WithValue(base, serverKey{}, s)
if err := s.sem.Acquire(ctx, 1); err != nil {
return nil, err
}
defer s.sem.Release(1)
s.rpcLog.LogRequest(ctx, req)
v, err := h.Handle(ctx, req)
if err != nil {
if req.IsNotification() {
s.log("Discarding error from notification to %q: %v", req.Method(), err)
return nil, nil // a notification
}
return nil, err // a call reporting an error
}
return json.Marshal(v)
}
// ServerInfo returns an atomic snapshot of the current server info for s.
func (s *Server) ServerInfo() *ServerInfo {
info := &ServerInfo{
Methods: s.mux.Names(),
UsesContext: s.expctx,
StartTime: s.start,
Counter: make(map[string]int64),
MaxValue: make(map[string]int64),
Label: make(map[string]string),
}
s.metrics.Snapshot(metrics.Snapshot{
Counter: info.Counter,
MaxValue: info.MaxValue,
Label: info.Label,
})
return info
}
// Notify posts a single server-side notification to the client.
//
// This is a non-standard extension of JSON-RPC, and may not be supported by
// all clients. Unless s was constructed with the AllowPush option set true,
// this method will always report an error (ErrPushUnsupported) without sending
// anything. If Notify is called after the client connection is closed, it
// returns ErrConnClosed.
func (s *Server) Notify(ctx context.Context, method string, params interface{}) error {
if !s.allowP {
return ErrPushUnsupported
}
_, err := s.pushReq(ctx, false /* no ID */, method, params)
return err
}
// Callback posts a single server-side call to the client. It blocks until a
// reply is received or the client connection terminates. A successful
// callback reports a nil error and a non-nil response. Errors returned by the
// client have concrete type *jrpc2.Error.
//
// This is a non-standard extension of JSON-RPC, and may not be supported by
// all clients. Unless s was constructed with the AllowPush option set true,
// this method will always report an error (ErrPushUnsupported) without sending
// anything. If Callback is called after the client connection is closed, it
// returns ErrConnClosed.
func (s *Server) Callback(ctx context.Context, method string, params interface{}) (*Response, error) {
if !s.allowP {
return nil, ErrPushUnsupported
}
rsp, err := s.pushReq(ctx, true /* set ID */, method, params)
if err != nil {
return nil, err
}
rsp.wait()
if err := rsp.Error(); err != nil {
return nil, filterError(err)
}
return rsp, nil
}
func (s *Server) pushReq(ctx context.Context, wantID bool, method string, params interface{}) (rsp *Response, _ error) {
var bits []byte
if params != nil {
v, err := json.Marshal(params)
if err != nil {
return nil, err
}
bits = v
}
s.mu.Lock()
defer s.mu.Unlock()
if s.ch == nil {
return nil, ErrConnClosed
}
kind := "notification"
var jid json.RawMessage
if wantID {
kind = "call"
id := strconv.FormatInt(s.callID, 10)
s.callID++
jid = json.RawMessage(id)
rsp = &Response{
ch: make(chan *jmessage, 1),
id: id,
cancel: func() {},
}
s.call[id] = rsp
}
s.log("Posting server %s %q %s", kind, method, string(bits))
nw, err := encode(s.ch, jmessages{{
V: Version,
ID: jid,
M: method,
P: bits,
}})
s.metrics.CountAndSetMax("rpc.bytesWritten", int64(nw))
s.metrics.Count("rpc."+kind+"s", 1)
return rsp, err
}
// Stop shuts down the server. It is safe to call this method multiple times or
// from concurrent goroutines; it will only take effect once.
func (s *Server) Stop() {
s.mu.Lock()
defer s.mu.Unlock()
s.stop(errServerStopped)
}
// ServerStatus describes the status of a stopped server.
type ServerStatus struct {
Err error // the error that caused the server to stop (nil on success)
stopped bool // whether Stop was called
}
// Success reports whether the server exited without error.
func (s ServerStatus) Success() bool { return s.Err == nil }
// Stopped reports whether the server exited due to Stop being called.
func (s ServerStatus) Stopped() bool { return s.Err == nil && s.stopped }
// Closed reports whether the server exited due to a channel close.
func (s ServerStatus) Closed() bool { return s.Err == nil && !s.stopped }
// WaitStatus blocks until the server terminates, and returns the resulting
// status. After WaitStatus returns, whether or not there was an error, it is
// safe to call s.Start again to restart the server with a fresh channel.
func (s *Server) WaitStatus() ServerStatus {
s.wg.Wait()
// Postcondition check.
if s.inq.Len() != 0 {
panic("s.inq is not empty at shutdown")
}
exitErr := s.err
// Don't remark on a closed channel or EOF as a noteworthy failure.
if s.err == io.EOF || channel.IsErrClosing(s.err) || s.err == errServerStopped {
exitErr = nil
}
return ServerStatus{Err: exitErr, stopped: s.err == errServerStopped}
}
// Wait blocks until the server terminates and returns the resulting error.
// It is equivalent to s.WaitStatus().Err.
func (s *Server) Wait() error { return s.WaitStatus().Err }
// stop shuts down the connection and records err as its final state. The
// caller must hold s.mu. If multiple callers invoke stop, only the first will
// successfully record its error status.
func (s *Server) stop(err error) {
if s.ch == nil {
return // nothing is running
}
s.log("Server signaled to stop with err=%v", err)
s.ch.Close()
// Remove any pending requests from the queue, but retain notifications.
// The server will process pending notifications before giving up.
//
// TODO(@creachadair): We need better tests for this behaviour.
var keep jmessages
for cur := s.inq.Front(); cur != nil; cur = s.inq.Front() {
for _, req := range cur.Value.(jmessages) {
if req.isNotification() {
keep = append(keep, req)
s.log("Retaining notification %p", req)
} else {
s.cancel(string(req.ID))
}
}
s.inq.Remove(cur)
}
for _, elt := range keep {
s.inq.PushBack(jmessages{elt})
}
s.work.Broadcast()
// Cancel any in-flight requests that made it out of the queue.
for id, cancel := range s.used {
cancel()
delete(s.used, id)
}
// Postcondition check.
if len(s.used) != 0 {
panic("s.used is not empty at shutdown")
}
s.err = err
s.ch = nil
}
// read is the main receiver loop, decoding requests from the client and adding
// them to the queue. Decoding errors and message-format problems are handled
// and reported back to the client directly, so that any message that survives
// into the request queue is structurally valid.
func (s *Server) read(ch channel.Receiver) {
for {
// If the message is not sensible, report an error; otherwise enqueue it
// for processing. Errors in individual requests are handled later.
var in jmessages
var derr error
bits, err := ch.Recv()
s.metrics.CountAndSetMax("rpc.bytesRead", int64(len(bits)))
if err == nil || (err == io.EOF && len(bits) != 0) {
err = nil
derr = in.parseJSON(bits)
s.metrics.Count("rpc.requests", int64(len(in)))
}
s.mu.Lock()
if err != nil { // receive failure; shut down
s.stop(err)
s.mu.Unlock()
return
} else if derr != nil { // parse failure; report and continue
s.pushError(derr)
} else if len(in) == 0 {
s.pushError(Errorf(code.InvalidRequest, "empty request batch"))
} else {
s.log("Received %d new requests", len(in))
s.inq.PushBack(in)
s.work.Broadcast()
}
s.mu.Unlock()
}
}
// ServerInfo is the concrete type of responses from the rpc.serverInfo method.
type ServerInfo struct {
// The list of method names exported by this server.
Methods []string `json:"methods,omitempty"`
// Whether this server understands context wrappers.
UsesContext bool `json:"usesContext"`
// Metric values defined by the evaluation of methods.
Counter map[string]int64 `json:"counters,omitempty"`
MaxValue map[string]int64 `json:"maxValue,omitempty"`
Label map[string]string `json:"labels,omitempty"`
// When the server started.
StartTime time.Time `json:"startTime,omitempty"`
}
// assign returns a Handler to handle the specified name, or nil.
// The caller must hold s.mu.
func (s *Server) assign(ctx context.Context, name string) Handler {
if s.builtin && strings.HasPrefix(name, "rpc.") {
switch name {
case rpcServerInfo:
return methodFunc(s.handleRPCServerInfo)
case rpcCancel:
return methodFunc(s.handleRPCCancel)
default:
return nil // reserved
}
}
return s.mux.Assign(ctx, name)
}
// pushError reports an error for the given request ID directly back to the
// client, bypassing the normal request handling mechanism. The caller must
// hold s.mu when calling this method.
func (s *Server) pushError(err error) {
s.log("Invalid request: %v", err)
var jerr *Error
if e, ok := err.(*Error); ok {
jerr = e
} else {
jerr = &Error{code: code.FromError(err), message: err.Error()}
}
nw, err := encode(s.ch, jmessages{{
V: Version,
ID: json.RawMessage("null"),
E: jerr,
}})
s.metrics.Count("rpc.errors", 1)
s.metrics.CountAndSetMax("rpc.bytesWritten", int64(nw))
if err != nil {
s.log("Writing error response: %v", err)
}
}
// cancel reports whether id is an active call. If so, it also calls the
// cancellation function associated with id and removes it from the
// reservations. The caller must hold s.mu.
func (s *Server) cancel(id string) bool {
cancel, ok := s.used[id]
if ok {
cancel()
delete(s.used, id)
}
return ok
}
func (s *Server) versionOK(v string) bool {
if v == "" {
return s.allow1 // an empty version is OK if the server allows it
}
return v == Version // ... otherwise it must match the spec
}
// A task represents a pending method invocation received by the server.
type task struct {
m Handler // the assigned handler (after assignment)
ctx context.Context // the context passed to the handler
hreq *Request // the request passed to the handler
batch bool // whether the request was part of a batch
val json.RawMessage // the result value (when complete)
err error // the error value (when complete)
}
type tasks []*task
func (ts tasks) responses(rpcLog RPCLogger) jmessages {
var rsps jmessages
for _, task := range ts {
if task.hreq.id == nil {
// Spec: "The Server MUST NOT reply to a Notification, including
// those that are within a batch request. Notifications are not
// confirmable by definition, since they do not have a Response
// object to be returned. As such, the Client would not be aware of
// any errors."
//
// However, parse and validation errors must still be reported, with
// an ID of null if the request ID was not resolvable.
if c := code.FromError(task.err); c != code.ParseError && c != code.InvalidRequest {
continue
}
}
rsp := &jmessage{V: Version, ID: task.hreq.id, batch: task.batch}
if rsp.ID == nil {
rsp.ID = json.RawMessage("null")
}
if task.err == nil {
rsp.R = task.val
} else if e, ok := task.err.(*Error); ok {
rsp.E = e
} else if c := code.FromError(task.err); c != code.NoError {
rsp.E = &Error{code: c, message: task.err.Error()}
} else {
rsp.E = &Error{code: code.InternalError, message: task.err.Error()}
}
rpcLog.LogResponse(task.ctx, &Response{
id: string(rsp.ID),
err: rsp.E,
result: rsp.R,
})
rsps = append(rsps, rsp)
}
return rsps
}
// numValidNotifications reports the number of elements in ts that are
// syntactically valid notifications.
func (ts tasks) numValidNotifications() (n int) {
for _, t := range ts {
if t.err == nil && t.hreq.IsNotification() {
n++
}
}
return
}