Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rpc_client
import (
"crypto/tls"
"fmt"
"io"
"net"
netrpc "net/rpc"
"sync"
Expand All @@ -12,33 +13,34 @@ import (
"github.com/stackengine/selog"
"github.com/stackengine/serpc"
"github.com/stackengine/ssltls"
"github.com/ugorji/go/codec"
)

var sLog = selog.Register("clntrpc", 0)

type NewClientCodec func(conn io.ReadWriteCloser) netrpc.ClientCodec

type Conn struct {
sync.Mutex

addr net.Addr
key string
lastUsed time.Time
mh *codec.MsgpackHandle
net_con net.Conn
pool *ConnPool
refCount int32
rpc_clnt *netrpc.Client
shutdown int32
stream_type string
version int
addr net.Addr
key string
lastUsed time.Time
newClientCodec NewClientCodec
net_con net.Conn
pool *ConnPool
refCount int32
rpc_clnt *netrpc.Client
shutdown int32
stream_type string
version int
}

func (c *Conn) String() string {
return fmt.Sprintf("Conn:%p type: %s ref: %d key: %s addr: %s shutdown: %d",
c, c.stream_type, c.refCount, c.key, c.addr.String(), c.shutdown)
}

func NewConn(mh *codec.MsgpackHandle,
func NewConn(newClientCodec NewClientCodec,
addr net.Addr,
stream_type string,
key string,
Expand Down Expand Up @@ -92,9 +94,8 @@ func NewConn(mh *codec.MsgpackHandle,
// sLog.Printf("Wrote stream type for: '%s'", stream_type)
var clnt *netrpc.Client

if mh != nil {
codec := codec.GoRpc.ClientCodec(conn, mh)
clnt = netrpc.NewClientWithCodec(codec)
if newClientCodec != nil {
clnt = netrpc.NewClientWithCodec(newClientCodec(conn))
} else {
clnt = netrpc.NewClient(conn)
}
Expand Down
74 changes: 57 additions & 17 deletions client/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@ package rpc_client
import (
"crypto/tls"
"net"
"net/rpc"
"strings"
"sync"
"time"

"github.com/stackengine/serpc"
"github.com/ugorji/go/codec"
)

type ConnPool struct {
sync.Mutex

maxTime time.Duration // The maximum time to keep a connection open
timo time.Duration // The maximum time to attempt net.Dail()
pool map[string]*Conn // Pool maps an address to a open connection
tlsConfig *tls.Config // TLS settings
shutdown bool // Used to indicate the pool is shutdown
shutdownCh chan struct{}
wg sync.WaitGroup
mh *codec.MsgpackHandle
maxTime time.Duration // The maximum time to keep a connection open
timo time.Duration // The maximum time to attempt net.Dail()
pool map[string]*Conn // Pool maps an address to a open connection
tlsConfig *tls.Config // TLS settings
shutdown bool // Used to indicate the pool is shutdown
shutdownCh chan struct{}
wg sync.WaitGroup
newClientCodec NewClientCodec
}

// Reap is used to close unused conns open over maxTime
Expand Down Expand Up @@ -64,18 +64,18 @@ func (p *ConnPool) reap() {
// Maintain at most one connection per host, for up to maxTime.
// Set maxTime to 0 to disable reaping.
// If TLS settings are provided outgoing connections use TLS.
func NewPool(mh *codec.MsgpackHandle,
func NewPool(newClientCodec NewClientCodec,
maxTime time.Duration,
timo time.Duration,
tlsConfig *tls.Config) *ConnPool {

pool := &ConnPool{
maxTime: maxTime,
timo: timo,
pool: make(map[string]*Conn),
tlsConfig: tlsConfig,
shutdownCh: make(chan struct{}),
mh: mh,
maxTime: maxTime,
timo: timo,
pool: make(map[string]*Conn),
tlsConfig: tlsConfig,
shutdownCh: make(chan struct{}),
newClientCodec: newClientCodec,
}
if maxTime > 0 {
go pool.reap()
Expand Down Expand Up @@ -136,7 +136,7 @@ func (p *ConnPool) getClnt(addr net.Addr, st string) (*Conn, error) {
key := addr.String() + "/" + st
c = p.getConn(key)
if c == nil {
c, err = NewConn(p.mh, addr, st, key, p.timo, p.tlsConfig)
c, err = NewConn(p.newClientCodec, addr, st, key, p.timo, p.tlsConfig)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -169,3 +169,43 @@ func (p *ConnPool) RPC(addr net.Addr, stream_type string, version rpc_stream.Mux
clnt_stream.Release()
return err
}

// Call is used to make an RPC call to a remote host
func (p *ConnPool) Call(addr net.Addr, stream_type string, version rpc_stream.MuxVersion,
method string, args interface{}, reply interface{}) error {

call, clnt_stream := p.Go(addr, stream_type, version, method, args, reply, nil)
call = <-call.Done
if clnt_stream != nil {
clnt_stream.Release()
}
return call.Error
}

// Go is used to make an RPC Go call to a remote host
func (p *ConnPool) Go(addr net.Addr, stream_type string, version rpc_stream.MuxVersion,
method string, args interface{}, reply interface{}, done chan *rpc.Call) (*rpc.Call, *Conn) {

st := strings.ToUpper(stream_type)
// sLog.Printf("Go: pool->%p addr: %s stream: %s method: %s", p, addr, st, method)
if reply == nil {
return &rpc.Call{ServiceMethod: method, Args: args, Reply: reply, Done: done, Error: ErrNeedReply}, nil
}
clnt_stream, err := p.getClnt(addr, st)
if err != nil {
sLog.Printf("rpc error: getClnt() %v", err)
return &rpc.Call{ServiceMethod: method, Args: args, Reply: reply, Done: done, Error: ErrNoClient}, clnt_stream
}
// sLog.Printf("@%p -> Go(%s, %s, %d, %s: Args: %#v)", clnt_stream, addr, st, version, method, args)
call := clnt_stream.rpc_clnt.Go(method, args, reply, done)
if call.Error != nil {
p.Shutdown(clnt_stream)
sLog.Printf("error on Go(): %v", err)
return &rpc.Call{ServiceMethod: method, Args: args, Reply: reply, Done: done, Error: ErrCallFailed}, clnt_stream
}

// caller of this method needs to call this:
// clnt_stream.Release()

return call, clnt_stream
}
29 changes: 18 additions & 11 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,18 @@ func Register(name string, obj interface{}) error {
return nil
}

var _ SvcRPC = &RPCImpl{}

type RPCImpl struct {
inboundTLS *tls.Config
isTLS bool
secure bool
outboundTLS *tls.Config
rpc_l net.Listener
rpc_svr *netrpc.Server
lck sync.Mutex
shutdown bool
inboundTLS *tls.Config
isTLS bool
secure bool
outboundTLS *tls.Config
rpc_l net.Listener
rpc_svr *netrpc.Server
lck sync.Mutex
newServerCodec NewServerCodec
shutdown bool
}

func NewServer() *RPCImpl {
Expand All @@ -75,9 +78,10 @@ func (impl *RPCImpl) Server() *netrpc.Server {
return impl.rpc_svr
}

func (impl *RPCImpl) Init(tlscfg *ssltls.Cfg, enforce_secure bool, port int) error {
func (impl *RPCImpl) Init(tlscfg *ssltls.Cfg, enforce_secure bool, port int, newServerCodec NewServerCodec) error {
var err error

impl.newServerCodec = newServerCodec
if tlscfg != nil {
if impl.outboundTLS, err = tlscfg.OutgoingTLSConfig(); err != nil {
return err
Expand Down Expand Up @@ -225,9 +229,12 @@ func (impl *RPCImpl) MuxRPC(conn net.Conn, isTLS bool) {
}

func (impl *RPCImpl) serviceRPC(conn net.Conn) {
// codec := codec.GoRpc.ServerCodec(conn, impl.mh)
sLog.Printf("Processing connection from %s", conn.RemoteAddr())
impl.rpc_svr.ServeConn(conn)
if impl.newServerCodec == nil {
impl.rpc_svr.ServeConn(conn)
} else {
impl.rpc_svr.ServeCodec(impl.newServerCodec(conn))
}
sLog.Printf("Close connection from %s", conn.RemoteAddr())
conn.Close()
}
5 changes: 4 additions & 1 deletion server/svcrpc.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package rpc_server

import (
"io"
netrpc "net/rpc"

"github.com/stackengine/ssltls"
)

type NewServerCodec func(conn io.ReadWriteCloser) netrpc.ServerCodec

type SvcRPC interface {
Init(*ssltls.Cfg, bool, int) error
Init(*ssltls.Cfg, bool, int, NewServerCodec) error
Start() error
Shutdown()
Server() *netrpc.Server
Expand Down