mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
rpc: make more tolerant of errors.
Add Error type to enable clients to distinguish between local and remote errors. Also return "connection shut down error" after the first error return rather than returning the same error each time. R=r CC=golang-dev https://golang.org/cl/4080058
This commit is contained in:
parent
d916cca327
commit
d40ae94993
3 changed files with 131 additions and 120 deletions
|
|
@ -15,6 +15,16 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ServerError represents an error that has been returned from
|
||||||
|
// the remote side of the RPC connection.
|
||||||
|
type ServerError string
|
||||||
|
|
||||||
|
func (e ServerError) String() string {
|
||||||
|
return string(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
const ErrShutdown = os.ErrorString("connection is shut down")
|
||||||
|
|
||||||
// Call represents an active RPC.
|
// Call represents an active RPC.
|
||||||
type Call struct {
|
type Call struct {
|
||||||
ServiceMethod string // The name of the service and method to call.
|
ServiceMethod string // The name of the service and method to call.
|
||||||
|
|
@ -30,12 +40,12 @@ type Call struct {
|
||||||
// with a single Client.
|
// with a single Client.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
mutex sync.Mutex // protects pending, seq
|
mutex sync.Mutex // protects pending, seq
|
||||||
shutdown os.Error // non-nil if the client is shut down
|
|
||||||
sending sync.Mutex
|
sending sync.Mutex
|
||||||
seq uint64
|
seq uint64
|
||||||
codec ClientCodec
|
codec ClientCodec
|
||||||
pending map[uint64]*Call
|
pending map[uint64]*Call
|
||||||
closing bool
|
closing bool
|
||||||
|
shutdown bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// A ClientCodec implements writing of RPC requests and
|
// A ClientCodec implements writing of RPC requests and
|
||||||
|
|
@ -55,8 +65,8 @@ type ClientCodec interface {
|
||||||
func (client *Client) send(c *Call) {
|
func (client *Client) send(c *Call) {
|
||||||
// Register this call.
|
// Register this call.
|
||||||
client.mutex.Lock()
|
client.mutex.Lock()
|
||||||
if client.shutdown != nil {
|
if client.shutdown {
|
||||||
c.Error = client.shutdown
|
c.Error = ErrShutdown
|
||||||
client.mutex.Unlock()
|
client.mutex.Unlock()
|
||||||
c.done()
|
c.done()
|
||||||
return
|
return
|
||||||
|
|
@ -79,6 +89,7 @@ func (client *Client) send(c *Call) {
|
||||||
|
|
||||||
func (client *Client) input() {
|
func (client *Client) input() {
|
||||||
var err os.Error
|
var err os.Error
|
||||||
|
var marker struct{}
|
||||||
for err == nil {
|
for err == nil {
|
||||||
response := new(Response)
|
response := new(Response)
|
||||||
err = client.codec.ReadResponseHeader(response)
|
err = client.codec.ReadResponseHeader(response)
|
||||||
|
|
@ -93,20 +104,27 @@ func (client *Client) input() {
|
||||||
c := client.pending[seq]
|
c := client.pending[seq]
|
||||||
client.pending[seq] = c, false
|
client.pending[seq] = c, false
|
||||||
client.mutex.Unlock()
|
client.mutex.Unlock()
|
||||||
|
|
||||||
|
if response.Error == "" {
|
||||||
err = client.codec.ReadResponseBody(c.Reply)
|
err = client.codec.ReadResponseBody(c.Reply)
|
||||||
if response.Error != "" {
|
if err != nil {
|
||||||
c.Error = os.ErrorString(response.Error)
|
c.Error = os.ErrorString("reading body " + err.String())
|
||||||
} else if err != nil {
|
}
|
||||||
c.Error = err
|
|
||||||
} else {
|
} else {
|
||||||
// Empty strings should turn into nil os.Errors
|
// We've got an error response. Give this to the request;
|
||||||
c.Error = nil
|
// any subsequent requests will get the ReadResponseBody
|
||||||
|
// error if there is one.
|
||||||
|
c.Error = ServerError(response.Error)
|
||||||
|
err = client.codec.ReadResponseBody(&marker)
|
||||||
|
if err != nil {
|
||||||
|
err = os.ErrorString("reading error body: " + err.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
c.done()
|
c.done()
|
||||||
}
|
}
|
||||||
// Terminate pending calls.
|
// Terminate pending calls.
|
||||||
client.mutex.Lock()
|
client.mutex.Lock()
|
||||||
client.shutdown = err
|
client.shutdown = true
|
||||||
for _, call := range client.pending {
|
for _, call := range client.pending {
|
||||||
call.Error = err
|
call.Error = err
|
||||||
call.done()
|
call.done()
|
||||||
|
|
@ -209,10 +227,11 @@ func Dial(network, address string) (*Client, os.Error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) Close() os.Error {
|
func (client *Client) Close() os.Error {
|
||||||
if client.shutdown != nil || client.closing {
|
|
||||||
return os.ErrorString("rpc: already closed")
|
|
||||||
}
|
|
||||||
client.mutex.Lock()
|
client.mutex.Lock()
|
||||||
|
if client.shutdown || client.closing {
|
||||||
|
client.mutex.Unlock()
|
||||||
|
return ErrShutdown
|
||||||
|
}
|
||||||
client.closing = true
|
client.closing = true
|
||||||
client.mutex.Unlock()
|
client.mutex.Unlock()
|
||||||
return client.codec.Close()
|
return client.codec.Close()
|
||||||
|
|
@ -239,8 +258,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Done = done
|
c.Done = done
|
||||||
if client.shutdown != nil {
|
if client.shutdown {
|
||||||
c.Error = client.shutdown
|
c.Error = ErrShutdown
|
||||||
c.done()
|
c.done()
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
@ -250,8 +269,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
|
||||||
|
|
||||||
// Call invokes the named function, waits for it to complete, and returns its error status.
|
// Call invokes the named function, waits for it to complete, and returns its error status.
|
||||||
func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
|
func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
|
||||||
if client.shutdown != nil {
|
if client.shutdown {
|
||||||
return client.shutdown
|
return ErrShutdown
|
||||||
}
|
}
|
||||||
call := <-client.Go(serviceMethod, args, reply, nil).Done
|
call := <-client.Go(serviceMethod, args, reply, nil).Done
|
||||||
return call.Error
|
return call.Error
|
||||||
|
|
|
||||||
|
|
@ -299,7 +299,7 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E
|
||||||
|
|
||||||
// A value sent as a placeholder for the response when the server receives an invalid request.
|
// A value sent as a placeholder for the response when the server receives an invalid request.
|
||||||
type InvalidRequest struct {
|
type InvalidRequest struct {
|
||||||
marker int
|
Marker int
|
||||||
}
|
}
|
||||||
|
|
||||||
var invalidRequest = InvalidRequest{1}
|
var invalidRequest = InvalidRequest{1}
|
||||||
|
|
@ -316,6 +316,7 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec Se
|
||||||
resp.ServiceMethod = req.ServiceMethod
|
resp.ServiceMethod = req.ServiceMethod
|
||||||
if errmsg != "" {
|
if errmsg != "" {
|
||||||
resp.Error = errmsg
|
resp.Error = errmsg
|
||||||
|
reply = invalidRequest
|
||||||
}
|
}
|
||||||
resp.Seq = req.Seq
|
resp.Seq = req.Seq
|
||||||
sending.Lock()
|
sending.Lock()
|
||||||
|
|
@ -389,9 +390,28 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
|
||||||
func (server *Server) ServeCodec(codec ServerCodec) {
|
func (server *Server) ServeCodec(codec ServerCodec) {
|
||||||
sending := new(sync.Mutex)
|
sending := new(sync.Mutex)
|
||||||
for {
|
for {
|
||||||
// Grab the request header.
|
req, service, mtype, err := server.readRequest(codec)
|
||||||
req := new(Request)
|
if err != nil {
|
||||||
err := codec.ReadRequestHeader(req)
|
if err != os.EOF {
|
||||||
|
log.Println("rpc:", err)
|
||||||
|
}
|
||||||
|
if err == os.EOF || err == io.ErrUnexpectedEOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// discard body
|
||||||
|
codec.ReadRequestBody(new(interface{}))
|
||||||
|
|
||||||
|
// send a response if we actually managed to read a header.
|
||||||
|
if req != nil {
|
||||||
|
sendResponse(sending, req, invalidRequest, codec, err.String())
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the argument value.
|
||||||
|
argv := _new(mtype.ArgType)
|
||||||
|
replyv := _new(mtype.ReplyType)
|
||||||
|
err = codec.ReadRequestBody(argv.Interface())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == os.EOF || err == io.ErrUnexpectedEOF {
|
if err == os.EOF || err == io.ErrUnexpectedEOF {
|
||||||
if err == io.ErrUnexpectedEOF {
|
if err == io.ErrUnexpectedEOF {
|
||||||
|
|
@ -399,44 +419,45 @@ func (server *Server) ServeCodec(codec ServerCodec) {
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
s := "rpc: server cannot decode request: " + err.String()
|
|
||||||
sendResponse(sending, req, invalidRequest, codec, s)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
serviceMethod := strings.Split(req.ServiceMethod, ".", -1)
|
|
||||||
if len(serviceMethod) != 2 {
|
|
||||||
s := "rpc: service/method request ill-formed: " + req.ServiceMethod
|
|
||||||
sendResponse(sending, req, invalidRequest, codec, s)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Look up the request.
|
|
||||||
server.Lock()
|
|
||||||
service, ok := server.serviceMap[serviceMethod[0]]
|
|
||||||
server.Unlock()
|
|
||||||
if !ok {
|
|
||||||
s := "rpc: can't find service " + req.ServiceMethod
|
|
||||||
sendResponse(sending, req, invalidRequest, codec, s)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mtype, ok := service.method[serviceMethod[1]]
|
|
||||||
if !ok {
|
|
||||||
s := "rpc: can't find method " + req.ServiceMethod
|
|
||||||
sendResponse(sending, req, invalidRequest, codec, s)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Decode the argument value.
|
|
||||||
argv := _new(mtype.ArgType)
|
|
||||||
replyv := _new(mtype.ReplyType)
|
|
||||||
err = codec.ReadRequestBody(argv.Interface())
|
|
||||||
if err != nil {
|
|
||||||
log.Println("rpc: tearing down", serviceMethod[0], "connection:", err)
|
|
||||||
sendResponse(sending, req, replyv.Interface(), codec, err.String())
|
sendResponse(sending, req, replyv.Interface(), codec, err.String())
|
||||||
break
|
continue
|
||||||
}
|
}
|
||||||
go service.call(sending, mtype, req, argv, replyv, codec)
|
go service.call(sending, mtype, req, argv, replyv, codec)
|
||||||
}
|
}
|
||||||
codec.Close()
|
codec.Close()
|
||||||
}
|
}
|
||||||
|
func (server *Server) readRequest(codec ServerCodec) (req *Request, service *service, mtype *methodType, err os.Error) {
|
||||||
|
// Grab the request header.
|
||||||
|
req = new(Request)
|
||||||
|
err = codec.ReadRequestHeader(req)
|
||||||
|
if err != nil {
|
||||||
|
req = nil
|
||||||
|
if err == os.EOF || err == io.ErrUnexpectedEOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = os.ErrorString("rpc: server cannot decode request: " + err.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceMethod := strings.Split(req.ServiceMethod, ".", -1)
|
||||||
|
if len(serviceMethod) != 2 {
|
||||||
|
err = os.ErrorString("rpc: service/method request ill-formed: " + req.ServiceMethod)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Look up the request.
|
||||||
|
server.Lock()
|
||||||
|
service = server.serviceMap[serviceMethod[0]]
|
||||||
|
server.Unlock()
|
||||||
|
if service == nil {
|
||||||
|
err = os.ErrorString("rpc: can't find service " + req.ServiceMethod)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mtype = service.method[serviceMethod[1]]
|
||||||
|
if mtype == nil {
|
||||||
|
err = os.ErrorString("rpc: can't find method " + req.ServiceMethod)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Accept accepts connections on the listener and serves requests
|
// Accept accepts connections on the listener and serves requests
|
||||||
// for each incoming connection. Accept blocks; the caller typically
|
// for each incoming connection. Accept blocks; the caller typically
|
||||||
|
|
|
||||||
|
|
@ -134,14 +134,25 @@ func testRPC(t *testing.T, addr string) {
|
||||||
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
|
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Nonexistent method
|
||||||
|
args = &Args{7, 0}
|
||||||
|
reply = new(Reply)
|
||||||
|
err = client.Call("Arith.BadOperation", args, reply)
|
||||||
|
// expect an error
|
||||||
|
if err == nil {
|
||||||
|
t.Error("BadOperation: expected error")
|
||||||
|
} else if !strings.HasPrefix(err.String(), "rpc: can't find method ") {
|
||||||
|
t.Errorf("BadOperation: expected can't find method error; got %q", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unknown service
|
||||||
args = &Args{7, 8}
|
args = &Args{7, 8}
|
||||||
reply = new(Reply)
|
reply = new(Reply)
|
||||||
err = client.Call("Arith.Mul", args, reply)
|
err = client.Call("Arith.Unknown", args, reply)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
t.Errorf("Mul: expected no error but got string %q", err.String())
|
t.Error("expected error calling unknown service")
|
||||||
}
|
} else if strings.Index(err.String(), "method") < 0 {
|
||||||
if reply.C != args.A*args.B {
|
t.Error("expected error about method; got", err)
|
||||||
t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Out of order.
|
// Out of order.
|
||||||
|
|
@ -178,6 +189,15 @@ func testRPC(t *testing.T, addr string) {
|
||||||
t.Error("Div: expected divide by zero error; got", err)
|
t.Error("Div: expected divide by zero error; got", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bad type.
|
||||||
|
reply = new(Reply)
|
||||||
|
err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error calling Arith.Add with wrong arg type")
|
||||||
|
} else if strings.Index(err.String(), "type") < 0 {
|
||||||
|
t.Error("expected error about type; got", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Non-struct argument
|
// Non-struct argument
|
||||||
const Val = 12345
|
const Val = 12345
|
||||||
str := fmt.Sprint(Val)
|
str := fmt.Sprint(Val)
|
||||||
|
|
@ -200,9 +220,19 @@ func testRPC(t *testing.T, addr string) {
|
||||||
if str != expect {
|
if str != expect {
|
||||||
t.Errorf("String: expected %s got %s", expect, str)
|
t.Errorf("String: expected %s got %s", expect, str)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
args = &Args{7, 8}
|
||||||
|
reply = new(Reply)
|
||||||
|
err = client.Call("Arith.Mul", args, reply)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Mul: expected no error but got string %q", err.String())
|
||||||
|
}
|
||||||
|
if reply.C != args.A*args.B {
|
||||||
|
t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPRPC(t *testing.T) {
|
func TestHTTP(t *testing.T) {
|
||||||
once.Do(startServer)
|
once.Do(startServer)
|
||||||
testHTTPRPC(t, "")
|
testHTTPRPC(t, "")
|
||||||
newOnce.Do(startNewServer)
|
newOnce.Do(startNewServer)
|
||||||
|
|
@ -233,65 +263,6 @@ func testHTTPRPC(t *testing.T, path string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckUnknownService(t *testing.T) {
|
|
||||||
once.Do(startServer)
|
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", "", serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("dialing:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClient(conn)
|
|
||||||
|
|
||||||
args := &Args{7, 8}
|
|
||||||
reply := new(Reply)
|
|
||||||
err = client.Call("Unknown.Add", args, reply)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("expected error calling unknown service")
|
|
||||||
} else if strings.Index(err.String(), "service") < 0 {
|
|
||||||
t.Error("expected error about service; got", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckUnknownMethod(t *testing.T) {
|
|
||||||
once.Do(startServer)
|
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", "", serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("dialing:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClient(conn)
|
|
||||||
|
|
||||||
args := &Args{7, 8}
|
|
||||||
reply := new(Reply)
|
|
||||||
err = client.Call("Arith.Unknown", args, reply)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("expected error calling unknown service")
|
|
||||||
} else if strings.Index(err.String(), "method") < 0 {
|
|
||||||
t.Error("expected error about method; got", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckBadType(t *testing.T) {
|
|
||||||
once.Do(startServer)
|
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", "", serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("dialing:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client := NewClient(conn)
|
|
||||||
|
|
||||||
reply := new(Reply)
|
|
||||||
err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
|
|
||||||
if err == nil {
|
|
||||||
t.Error("expected error calling Arith.Add with wrong arg type")
|
|
||||||
} else if strings.Index(err.String(), "type") < 0 {
|
|
||||||
t.Error("expected error about type; got", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ArgNotPointer int
|
type ArgNotPointer int
|
||||||
type ReplyNotPointer int
|
type ReplyNotPointer int
|
||||||
type ArgNotPublic int
|
type ArgNotPublic int
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue