rpc: abstract client and server encodings

R=r
CC=golang-dev, rog
https://golang.org/cl/811046
This commit is contained in:
Russ Cox 2010-04-27 13:51:25 -07:00
parent 72f9b2ebee
commit dcff89057b
2 changed files with 131 additions and 41 deletions

View file

@ -272,7 +272,7 @@ func _new(t *reflect.PtrType) *reflect.PtrValue {
return v
}
func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) {
func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
resp := new(Response)
// Encode the response header
resp.ServiceMethod = req.ServiceMethod
@ -281,13 +281,14 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob
}
resp.Seq = req.Seq
sending.Lock()
enc.Encode(resp)
// Encode the reply value.
enc.Encode(reply)
err := codec.WriteResponse(resp, reply)
if err != nil {
log.Stderr("rpc: writing response: ", err)
}
sending.Unlock()
}
func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, enc *gob.Encoder) {
func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
mtype.Lock()
mtype.numCalls++
mtype.Unlock()
@ -300,17 +301,40 @@ func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, arg
if errInter != nil {
errmsg = errInter.(os.Error).String()
}
sendResponse(sending, req, replyv.Interface(), enc, errmsg)
sendResponse(sending, req, replyv.Interface(), codec, errmsg)
}
func (server *serverType) input(conn io.ReadWriteCloser) {
dec := gob.NewDecoder(conn)
enc := gob.NewEncoder(conn)
type gobServerCodec struct {
rwc io.ReadWriteCloser
dec *gob.Decoder
enc *gob.Encoder
}
func (c *gobServerCodec) ReadRequestHeader(r *Request) os.Error {
return c.dec.Decode(r)
}
func (c *gobServerCodec) ReadRequestBody(body interface{}) os.Error {
return c.dec.Decode(body)
}
func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) os.Error {
if err := c.enc.Encode(r); err != nil {
return err
}
return c.enc.Encode(body)
}
func (c *gobServerCodec) Close() os.Error {
return c.rwc.Close()
}
func (server *serverType) input(codec ServerCodec) {
sending := new(sync.Mutex)
for {
// Grab the request header.
req := new(Request)
err := dec.Decode(req)
err := codec.ReadRequestHeader(req)
if err != nil {
if err == os.EOF || err == io.ErrUnexpectedEOF {
if err == io.ErrUnexpectedEOF {
@ -319,13 +343,13 @@ func (server *serverType) input(conn io.ReadWriteCloser) {
break
}
s := "rpc: server cannot decode request: " + err.String()
sendResponse(sending, req, invalidRequest, enc, s)
continue
sendResponse(sending, req, invalidRequest, codec, s)
break
}
serviceMethod := strings.Split(req.ServiceMethod, ".", 0)
if len(serviceMethod) != 2 {
s := "rpc: service/method request ill:formed: " + req.ServiceMethod
sendResponse(sending, req, invalidRequest, enc, s)
s := "rpc: service/method request ill-formed: " + req.ServiceMethod
sendResponse(sending, req, invalidRequest, codec, s)
continue
}
// Look up the request.
@ -334,27 +358,27 @@ func (server *serverType) input(conn io.ReadWriteCloser) {
server.Unlock()
if !ok {
s := "rpc: can't find service " + req.ServiceMethod
sendResponse(sending, req, invalidRequest, enc, s)
sendResponse(sending, req, invalidRequest, codec, s)
continue
}
mtype, ok := service.method[serviceMethod[1]]
if !ok {
s := "rpc: can't find method " + req.ServiceMethod
sendResponse(sending, req, invalidRequest, enc, s)
sendResponse(sending, req, invalidRequest, codec, s)
continue
}
// Decode the argument value.
argv := _new(mtype.argType)
replyv := _new(mtype.replyType)
err = dec.Decode(argv.Interface())
err = codec.ReadRequestBody(argv.Interface())
if err != nil {
log.Stderr("rpc: tearing down", serviceMethod[0], "connection:", err)
sendResponse(sending, req, replyv.Interface(), enc, err.String())
continue
sendResponse(sending, req, replyv.Interface(), codec, err.String())
break
}
go service.call(sending, mtype, req, argv, replyv, enc)
go service.call(sending, mtype, req, argv, replyv, codec)
}
conn.Close()
codec.Close()
}
func (server *serverType) accept(lis net.Listener) {
@ -363,7 +387,7 @@ func (server *serverType) accept(lis net.Listener) {
if err != nil {
log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit?
}
go server.input(conn)
go ServeConn(conn)
}
}
@ -376,10 +400,34 @@ func (server *serverType) accept(lis net.Listener) {
// suitable methods.
func Register(rcvr interface{}) os.Error { return server.register(rcvr) }
// ServeConn runs the server on a single connection. When the connection
// completes, service terminates. ServeConn blocks; the caller typically
// invokes it in a go statement.
func ServeConn(conn io.ReadWriteCloser) { server.input(conn) }
// A ServerCodec implements reading of RPC requests and writing of
// RPC responses for the server side of an RPC session.
// The server calls ReadRequestHeader and ReadRequestBody in pairs
// to read requests from the connection, and it calls WriteResponse to
// write a response back. The server calls Close when finished with the
// connection.
type ServerCodec interface {
ReadRequestHeader(*Request) os.Error
ReadRequestBody(interface{}) os.Error
WriteResponse(*Response, interface{}) os.Error
Close() os.Error
}
// ServeConn runs the server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec.
func ServeConn(conn io.ReadWriteCloser) {
ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
}
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func ServeCodec(codec ServerCodec) {
server.input(codec)
}
// Accept accepts connections on the listener and serves requests
// for each incoming connection. Accept blocks; the caller typically
@ -404,7 +452,7 @@ func serveHTTP(c *http.Conn, req *http.Request) {
return
}
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.input(conn)
ServeConn(conn)
}
// HandleHTTP registers an HTTP handler for RPC messages.