365 lines
9.7 KiB
Go
Raw Normal View History

2021-11-22 16:05:02 +00:00
// Package rpc2 provides bi-directional RPC client and server similar to net/rpc.
package rpc2
import (
"context"
"errors"
"io"
"log"
"reflect"
"sync"
)
// Client represents an RPC Client.
// There may be multiple outstanding Calls associated
// with a single Client, and a Client may be used by
// multiple goroutines simultaneously.
type Client struct {
mutex sync.Mutex // protects pending, seq, request
sending sync.Mutex
request Request // temp area used in send()
seq uint64
pending map[uint64]*Call
closing bool
shutdown bool
server bool
codec Codec
handlers map[string]*handler
disconnect chan struct{}
State *State // additional information to associate with client
blocking bool // whether to block request handling
}
// NewClient returns a new Client to handle requests to the
// set of services at the other end of the connection.
// It adds a buffer to the write side of the connection so
// the header and payload are sent as a unit.
func NewClient(conn io.ReadWriteCloser) *Client {
return NewClientWithCodec(NewGobCodec(conn))
}
// NewClientWithCodec is like NewClient but uses the specified
// codec to encode requests and decode responses.
func NewClientWithCodec(codec Codec) *Client {
return &Client{
codec: codec,
pending: make(map[uint64]*Call),
handlers: make(map[string]*handler),
disconnect: make(chan struct{}),
seq: 1, // 0 means notification.
}
}
// SetBlocking puts the client in blocking mode.
// In blocking mode, received requests are processes synchronously.
// If you have methods that may take a long time, other subsequent requests may time out.
func (c *Client) SetBlocking(blocking bool) {
c.blocking = blocking
}
// Run the client's read loop.
// You must run this method before calling any methods on the server.
func (c *Client) Run() {
c.readLoop()
}
// DisconnectNotify returns a channel that is closed
// when the client connection has gone away.
func (c *Client) DisconnectNotify() chan struct{} {
return c.disconnect
}
// Handle registers the handler function for the given method. If a handler already exists for method, Handle panics.
func (c *Client) Handle(method string, handlerFunc interface{}) {
addHandler(c.handlers, method, handlerFunc)
}
// readLoop reads messages from codec.
// It reads a reqeust or a response to the previous request.
// If the message is request, calls the handler function.
// If the message is response, sends the reply to the associated call.
func (c *Client) readLoop() {
var err error
var req Request
var resp Response
for err == nil {
req = Request{}
resp = Response{}
if err = c.codec.ReadHeader(&req, &resp); err != nil {
break
}
if req.Method != "" {
// request comes to server
if err = c.readRequest(&req); err != nil {
debugln("rpc2: error reading request:", err.Error())
}
} else {
// response comes to client
if err = c.readResponse(&resp); err != nil {
debugln("rpc2: error reading response:", err.Error())
}
}
}
// Terminate pending calls.
c.sending.Lock()
c.mutex.Lock()
c.shutdown = true
closing := c.closing
if err == io.EOF {
if closing {
err = ErrShutdown
} else {
err = io.ErrUnexpectedEOF
}
}
for _, call := range c.pending {
call.Error = err
call.done()
}
c.mutex.Unlock()
c.sending.Unlock()
if err != io.EOF && !closing && !c.server {
debugln("rpc2: client protocol error:", err)
}
close(c.disconnect)
if !closing {
c.codec.Close()
}
}
func (c *Client) handleRequest(req Request, method *handler, argv reflect.Value) {
// Invoke the method, providing a new value for the reply.
replyv := reflect.New(method.replyType.Elem())
returnValues := method.fn.Call([]reflect.Value{reflect.ValueOf(c), argv, replyv})
// Do not send response if request is a notification.
if req.Seq == 0 {
return
}
// The return value for the method is an error.
errInter := returnValues[0].Interface()
errmsg := ""
if errInter != nil {
errmsg = errInter.(error).Error()
}
resp := &Response{
Seq: req.Seq,
Error: errmsg,
}
if err := c.codec.WriteResponse(resp, replyv.Interface()); err != nil {
debugln("rpc2: error writing response:", err.Error())
}
}
func (c *Client) readRequest(req *Request) error {
method, ok := c.handlers[req.Method]
if !ok {
resp := &Response{
Seq: req.Seq,
Error: "rpc2: can't find method " + req.Method,
}
return c.codec.WriteResponse(resp, resp)
}
// Decode the argument value.
var argv reflect.Value
argIsValue := false // if true, need to indirect before calling.
if method.argType.Kind() == reflect.Ptr {
argv = reflect.New(method.argType.Elem())
} else {
argv = reflect.New(method.argType)
argIsValue = true
}
// argv guaranteed to be a pointer now.
if err := c.codec.ReadRequestBody(argv.Interface()); err != nil {
return err
}
if argIsValue {
argv = argv.Elem()
}
if c.blocking {
c.handleRequest(*req, method, argv)
} else {
go c.handleRequest(*req, method, argv)
}
return nil
}
func (c *Client) readResponse(resp *Response) error {
seq := resp.Seq
c.mutex.Lock()
call := c.pending[seq]
delete(c.pending, seq)
c.mutex.Unlock()
var err error
switch {
case call == nil:
// We've got no pending call. That usually means that
// WriteRequest partially failed, and call was already
// removed; response is a server telling us about an
// error reading request body. We should still attempt
// to read error body, but there's no one to give it to.
err = c.codec.ReadResponseBody(nil)
if err != nil {
err = errors.New("reading error body: " + err.Error())
}
case resp.Error != "":
// We've got an error response. Give this to the request;
// any subsequent requests will get the ReadResponseBody
// error if there is one.
call.Error = ServerError(resp.Error)
err = c.codec.ReadResponseBody(nil)
if err != nil {
err = errors.New("reading error body: " + err.Error())
}
call.done()
default:
err = c.codec.ReadResponseBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body " + err.Error())
}
call.done()
}
return err
}
// Close waits for active calls to finish and closes the codec.
func (c *Client) Close() error {
c.mutex.Lock()
if c.shutdown || c.closing {
c.mutex.Unlock()
return ErrShutdown
}
c.closing = true
c.mutex.Unlock()
return c.codec.Close()
}
// Go invokes the function asynchronously. It returns the Call structure representing
// the invocation. The done channel will signal when the call is complete by returning
// the same Call object. If done is nil, Go will allocate a new channel.
// If non-nil, done must be buffered or Go will deliberately crash.
func (c *Client) Go(method string, args interface{}, reply interface{}, done chan *Call) *Call {
call := new(Call)
call.Method = method
call.Args = args
call.Reply = reply
if done == nil {
done = make(chan *Call, 10) // buffered.
} else {
// If caller passes done != nil, it must arrange that
// done has enough buffer for the number of simultaneous
// RPCs that will be using that channel. If the channel
// is totally unbuffered, it's best not to run at all.
if cap(done) == 0 {
log.Panic("rpc2: done channel is unbuffered")
}
}
call.Done = done
c.send(call)
return call
}
// CallWithContext invokes the named function, waits for it to complete, and
// returns its error status, or an error from Context timeout.
func (c *Client) CallWithContext(ctx context.Context, method string, args interface{}, reply interface{}) error {
call := c.Go(method, args, reply, make(chan *Call, 1))
select {
case <-call.Done:
return call.Error
case <-ctx.Done():
return ctx.Err()
}
return nil
}
// Call invokes the named function, waits for it to complete, and returns its error status.
func (c *Client) Call(method string, args interface{}, reply interface{}) error {
return c.CallWithContext(context.Background(), method, args, reply)
}
func (call *Call) done() {
select {
case call.Done <- call:
// ok
default:
// We don't want to block here. It is the caller's responsibility to make
// sure the channel has enough buffer space. See comment in Go().
debugln("rpc2: discarding Call reply due to insufficient Done chan capacity")
}
}
// ServerError represents an error that has been returned from
// the remote side of the RPC connection.
type ServerError string
func (e ServerError) Error() string {
return string(e)
}
// ErrShutdown is returned when the connection is closing or closed.
var ErrShutdown = errors.New("connection is shut down")
// Call represents an active RPC.
type Call struct {
Method string // The name of the service and method to call.
Args interface{} // The argument to the function (*struct).
Reply interface{} // The reply from the function (*struct).
Error error // After completion, the error status.
Done chan *Call // Strobes when call is complete.
}
func (c *Client) send(call *Call) {
c.sending.Lock()
defer c.sending.Unlock()
// Register this call.
c.mutex.Lock()
if c.shutdown || c.closing {
call.Error = ErrShutdown
c.mutex.Unlock()
call.done()
return
}
seq := c.seq
c.seq++
c.pending[seq] = call
c.mutex.Unlock()
// Encode and send the request.
c.request.Seq = seq
c.request.Method = call.Method
err := c.codec.WriteRequest(&c.request, call.Args)
if err != nil {
c.mutex.Lock()
call = c.pending[seq]
delete(c.pending, seq)
c.mutex.Unlock()
if call != nil {
call.Error = err
call.done()
}
}
}
// Notify sends a request to the receiver but does not wait for a return value.
func (c *Client) Notify(method string, args interface{}) error {
c.sending.Lock()
defer c.sending.Unlock()
if c.shutdown || c.closing {
return ErrShutdown
}
c.request.Seq = 0
c.request.Method = method
return c.codec.WriteRequest(&c.request, args)
}