Skip to content

Commit 9702bca

Browse files
committed
Middleware extended with per task context concept
It's now possible to implement a tracer system for example.
1 parent dbc8660 commit 9702bca

File tree

3 files changed

+339
-7
lines changed

3 files changed

+339
-7
lines changed

tests/test_middleware.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525

2626
from nose.tools import assert_raises
2727
import gevent
28+
import gevent.local
29+
import random
30+
import md5
2831

2932
from zerorpc import zmq
3033
import zerorpc
@@ -250,3 +253,274 @@ def test(self, argument):
250253
#FIXME: These seems to be broken
251254
# publisher.close()
252255
# subscriber.close()
256+
257+
258+
class Tracer:
259+
'''Used by test_task_context_* tests'''
260+
def __init__(self, identity):
261+
self._identity = identity
262+
self._locals = gevent.local.local()
263+
self._log = []
264+
265+
@property
266+
def trace_id(self):
267+
return self._locals.__dict__.get('trace_id', None)
268+
269+
def load_task_context(self, event_header):
270+
self._locals.trace_id = event_header.get('trace_id', None)
271+
print self._identity, 'load_task_context', self.trace_id
272+
self._log.append(('load', self.trace_id))
273+
274+
def get_task_context(self):
275+
if self.trace_id is None:
276+
# just an ugly code to generate a beautiful little hash.
277+
self._locals.trace_id = '<{0}>'.format(md5.md5(
278+
str(random.random())[3:]
279+
).hexdigest()[0:6].upper())
280+
print self._identity, 'get_task_context! [make a new one]', self.trace_id
281+
self._log.append(('new', self.trace_id))
282+
else:
283+
print self._identity, 'get_task_context! [reuse]', self.trace_id
284+
self._log.append(('reuse', self.trace_id))
285+
return { 'trace_id': self.trace_id }
286+
287+
288+
def test_task_context():
289+
endpoint = random_ipc_endpoint()
290+
srv_ctx = zerorpc.Context()
291+
cli_ctx = zerorpc.Context()
292+
293+
srv_tracer = Tracer('[server]')
294+
srv_ctx.register_middleware(srv_tracer)
295+
cli_tracer = Tracer('[client]')
296+
cli_ctx.register_middleware(cli_tracer)
297+
298+
class Srv:
299+
def echo(self, msg):
300+
return msg
301+
302+
@zerorpc.stream
303+
def stream(self):
304+
yield 42
305+
306+
srv = zerorpc.Server(Srv(), context=srv_ctx)
307+
srv.bind(endpoint)
308+
srv_task = gevent.spawn(srv.run)
309+
310+
c = zerorpc.Client(context=cli_ctx)
311+
c.connect(endpoint)
312+
313+
assert c.echo('hello') == 'hello'
314+
for x in c.stream():
315+
assert x == 42
316+
317+
srv.stop()
318+
srv_task.join()
319+
320+
assert cli_tracer._log == [
321+
('new', cli_tracer.trace_id),
322+
('reuse', cli_tracer.trace_id),
323+
]
324+
assert srv_tracer._log == [
325+
('load', cli_tracer.trace_id),
326+
('reuse', cli_tracer.trace_id),
327+
('load', cli_tracer.trace_id),
328+
('reuse', cli_tracer.trace_id),
329+
]
330+
331+
def test_task_context_relay():
332+
endpoint1 = random_ipc_endpoint()
333+
endpoint2 = random_ipc_endpoint()
334+
srv_ctx = zerorpc.Context()
335+
srv_relay_ctx = zerorpc.Context()
336+
cli_ctx = zerorpc.Context()
337+
338+
srv_tracer = Tracer('[server]')
339+
srv_ctx.register_middleware(srv_tracer)
340+
srv_relay_tracer = Tracer('[server_relay]')
341+
srv_relay_ctx.register_middleware(srv_relay_tracer)
342+
cli_tracer = Tracer('[client]')
343+
cli_ctx.register_middleware(cli_tracer)
344+
345+
class Srv:
346+
def echo(self, msg):
347+
return msg
348+
349+
srv = zerorpc.Server(Srv(), context=srv_ctx)
350+
srv.bind(endpoint1)
351+
srv_task = gevent.spawn(srv.run)
352+
353+
c_relay = zerorpc.Client(context=srv_relay_ctx)
354+
c_relay.connect(endpoint1)
355+
356+
class SrvRelay:
357+
def echo(self, msg):
358+
return c_relay.echo('relay' + msg) + 'relayed'
359+
360+
srv_relay = zerorpc.Server(SrvRelay(), context=srv_relay_ctx)
361+
srv_relay.bind(endpoint2)
362+
srv_relay_task = gevent.spawn(srv_relay.run)
363+
364+
c = zerorpc.Client(context=cli_ctx)
365+
c.connect(endpoint2)
366+
367+
assert c.echo('hello') == 'relayhellorelayed'
368+
369+
srv_relay.stop()
370+
srv.stop()
371+
srv_relay_task.join()
372+
srv_task.join()
373+
374+
assert cli_tracer._log == [
375+
('new', cli_tracer.trace_id),
376+
]
377+
assert srv_relay_tracer._log == [
378+
('load', cli_tracer.trace_id),
379+
('reuse', cli_tracer.trace_id),
380+
('reuse', cli_tracer.trace_id),
381+
]
382+
assert srv_tracer._log == [
383+
('load', cli_tracer.trace_id),
384+
('reuse', cli_tracer.trace_id),
385+
]
386+
387+
def test_task_context_relay_fork():
388+
endpoint1 = random_ipc_endpoint()
389+
endpoint2 = random_ipc_endpoint()
390+
srv_ctx = zerorpc.Context()
391+
srv_relay_ctx = zerorpc.Context()
392+
cli_ctx = zerorpc.Context()
393+
394+
srv_tracer = Tracer('[server]')
395+
srv_ctx.register_middleware(srv_tracer)
396+
srv_relay_tracer = Tracer('[server_relay]')
397+
srv_relay_ctx.register_middleware(srv_relay_tracer)
398+
cli_tracer = Tracer('[client]')
399+
cli_ctx.register_middleware(cli_tracer)
400+
401+
class Srv:
402+
def echo(self, msg):
403+
return msg
404+
405+
srv = zerorpc.Server(Srv(), context=srv_ctx)
406+
srv.bind(endpoint1)
407+
srv_task = gevent.spawn(srv.run)
408+
409+
c_relay = zerorpc.Client(context=srv_relay_ctx)
410+
c_relay.connect(endpoint1)
411+
412+
class SrvRelay:
413+
def echo(self, msg):
414+
def dothework(msg):
415+
return c_relay.echo(msg) + 'relayed'
416+
g = gevent.spawn(zerorpc.fork_task_context(dothework,
417+
srv_relay_ctx), 'relay' + msg)
418+
print 'relaying in separate task:', g
419+
r = g.get()
420+
print 'back to main task'
421+
return r
422+
423+
srv_relay = zerorpc.Server(SrvRelay(), context=srv_relay_ctx)
424+
srv_relay.bind(endpoint2)
425+
srv_relay_task = gevent.spawn(srv_relay.run)
426+
427+
c = zerorpc.Client(context=cli_ctx)
428+
c.connect(endpoint2)
429+
430+
assert c.echo('hello') == 'relayhellorelayed'
431+
432+
srv_relay.stop()
433+
srv.stop()
434+
srv_relay_task.join()
435+
srv_task.join()
436+
437+
assert cli_tracer._log == [
438+
('new', cli_tracer.trace_id),
439+
]
440+
assert srv_relay_tracer._log == [
441+
('load', cli_tracer.trace_id),
442+
('reuse', cli_tracer.trace_id),
443+
('load', cli_tracer.trace_id),
444+
('reuse', cli_tracer.trace_id),
445+
('reuse', cli_tracer.trace_id),
446+
]
447+
assert srv_tracer._log == [
448+
('load', cli_tracer.trace_id),
449+
('reuse', cli_tracer.trace_id),
450+
]
451+
452+
453+
def test_task_context_pushpull():
454+
endpoint = random_ipc_endpoint()
455+
puller_ctx = zerorpc.Context()
456+
pusher_ctx = zerorpc.Context()
457+
458+
puller_tracer = Tracer('[puller]')
459+
puller_ctx.register_middleware(puller_tracer)
460+
pusher_tracer = Tracer('[pusher]')
461+
pusher_ctx.register_middleware(pusher_tracer)
462+
463+
trigger = gevent.event.Event()
464+
465+
class Puller:
466+
def echo(self, msg):
467+
trigger.set()
468+
469+
puller = zerorpc.Puller(Puller(), context=puller_ctx)
470+
puller.bind(endpoint)
471+
puller_task = gevent.spawn(puller.run)
472+
473+
c = zerorpc.Pusher(context=pusher_ctx)
474+
c.connect(endpoint)
475+
476+
trigger.clear()
477+
c.echo('hello')
478+
trigger.wait()
479+
480+
puller.stop()
481+
puller_task.join()
482+
483+
assert pusher_tracer._log == [
484+
('new', pusher_tracer.trace_id),
485+
]
486+
assert puller_tracer._log == [
487+
('load', pusher_tracer.trace_id),
488+
]
489+
490+
491+
def test_task_context_pubsub():
492+
endpoint = random_ipc_endpoint()
493+
subscriber_ctx = zerorpc.Context()
494+
publisher_ctx = zerorpc.Context()
495+
496+
subscriber_tracer = Tracer('[subscriber]')
497+
subscriber_ctx.register_middleware(subscriber_tracer)
498+
publisher_tracer = Tracer('[publisher]')
499+
publisher_ctx.register_middleware(publisher_tracer)
500+
501+
trigger = gevent.event.Event()
502+
503+
class Subscriber:
504+
def echo(self, msg):
505+
trigger.set()
506+
507+
subscriber = zerorpc.Subscriber(Subscriber(), context=subscriber_ctx)
508+
subscriber.bind(endpoint)
509+
subscriber_task = gevent.spawn(subscriber.run)
510+
511+
c = zerorpc.Publisher(context=publisher_ctx)
512+
c.connect(endpoint)
513+
514+
trigger.clear()
515+
c.echo('pub...')
516+
trigger.wait()
517+
518+
subscriber.stop()
519+
subscriber_task.join()
520+
521+
assert publisher_tracer._log == [
522+
('new', publisher_tracer.trace_id),
523+
]
524+
assert subscriber_tracer._log == [
525+
('load', publisher_tracer.trace_id),
526+
]

zerorpc/context.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def __init__(self):
3737
self._middlewares_hooks = {
3838
'resolve_endpoint': [],
3939
'raise_error': [],
40-
'call_procedure': []
40+
'call_procedure': [],
41+
'load_task_context': [],
42+
'get_task_context': [],
4143
}
4244

4345
@staticmethod
@@ -86,3 +88,13 @@ def __call__(self, *args, **kwargs):
8688
for functor in self._middlewares_hooks['call_procedure']:
8789
procedure = chain(functor, procedure)
8890
return procedure(*args, **kwargs)
91+
92+
def middleware_load_task_context(self, event_header):
93+
for functor in self._middlewares_hooks['load_task_context']:
94+
functor(event_header)
95+
96+
def middleware_get_task_context(self):
97+
event_header = {}
98+
for functor in self._middlewares_hooks['get_task_context']:
99+
event_header.update(functor())
100+
return event_header

0 commit comments

Comments
 (0)