diff --git a/client/client.go b/client/client.go index bf68e28..2704e56 100644 --- a/client/client.go +++ b/client/client.go @@ -3,6 +3,7 @@ package rpc_client import ( "crypto/tls" "fmt" + "io" "net" netrpc "net/rpc" "sync" @@ -12,25 +13,26 @@ import ( "github.com/stackengine/selog" "github.com/stackengine/serpc" "github.com/stackengine/ssltls" - "github.com/ugorji/go/codec" ) var sLog = selog.Register("clntrpc", 0) +type NewClientCodec func(conn io.ReadWriteCloser) netrpc.ClientCodec + type Conn struct { sync.Mutex - addr net.Addr - key string - lastUsed time.Time - mh *codec.MsgpackHandle - net_con net.Conn - pool *ConnPool - refCount int32 - rpc_clnt *netrpc.Client - shutdown int32 - stream_type string - version int + addr net.Addr + key string + lastUsed time.Time + newClientCodec NewClientCodec + net_con net.Conn + pool *ConnPool + refCount int32 + rpc_clnt *netrpc.Client + shutdown int32 + stream_type string + version int } func (c *Conn) String() string { @@ -38,7 +40,7 @@ func (c *Conn) String() string { c, c.stream_type, c.refCount, c.key, c.addr.String(), c.shutdown) } -func NewConn(mh *codec.MsgpackHandle, +func NewConn(newClientCodec NewClientCodec, addr net.Addr, stream_type string, key string, @@ -92,9 +94,8 @@ func NewConn(mh *codec.MsgpackHandle, // sLog.Printf("Wrote stream type for: '%s'", stream_type) var clnt *netrpc.Client - if mh != nil { - codec := codec.GoRpc.ClientCodec(conn, mh) - clnt = netrpc.NewClientWithCodec(codec) + if newClientCodec != nil { + clnt = netrpc.NewClientWithCodec(newClientCodec(conn)) } else { clnt = netrpc.NewClient(conn) } diff --git a/client/pool.go b/client/pool.go index 3100de9..c4a22a7 100644 --- a/client/pool.go +++ b/client/pool.go @@ -3,25 +3,25 @@ package rpc_client import ( "crypto/tls" "net" + "net/rpc" "strings" "sync" "time" "github.com/stackengine/serpc" - "github.com/ugorji/go/codec" ) type ConnPool struct { sync.Mutex - maxTime time.Duration // The maximum time to keep a connection open - timo time.Duration // The maximum time to attempt net.Dail() - pool map[string]*Conn // Pool maps an address to a open connection - tlsConfig *tls.Config // TLS settings - shutdown bool // Used to indicate the pool is shutdown - shutdownCh chan struct{} - wg sync.WaitGroup - mh *codec.MsgpackHandle + maxTime time.Duration // The maximum time to keep a connection open + timo time.Duration // The maximum time to attempt net.Dail() + pool map[string]*Conn // Pool maps an address to a open connection + tlsConfig *tls.Config // TLS settings + shutdown bool // Used to indicate the pool is shutdown + shutdownCh chan struct{} + wg sync.WaitGroup + newClientCodec NewClientCodec } // Reap is used to close unused conns open over maxTime @@ -64,18 +64,18 @@ func (p *ConnPool) reap() { // Maintain at most one connection per host, for up to maxTime. // Set maxTime to 0 to disable reaping. // If TLS settings are provided outgoing connections use TLS. -func NewPool(mh *codec.MsgpackHandle, +func NewPool(newClientCodec NewClientCodec, maxTime time.Duration, timo time.Duration, tlsConfig *tls.Config) *ConnPool { pool := &ConnPool{ - maxTime: maxTime, - timo: timo, - pool: make(map[string]*Conn), - tlsConfig: tlsConfig, - shutdownCh: make(chan struct{}), - mh: mh, + maxTime: maxTime, + timo: timo, + pool: make(map[string]*Conn), + tlsConfig: tlsConfig, + shutdownCh: make(chan struct{}), + newClientCodec: newClientCodec, } if maxTime > 0 { go pool.reap() @@ -136,7 +136,7 @@ func (p *ConnPool) getClnt(addr net.Addr, st string) (*Conn, error) { key := addr.String() + "/" + st c = p.getConn(key) if c == nil { - c, err = NewConn(p.mh, addr, st, key, p.timo, p.tlsConfig) + c, err = NewConn(p.newClientCodec, addr, st, key, p.timo, p.tlsConfig) if err != nil { return nil, err } @@ -169,3 +169,43 @@ func (p *ConnPool) RPC(addr net.Addr, stream_type string, version rpc_stream.Mux clnt_stream.Release() return err } + +// Call is used to make an RPC call to a remote host +func (p *ConnPool) Call(addr net.Addr, stream_type string, version rpc_stream.MuxVersion, + method string, args interface{}, reply interface{}) error { + + call, clnt_stream := p.Go(addr, stream_type, version, method, args, reply, nil) + call = <-call.Done + if clnt_stream != nil { + clnt_stream.Release() + } + return call.Error +} + +// Go is used to make an RPC Go call to a remote host +func (p *ConnPool) Go(addr net.Addr, stream_type string, version rpc_stream.MuxVersion, + method string, args interface{}, reply interface{}, done chan *rpc.Call) (*rpc.Call, *Conn) { + + st := strings.ToUpper(stream_type) + // sLog.Printf("Go: pool->%p addr: %s stream: %s method: %s", p, addr, st, method) + if reply == nil { + return &rpc.Call{ServiceMethod: method, Args: args, Reply: reply, Done: done, Error: ErrNeedReply}, nil + } + clnt_stream, err := p.getClnt(addr, st) + if err != nil { + sLog.Printf("rpc error: getClnt() %v", err) + return &rpc.Call{ServiceMethod: method, Args: args, Reply: reply, Done: done, Error: ErrNoClient}, clnt_stream + } + // sLog.Printf("@%p -> Go(%s, %s, %d, %s: Args: %#v)", clnt_stream, addr, st, version, method, args) + call := clnt_stream.rpc_clnt.Go(method, args, reply, done) + if call.Error != nil { + p.Shutdown(clnt_stream) + sLog.Printf("error on Go(): %v", err) + return &rpc.Call{ServiceMethod: method, Args: args, Reply: reply, Done: done, Error: ErrCallFailed}, clnt_stream + } + + // caller of this method needs to call this: + // clnt_stream.Release() + + return call, clnt_stream +} diff --git a/server/server.go b/server/server.go index a1ca094..3896662 100644 --- a/server/server.go +++ b/server/server.go @@ -56,15 +56,18 @@ func Register(name string, obj interface{}) error { return nil } +var _ SvcRPC = &RPCImpl{} + type RPCImpl struct { - inboundTLS *tls.Config - isTLS bool - secure bool - outboundTLS *tls.Config - rpc_l net.Listener - rpc_svr *netrpc.Server - lck sync.Mutex - shutdown bool + inboundTLS *tls.Config + isTLS bool + secure bool + outboundTLS *tls.Config + rpc_l net.Listener + rpc_svr *netrpc.Server + lck sync.Mutex + newServerCodec NewServerCodec + shutdown bool } func NewServer() *RPCImpl { @@ -75,9 +78,10 @@ func (impl *RPCImpl) Server() *netrpc.Server { return impl.rpc_svr } -func (impl *RPCImpl) Init(tlscfg *ssltls.Cfg, enforce_secure bool, port int) error { +func (impl *RPCImpl) Init(tlscfg *ssltls.Cfg, enforce_secure bool, port int, newServerCodec NewServerCodec) error { var err error + impl.newServerCodec = newServerCodec if tlscfg != nil { if impl.outboundTLS, err = tlscfg.OutgoingTLSConfig(); err != nil { return err @@ -225,9 +229,12 @@ func (impl *RPCImpl) MuxRPC(conn net.Conn, isTLS bool) { } func (impl *RPCImpl) serviceRPC(conn net.Conn) { - // codec := codec.GoRpc.ServerCodec(conn, impl.mh) sLog.Printf("Processing connection from %s", conn.RemoteAddr()) - impl.rpc_svr.ServeConn(conn) + if impl.newServerCodec == nil { + impl.rpc_svr.ServeConn(conn) + } else { + impl.rpc_svr.ServeCodec(impl.newServerCodec(conn)) + } sLog.Printf("Close connection from %s", conn.RemoteAddr()) conn.Close() } diff --git a/server/svcrpc.go b/server/svcrpc.go index f5de7e3..eb9c176 100644 --- a/server/svcrpc.go +++ b/server/svcrpc.go @@ -1,13 +1,16 @@ package rpc_server import ( + "io" netrpc "net/rpc" "github.com/stackengine/ssltls" ) +type NewServerCodec func(conn io.ReadWriteCloser) netrpc.ServerCodec + type SvcRPC interface { - Init(*ssltls.Cfg, bool, int) error + Init(*ssltls.Cfg, bool, int, NewServerCodec) error Start() error Shutdown() Server() *netrpc.Server