329 lines
10 KiB
Go
329 lines
10 KiB
Go
|
// Package handler provides implementations of the jrpc2.Assigner interface,
|
||
|
// and support for adapting functions to the jrpc2.Handler interface.
|
||
|
package handler
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
|
||
|
"bitbucket.org/creachadair/stringset"
|
||
|
"github.com/creachadair/jrpc2"
|
||
|
"github.com/creachadair/jrpc2/code"
|
||
|
)
|
||
|
|
||
|
// A Func adapts a function having the correct signature to a jrpc2.Handler.
|
||
|
type Func func(context.Context, *jrpc2.Request) (interface{}, error)
|
||
|
|
||
|
// Handle implements the jrpc2.Handler interface by calling m.
|
||
|
func (m Func) Handle(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
|
||
|
return m(ctx, req)
|
||
|
}
|
||
|
|
||
|
// A Map is a trivial implementation of the jrpc2.Assigner interface that looks
|
||
|
// up method names in a map of static jrpc2.Handler values.
|
||
|
type Map map[string]jrpc2.Handler
|
||
|
|
||
|
// Assign implements part of the jrpc2.Assigner interface.
|
||
|
func (m Map) Assign(_ context.Context, method string) jrpc2.Handler { return m[method] }
|
||
|
|
||
|
// Names implements part of the jrpc2.Assigner interface.
|
||
|
func (m Map) Names() []string { return stringset.FromKeys(m).Elements() }
|
||
|
|
||
|
// A ServiceMap combines multiple assigners into one, permitting a server to
|
||
|
// export multiple services under different names.
|
||
|
//
|
||
|
// Example:
|
||
|
// m := handler.ServiceMap{
|
||
|
// "Foo": handler.NewService(fooService), // methods Foo.A, Foo.B, etc.
|
||
|
// "Bar": handler.NewService(barService), // methods Bar.A, Bar.B, etc.
|
||
|
// }
|
||
|
//
|
||
|
type ServiceMap map[string]jrpc2.Assigner
|
||
|
|
||
|
// Assign splits the inbound method name as Service.Method, and passes the
|
||
|
// Method portion to the corresponding Service assigner. If method does not
|
||
|
// have the form Service.Method, or if Service is not set in m, the lookup
|
||
|
// fails and returns nil.
|
||
|
func (m ServiceMap) Assign(ctx context.Context, method string) jrpc2.Handler {
|
||
|
parts := strings.SplitN(method, ".", 2)
|
||
|
if len(parts) == 1 {
|
||
|
return nil
|
||
|
} else if ass, ok := m[parts[0]]; ok {
|
||
|
return ass.Assign(ctx, parts[1])
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Names reports the composed names of all the methods in the service, each
|
||
|
// having the form Service.Method.
|
||
|
func (m ServiceMap) Names() []string {
|
||
|
var all stringset.Set
|
||
|
for svc, assigner := range m {
|
||
|
for _, name := range assigner.Names() {
|
||
|
all.Add(svc + "." + name)
|
||
|
}
|
||
|
}
|
||
|
return all.Elements()
|
||
|
}
|
||
|
|
||
|
// New adapts a function to a jrpc2.Handler. The concrete value of fn must be a
|
||
|
// function with one of the following type signatures:
|
||
|
//
|
||
|
// func(context.Context) error
|
||
|
// func(context.Context) Y
|
||
|
// func(context.Context) (Y, error)
|
||
|
// func(context.Context, X) error
|
||
|
// func(context.Context, X) Y
|
||
|
// func(context.Context, X) (Y, error)
|
||
|
// func(context.Context, ...X) (Y, error)
|
||
|
// func(context.Context, *jrpc2.Request) (Y, error)
|
||
|
// func(context.Context, *jrpc2.Request) (interface{}, error)
|
||
|
//
|
||
|
// for JSON-marshalable types X and Y. New will panic if the type of fn does
|
||
|
// not have one of these forms. The resulting method will handle encoding and
|
||
|
// decoding of JSON and report appropriate errors.
|
||
|
//
|
||
|
// Functions adapted by in this way can obtain the *jrpc2.Request value using
|
||
|
// the jrpc2.InboundRequest helper on the context value supplied by the server.
|
||
|
func New(fn interface{}) Func {
|
||
|
m, err := newHandler(fn)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
// NewService adapts the methods of a value to a map from method names to
|
||
|
// Handler implementations as constructed by New. It will panic if obj has no
|
||
|
// exported methods with a suitable signature.
|
||
|
func NewService(obj interface{}) Map {
|
||
|
out := make(Map)
|
||
|
val := reflect.ValueOf(obj)
|
||
|
typ := val.Type()
|
||
|
|
||
|
// This considers only exported methods, as desired.
|
||
|
for i, n := 0, val.NumMethod(); i < n; i++ {
|
||
|
mi := val.Method(i)
|
||
|
if v, err := newHandler(mi.Interface()); err == nil {
|
||
|
out[typ.Method(i).Name] = v
|
||
|
}
|
||
|
}
|
||
|
if len(out) == 0 {
|
||
|
panic("no matching exported methods")
|
||
|
}
|
||
|
return out
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
ctxType = reflect.TypeOf((*context.Context)(nil)).Elem() // type context.Context
|
||
|
errType = reflect.TypeOf((*error)(nil)).Elem() // type error
|
||
|
reqType = reflect.TypeOf((*jrpc2.Request)(nil)) // type *jrpc2.Request
|
||
|
)
|
||
|
|
||
|
func newHandler(fn interface{}) (Func, error) {
|
||
|
if fn == nil {
|
||
|
return nil, errors.New("nil method")
|
||
|
}
|
||
|
|
||
|
// Special case: If fn has the exact signature of the Handle method, don't do
|
||
|
// any (additional) reflection at all.
|
||
|
if f, ok := fn.(func(context.Context, *jrpc2.Request) (interface{}, error)); ok {
|
||
|
return Func(f), nil
|
||
|
}
|
||
|
|
||
|
// Check that fn is a function of one of the correct forms.
|
||
|
typ, err := checkFunctionType(fn)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// Construct a function to unpack the parameters from the request message,
|
||
|
// based on the signature of the user's callback.
|
||
|
var newinput func(req *jrpc2.Request) ([]reflect.Value, error)
|
||
|
|
||
|
if typ.NumIn() == 1 {
|
||
|
// Case 1: The function does not want any request parameters.
|
||
|
// Nothing needs to be decoded, but verify no parameters were passed.
|
||
|
newinput = func(req *jrpc2.Request) ([]reflect.Value, error) {
|
||
|
if req.HasParams() {
|
||
|
return nil, jrpc2.Errorf(code.InvalidParams, "no parameters accepted")
|
||
|
}
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
} else if a := typ.In(1); a == reqType {
|
||
|
// Case 2: The function wants the underlying *jrpc2.Request value.
|
||
|
newinput = func(req *jrpc2.Request) ([]reflect.Value, error) {
|
||
|
return []reflect.Value{reflect.ValueOf(req)}, nil
|
||
|
}
|
||
|
|
||
|
} else {
|
||
|
// Check whether the function wants a pointer to its argument. We need
|
||
|
// to create one either way to support unmarshaling, but we need to
|
||
|
// indirect it back off if the callee didn't want it.
|
||
|
|
||
|
// Case 3a: The function wants a bare value, not a pointer.
|
||
|
argType := typ.In(1)
|
||
|
undo := reflect.Value.Elem
|
||
|
|
||
|
if argType.Kind() == reflect.Ptr {
|
||
|
// Case 3b: The function wants a pointer.
|
||
|
undo = func(v reflect.Value) reflect.Value { return v }
|
||
|
argType = argType.Elem()
|
||
|
}
|
||
|
|
||
|
newinput = func(req *jrpc2.Request) ([]reflect.Value, error) {
|
||
|
in := reflect.New(argType).Interface()
|
||
|
if err := req.UnmarshalParams(in); err != nil {
|
||
|
return nil, jrpc2.Errorf(code.InvalidParams, "invalid parameters: %v", err)
|
||
|
}
|
||
|
arg := reflect.ValueOf(in)
|
||
|
return []reflect.Value{undo(arg)}, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Construct a function to decode the result values.
|
||
|
var decodeOut func([]reflect.Value) (interface{}, error)
|
||
|
|
||
|
switch typ.NumOut() {
|
||
|
case 1:
|
||
|
if typ.Out(0) == errType {
|
||
|
// A function that returns only error: Result is always nil.
|
||
|
decodeOut = func(vals []reflect.Value) (interface{}, error) {
|
||
|
oerr := vals[0].Interface()
|
||
|
if oerr != nil {
|
||
|
return nil, oerr.(error)
|
||
|
}
|
||
|
return nil, nil
|
||
|
}
|
||
|
} else {
|
||
|
// A function that returns a single non-error: err is always nil.
|
||
|
decodeOut = func(vals []reflect.Value) (interface{}, error) {
|
||
|
return vals[0].Interface(), nil
|
||
|
}
|
||
|
}
|
||
|
default:
|
||
|
// A function that returns a value and an error.
|
||
|
decodeOut = func(vals []reflect.Value) (interface{}, error) {
|
||
|
out, oerr := vals[0].Interface(), vals[1].Interface()
|
||
|
if oerr != nil {
|
||
|
return nil, oerr.(error)
|
||
|
}
|
||
|
return out, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
f := reflect.ValueOf(fn)
|
||
|
call := f.Call
|
||
|
if typ.IsVariadic() {
|
||
|
call = f.CallSlice
|
||
|
}
|
||
|
|
||
|
return Func(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
|
||
|
rest, ierr := newinput(req)
|
||
|
if ierr != nil {
|
||
|
return nil, ierr
|
||
|
}
|
||
|
args := append([]reflect.Value{reflect.ValueOf(ctx)}, rest...)
|
||
|
return decodeOut(call(args))
|
||
|
}), nil
|
||
|
}
|
||
|
|
||
|
func checkFunctionType(fn interface{}) (reflect.Type, error) {
|
||
|
typ := reflect.TypeOf(fn)
|
||
|
if typ.Kind() != reflect.Func {
|
||
|
return nil, errors.New("not a function")
|
||
|
} else if np := typ.NumIn(); np == 0 || np > 2 {
|
||
|
return nil, errors.New("wrong number of parameters")
|
||
|
} else if no := typ.NumOut(); no < 1 || no > 2 {
|
||
|
return nil, errors.New("wrong number of results")
|
||
|
} else if typ.In(0) != ctxType {
|
||
|
return nil, errors.New("first parameter is not context.Context")
|
||
|
} else if no == 2 && typ.Out(1) != errType {
|
||
|
return nil, errors.New("result is not of type error")
|
||
|
}
|
||
|
return typ, nil
|
||
|
}
|
||
|
|
||
|
// Args is a wrapper that decodes an array of positional parameters into
|
||
|
// concrete locations.
|
||
|
//
|
||
|
// Unmarshaling a JSON value into an Args value v succeeds if the JSON encodes
|
||
|
// an array with length len(v), and unmarshaling each subvalue i into the
|
||
|
// corresponding v[i] succeeds. As a special case, if v[i] == nil the
|
||
|
// corresponding value is discarded.
|
||
|
//
|
||
|
// Marshaling an Args value v into JSON succeeds if each element of the slice
|
||
|
// is JSON marshalable, and yields a JSON array of length len(v) containing the
|
||
|
// JSON values corresponding to the elements of v.
|
||
|
//
|
||
|
// Usage example:
|
||
|
//
|
||
|
// func Handler(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
|
||
|
// var x, y int
|
||
|
// var s string
|
||
|
//
|
||
|
// if err := req.UnmarshalParams(&handler.Args{&x, &y, &s}); err != nil {
|
||
|
// return nil, err
|
||
|
// }
|
||
|
// // do useful work with x, y, and s
|
||
|
// }
|
||
|
//
|
||
|
type Args []interface{}
|
||
|
|
||
|
// UnmarshalJSON supports JSON unmarshaling for a.
|
||
|
func (a Args) UnmarshalJSON(data []byte) error {
|
||
|
var elts []json.RawMessage
|
||
|
if err := json.Unmarshal(data, &elts); err != nil {
|
||
|
return fmt.Errorf("decoding args: %w", err)
|
||
|
} else if len(elts) != len(a) {
|
||
|
return fmt.Errorf("wrong number of args (got %d, want %d)", len(elts), len(a))
|
||
|
}
|
||
|
for i, elt := range elts {
|
||
|
if a[i] == nil {
|
||
|
continue
|
||
|
} else if err := json.Unmarshal(elt, a[i]); err != nil {
|
||
|
return fmt.Errorf("decoding argument %d: %w", i+1, err)
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// MarshalJSON supports JSON marshaling for a.
|
||
|
func (a Args) MarshalJSON() ([]byte, error) {
|
||
|
if len(a) == 0 {
|
||
|
return []byte(`[]`), nil
|
||
|
}
|
||
|
return json.Marshal([]interface{}(a))
|
||
|
}
|
||
|
|
||
|
// Obj is a wrapper that maps object fields into concrete locations.
|
||
|
//
|
||
|
// Unmarshaling a JSON text into an Obj value v succeeds if the JSON encodes an
|
||
|
// object, and unmarshaling the value for each key k of the object into v[k]
|
||
|
// succeeds. If k does not exist in v, it is ignored.
|
||
|
//
|
||
|
// Marshaling an Obj into JSON works as for an ordinary map.
|
||
|
type Obj map[string]interface{}
|
||
|
|
||
|
// UnmarshalJSON supports JSON unmarshaling into o.
|
||
|
func (o Obj) UnmarshalJSON(data []byte) error {
|
||
|
var base map[string]json.RawMessage
|
||
|
if err := json.Unmarshal(data, &base); err != nil {
|
||
|
return fmt.Errorf("decoding object: %v", err)
|
||
|
}
|
||
|
for key, val := range base {
|
||
|
arg, ok := o[key]
|
||
|
if !ok {
|
||
|
continue
|
||
|
} else if err := json.Unmarshal(val, arg); err != nil {
|
||
|
return fmt.Errorf("decoding %q: %v", key, err)
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|