2021-11-08 16:39:17 +00:00

385 lines
13 KiB
Go

// Package handler provides implementations of the jrpc2.Assigner interface,
// and support for adapting functions to the jrpc2.Handler interface.
package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"sort"
"strings"
"github.com/creachadair/jrpc2"
"github.com/creachadair/jrpc2/code"
)
// A Func adapts a function having the correct signature to a jrpc2.Handler.
type Func func(context.Context, *jrpc2.Request) (interface{}, error)
// Handle implements the jrpc2.Handler interface by calling m.
func (m Func) Handle(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
return m(ctx, req)
}
// A Map is a trivial implementation of the jrpc2.Assigner interface that looks
// up method names in a map of static jrpc2.Handler values.
type Map map[string]jrpc2.Handler
// Assign implements part of the jrpc2.Assigner interface.
func (m Map) Assign(_ context.Context, method string) jrpc2.Handler { return m[method] }
// Names implements part of the jrpc2.Assigner interface.
func (m Map) Names() []string {
var names []string
for name := range m {
names = append(names, name)
}
sort.Strings(names)
return names
}
// A ServiceMap combines multiple assigners into one, permitting a server to
// export multiple services under different names.
type ServiceMap map[string]jrpc2.Assigner
// Assign splits the inbound method name as Service.Method, and passes the
// Method portion to the corresponding Service assigner. If method does not
// have the form Service.Method, or if Service is not set in m, the lookup
// fails and returns nil.
func (m ServiceMap) Assign(ctx context.Context, method string) jrpc2.Handler {
parts := strings.SplitN(method, ".", 2)
if len(parts) == 1 {
return nil
} else if ass, ok := m[parts[0]]; ok {
return ass.Assign(ctx, parts[1])
}
return nil
}
// Names reports the composed names of all the methods in the service, each
// having the form Service.Method.
func (m ServiceMap) Names() []string {
var all []string
for svc, assigner := range m {
for _, name := range assigner.Names() {
all = append(all, svc+"."+name)
}
}
sort.Strings(all)
return all
}
// New adapts a function to a jrpc2.Handler. The concrete value of fn must be
// function accepted by Check. The resulting Func will handle JSON encoding and
// decoding, call fn, and report appropriate errors.
//
// New is intended for use during program initialization, and will panic if the
// type of fn does not have one of the accepted forms. Programs that need to
// check for possible errors should call handler.Check directly, and use the
// Wrap method of the resulting FuncInfo to obtain the wrapper.
func New(fn interface{}) Func {
fi, err := Check(fn)
if err != nil {
panic(err)
}
return fi.Wrap()
}
var (
ctxType = reflect.TypeOf((*context.Context)(nil)).Elem() // type context.Context
errType = reflect.TypeOf((*error)(nil)).Elem() // type error
reqType = reflect.TypeOf((*jrpc2.Request)(nil)) // type *jrpc2.Request
strictType = reflect.TypeOf((*interface{ DisallowUnknownFields() })(nil)).Elem()
errNoParameters = &jrpc2.Error{Code: code.InvalidParams, Message: "no parameters accepted"}
)
// FuncInfo captures type signature information from a valid handler function.
type FuncInfo struct {
Type reflect.Type // the complete function type
Argument reflect.Type // the non-context argument type, or nil
Result reflect.Type // the non-error result type, or nil
ReportsError bool // true if the function reports an error
strictFields bool // enforce strict field checking
fn interface{} // the original function value
}
// Wrap adapts the function represented by fi in a Func that satisfies the
// jrpc2.Handler interface. The wrapped function can obtain the *jrpc2.Request
// value from its context argument using the jrpc2.InboundRequest helper.
//
// This method panics if fi == nil or if it does not represent a valid function
// type. A FuncInfo returned by a successful call to Check is always valid.
func (fi *FuncInfo) Wrap() Func {
if fi == nil || fi.fn == nil {
panic("handler: invalid FuncInfo value")
}
// Although it is not possible to completely eliminate reflection, the
// intent here is to hoist as much work as possible out of the body of the
// constructed Func wrapper, since that will be executed every time the
// handler is invoked.
//
// To do this, we "pre-compile" helper functions to unmarshal JSON into the
// input arguments (newInput) and to convert the results from reflectors
// back into values (decodeOut). We pre-check the function signature and
// types, so that the helpers do only as much reflection as is necessary:
// for example, we won't allocate a parameter value if the function does not
// accept a parameter, nor decode a return value if the function returns
// only an error.
// Special case: If fn has the exact signature of the Handle method, don't do
// any (additional) reflection at all.
if f, ok := fi.fn.(func(context.Context, *jrpc2.Request) (interface{}, error)); ok {
return Func(f)
}
// If strict field checking is desired, ensure arguments are wrapped.
wrapArg := func(v reflect.Value) interface{} { return v.Interface() }
if fi.strictFields && !fi.Argument.Implements(strictType) {
wrapArg = func(v reflect.Value) interface{} { return &strict{v.Interface()} }
}
// Construct a function to unpack the parameters from the request message,
// based on the signature of the user's callback.
var newInput func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error)
if fi.Argument == nil {
// Case 1: The function does not want any request parameters.
// Nothing needs to be decoded, but verify no parameters were passed.
newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) {
if req.HasParams() {
return nil, errNoParameters
}
return []reflect.Value{ctx}, nil
}
} else if fi.Argument == reqType {
// Case 2: The function wants the underlying *jrpc2.Request value.
newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) {
return []reflect.Value{ctx, reflect.ValueOf(req)}, nil
}
} else if fi.Argument.Kind() == reflect.Ptr {
// Case 3a: The function wants a pointer to its argument value.
newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) {
in := reflect.New(fi.Argument.Elem())
if err := req.UnmarshalParams(wrapArg(in)); err != nil {
return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err)
}
return []reflect.Value{ctx, in}, nil
}
} else {
// Case 3b: The function wants a bare argument value.
newInput = func(ctx reflect.Value, req *jrpc2.Request) ([]reflect.Value, error) {
in := reflect.New(fi.Argument) // we still need a pointer to unmarshal
if err := req.UnmarshalParams(wrapArg(in)); err != nil {
return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err)
}
// Indirect the pointer back off for the callee.
return []reflect.Value{ctx, in.Elem()}, nil
}
}
// Construct a function to decode the result values.
var decodeOut func([]reflect.Value) (interface{}, error)
if fi.Result == nil {
// The function returns only an error, the result is always nil.
decodeOut = func(vals []reflect.Value) (interface{}, error) {
oerr := vals[0].Interface()
if oerr != nil {
return nil, oerr.(error)
}
return nil, nil
}
} else if !fi.ReportsError {
// The function returns only single non-error: err is always nil.
decodeOut = func(vals []reflect.Value) (interface{}, error) {
return vals[0].Interface(), nil
}
} else {
// The function returns both a value and an error.
decodeOut = func(vals []reflect.Value) (interface{}, error) {
if oerr := vals[1].Interface(); oerr != nil {
return nil, oerr.(error)
}
return vals[0].Interface(), nil
}
}
call := reflect.ValueOf(fi.fn).Call
return Func(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
args, ierr := newInput(reflect.ValueOf(ctx), req)
if ierr != nil {
return nil, ierr
}
return decodeOut(call(args))
})
}
// Check checks whether fn can serve as a jrpc2.Handler. The concrete value of
// fn must be a function with one of the following type signature schemes, for
// JSON-marshalable types X and Y:
//
// func(context.Context) error
// func(context.Context) Y
// func(context.Context) (Y, error)
// func(context.Context, X) error
// func(context.Context, X) Y
// func(context.Context, X) (Y, error)
// func(context.Context, *jrpc2.Request) error
// func(context.Context, *jrpc2.Request) Y
// func(context.Context, *jrpc2.Request) (Y, error)
// func(context.Context, *jrpc2.Request) (interface{}, error)
//
// If fn does not have one of these forms, Check reports an error.
//
// Note that the JSON-RPC standard restricts encoded parameter values to arrays
// and objects. Check will accept argument types that do not encode to arrays
// or objects, but the wrapper will report an error when decoding the request.
//
// The recommended solution is to define a struct type for your parameters.
// For arbitrary single value types, however, another approach is to wrap it in
// a 1-element array, for example:
//
// func(ctx context.Context, sp [1]string) error {
// s := sp[0] // pull the actual argument out of the array
// // ...
// }
//
func Check(fn interface{}) (*FuncInfo, error) {
if fn == nil {
return nil, errors.New("nil function")
}
info := &FuncInfo{Type: reflect.TypeOf(fn), fn: fn}
if info.Type.Kind() != reflect.Func {
return nil, errors.New("not a function")
}
// Check argument values.
if np := info.Type.NumIn(); np == 0 || np > 2 {
return nil, errors.New("wrong number of parameters")
} else if info.Type.In(0) != ctxType {
return nil, errors.New("first parameter is not context.Context")
} else if info.Type.IsVariadic() {
return nil, errors.New("variadic functions are not supported")
} else if np == 2 {
info.Argument = info.Type.In(1)
}
// Check return values.
no := info.Type.NumOut()
if no < 1 || no > 2 {
return nil, errors.New("wrong number of results")
} else if no == 2 && info.Type.Out(1) != errType {
return nil, errors.New("result is not of type error")
}
info.ReportsError = info.Type.Out(no-1) == errType
if no == 2 || !info.ReportsError {
info.Result = info.Type.Out(0)
}
return info, nil
}
// Args is a wrapper that decodes an array of positional parameters into
// concrete locations.
//
// Unmarshaling a JSON value into an Args value v succeeds if the JSON encodes
// an array with length len(v), and unmarshaling each subvalue i into the
// corresponding v[i] succeeds. As a special case, if v[i] == nil the
// corresponding value is discarded.
//
// Marshaling an Args value v into JSON succeeds if each element of the slice
// is JSON marshalable, and yields a JSON array of length len(v) containing the
// JSON values corresponding to the elements of v.
//
// Usage example:
//
// func Handler(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
// var x, y int
// var s string
//
// if err := req.UnmarshalParams(&handler.Args{&x, &y, &s}); err != nil {
// return nil, err
// }
// // do useful work with x, y, and s
// }
//
type Args []interface{}
// UnmarshalJSON supports JSON unmarshaling for a.
func (a Args) UnmarshalJSON(data []byte) error {
var elts []json.RawMessage
if err := json.Unmarshal(data, &elts); err != nil {
return filterJSONError("args", "array", err)
} else if len(elts) != len(a) {
return fmt.Errorf("wrong number of args (got %d, want %d)", len(elts), len(a))
}
for i, elt := range elts {
if a[i] == nil {
continue
} else if err := json.Unmarshal(elt, a[i]); err != nil {
return fmt.Errorf("decoding argument %d: %w", i+1, err)
}
}
return nil
}
// MarshalJSON supports JSON marshaling for a.
func (a Args) MarshalJSON() ([]byte, error) {
if len(a) == 0 {
return []byte(`[]`), nil
}
return json.Marshal([]interface{}(a))
}
// Obj is a wrapper that maps object fields into concrete locations.
//
// Unmarshaling a JSON text into an Obj value v succeeds if the JSON encodes an
// object, and unmarshaling the value for each key k of the object into v[k]
// succeeds. If k does not exist in v, it is ignored.
//
// Marshaling an Obj into JSON works as for an ordinary map.
type Obj map[string]interface{}
// UnmarshalJSON supports JSON unmarshaling into o.
func (o Obj) UnmarshalJSON(data []byte) error {
var base map[string]json.RawMessage
if err := json.Unmarshal(data, &base); err != nil {
return filterJSONError("decoding", "object", err)
}
for key, val := range base {
arg, ok := o[key]
if !ok {
continue
} else if err := json.Unmarshal(val, arg); err != nil {
return fmt.Errorf("decoding %q: %v", key, err)
}
}
return nil
}
func filterJSONError(tag, want string, err error) error {
if t, ok := err.(*json.UnmarshalTypeError); ok {
return fmt.Errorf("%s: cannot decode %s as %s", tag, t.Value, want)
}
return err
}
// strict is a wrapper for an arbitrary value that enforces strict field
// checking when unmarshaling from JSON.
type strict struct{ v interface{} }
func (s *strict) UnmarshalJSON(data []byte) error {
dec := json.NewDecoder(bytes.NewReader(data))
dec.DisallowUnknownFields()
return dec.Decode(s.v)
}