diff --git a/splitio/api/client.py b/splitio/api/client.py index 86945a27..e670bba3 100644 --- a/splitio/api/client.py +++ b/splitio/api/client.py @@ -95,7 +95,6 @@ def get(self, server, path, apikey, query=None, extra_headers=None): #pylint: d :rtype: HttpResponse """ headers = self._build_basic_headers(apikey) - if extra_headers is not None: headers.update(extra_headers) diff --git a/splitio/push/__init__.py b/splitio/push/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/splitio/push/splitsse.py b/splitio/push/splitsse.py new file mode 100644 index 00000000..9286166e --- /dev/null +++ b/splitio/push/splitsse.py @@ -0,0 +1,126 @@ +"""An SSE client wrapper to be used with split endpoint.""" +import logging +import threading +from enum import Enum +import six +from splitio.push.sse import SSEClient, SSE_EVENT_ERROR +from splitio.util.threadutil import EventGroup + + +_LOGGER = logging.getLogger(__name__) + + +class SplitSSEClient(object): + """Split streaming endpoint SSE client.""" + + class _Status(Enum): + IDLE = 0 + CONNECTING = 1 + ERRORED = 2 + CONNECTED = 3 + + def __init__(self, callback, base_url='https://streaming.split.io'): + """ + Construct a split sse client. + + :param callback: fuction to call when an event is received. + :type callback: callable + + :param base_url: scheme + :// + host + :type base_url: str + """ + self._client = SSEClient(self._raw_event_handler) + self._callback = callback + self._base_url = base_url + self._status = SplitSSEClient._Status.IDLE + self._sse_first_event = None + self._sse_connection_closed = None + + def _raw_event_handler(self, event): + """ + Handle incoming raw sse event. + + :param event: Incoming raw sse event. + :type event: splitio.push.sse.SSEEvent + """ + if self._status == SplitSSEClient._Status.CONNECTING: + self._status = SplitSSEClient._Status.CONNECTED if event.event != SSE_EVENT_ERROR \ + else SplitSSEClient._Status.ERRORED + self._sse_first_event.set() + + if event.data is not None: + self._callback(event) + + @staticmethod + def _format_channels(channels): + """ + Format channels into a list from the raw object retrieved in the token. + + :param channels: object as extracted from the JWT capabilities. + :type channels: dict[str,list[str]] + + :returns: channels as a list of strings. + :rtype: list[str] + """ + regular = [k for (k, v) in six.iteritems(channels) if v == ['subscribe']] + occupancy = ['[?occupancy=metrics.publishers]' + k + for (k, v) in six.iteritems(channels) + if 'channel-metadata:publishers' in v] + return regular + occupancy + + def _build_url(self, token): + """ + Build the url to connect to and return it as a string. + + :param token: (parsed) JWT + :type token: splitio.models.token.Token + + :returns: true if the connection was successful. False otherwise. + :rtype: bool + """ + return '{base}/event-stream?v=1.1&accessToken={token}&channels={channels}'.format( + base=self._base_url, + token=token.token, + channels=','.join(self._format_channels(token.channels))) + + def start(self, token): + """ + Open a connection to start listening for events. + + :param token: (parsed) JWT + :type token: splitio.models.token.Token + + :returns: true if the connection was successful. False otherwise. + :rtype: bool + """ + if self._status != SplitSSEClient._Status.IDLE: + raise Exception('SseClient already started.') + + self._status = SplitSSEClient._Status.CONNECTING + + event_group = EventGroup() + self._sse_first_event = event_group.make_event() + self._sse_connection_closed = event_group.make_event() + + def connect(url): + """Connect to sse in a blocking manner.""" + try: + self._client.start(url) + finally: + self._sse_connection_closed.set() + self._status = SplitSSEClient._Status.IDLE + + url = self._build_url(token) + task = threading.Thread(target=connect, args=(url,)) + task.setDaemon(True) + task.start() + event_group.wait() + return self._status == SplitSSEClient._Status.CONNECTED + + def stop(self, blocking=False, timeout=None): + """Abort the ongoing connection.""" + if self._status == SplitSSEClient._Status.IDLE: + raise Exception('SseClient not running') + self._client.shutdown() + if blocking: + self._sse_connection_closed.wait(timeout) diff --git a/splitio/push/sse.py b/splitio/push/sse.py new file mode 100644 index 00000000..344d41f5 --- /dev/null +++ b/splitio/push/sse.py @@ -0,0 +1,167 @@ +"""Low-level SSE Client.""" +import logging +import socket +import sys +from collections import namedtuple + +try: # try to import python3 names. fallback to python2 + from http.client import HTTPConnection, HTTPSConnection + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + from httplib import HTTPConnection, HTTPSConnection + + +_LOGGER = logging.getLogger(__name__) + + +SSE_EVENT_ERROR = 'error' +SSE_EVENT_MESSAGE = 'message' + + +SSEEvent = namedtuple('SSEEvent', ['event_id', 'event', 'retry', 'data']) + + +__ENDING_CHARS = set(['\n', '']) +def __httpresponse_readline_py2(response): + """ + Hacky `readline` implementation to be used with chunked transfers in python2. + + This makes syscalls in a loop, so not particularly efficient. Migrate to py3 now! + + :param response: HTTPConnection's response after a .request() call + :type response: httplib.HTTPResponse + + :returns: a string with the read line + :rtype: str + """ + buf = [] + while True: + read = response.read(1) + buf.append(read) + if read in __ENDING_CHARS: + break + + return ''.join(buf) + + +_http_response_readline = (__httpresponse_readline_py2 if sys.version_info.major <= 2 #pylint:disable=invalid-name + else lambda response: response.readline()) + + +class EventBuilder(object): + """Event builder class.""" + + _SEPARATOR = b':' + + def __init__(self): + """Construct a builder.""" + self._lines = {} + + def process_line(self, line): + """ + Process a new line. + + :param line: Line to process + :type line: bytes + """ + try: + key, val = line.split(self._SEPARATOR, 1) + self._lines[key.decode('utf8').strip()] = val.decode('utf8').strip() + except ValueError: # key without a value + self._lines[line.decode('utf8').strip()] = None + + def build(self): + """Construct an event with relevant fields.""" + return SSEEvent(self._lines.get('id'), self._lines.get('event'), + self._lines.get('retry'), self._lines.get('data')) + + +class SSEClient(object): + """SSE Client implementation.""" + + _DEFAULT_HEADERS = {'Accept': 'text/event-stream'} + _EVENT_SEPARATORS = set([b'\n', b'\r\n']) + + def __init__(self, callback): + """ + Construct an SSE client. + + :param callback: function to call when an event is received + :type callback: callable + """ + self._connection = None + self._event_callback = callback + self._shutdown_requested = False + + def _read_events(self): + """ + Read events from the supplied connection. + + :returns: True if the connection was ended by us. False if it was closed by the serve. + :rtype: bool + """ + try: + response = self._connection.getresponse() + event_builder = EventBuilder() + while True: + line = _http_response_readline(response) + if line is None or len(line) <= 0: # connection ended + _LOGGER.info("sse connection has ended.") + break + elif line.startswith(b':'): # comment. Skip + _LOGGER.debug("skipping sse comment") + continue + elif line in self._EVENT_SEPARATORS: + event = event_builder.build() + _LOGGER.debug("dispatching event: %s", event) + self._event_callback(event) + event_builder = EventBuilder() + else: + event_builder.process_line(line) + except Exception: #pylint:disable=broad-except + _LOGGER.info('sse connection ended.') + _LOGGER.debug('stack trace: ', exc_info=True) + finally: + self._connection.close() + self._connection = None # clear so it can be started again + + return self._shutdown_requested + + def start(self, url, extra_headers=None): #pylint:disable=dangerous-default-value + """ + Connect and start listening for events. + + :param url: url to connect to + :type url: str + + :param extra_headers: additional headers + :type extra_headers: dict[str, str] + + :returns: True if the connection was ended by us. False if it was closed by the serve. + :rtype: bool + """ + if self._connection is not None: + raise RuntimeError('Client already started.') + + url = urlparse(url) + headers = self._DEFAULT_HEADERS.copy() + headers.update(extra_headers if extra_headers is not None else {}) + self._connection = HTTPSConnection(url.hostname, url.port) if url.scheme == 'https' \ + else HTTPConnection(url.hostname, port=url.port) + + self._connection.request('GET', '%s?%s' % (url.path, url.query), headers=headers) + return self._read_events() + + def shutdown(self): + """Shutdown the current connection.""" + if self._connection is None: + _LOGGER.warn("no sse connection has been started on this SSEClient instance. Ignoring") + return + + if self._shutdown_requested: + _LOGGER.warn("shutdown already requested") + return + + self._shutdown_requested = True + self._connection.sock.shutdown(socket.SHUT_RDWR) diff --git a/splitio/util/__init__.py b/splitio/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/splitio/util/threadutil.py b/splitio/util/threadutil.py new file mode 100644 index 00000000..f76b4590 --- /dev/null +++ b/splitio/util/threadutil.py @@ -0,0 +1,56 @@ +"""Threading utilities.""" +from inspect import isclass +import threading + + +# python2 workaround +_EventClass = threading.Event if isclass(threading.Event) else threading._Event #pylint:disable=protected-access,invalid-name + + +class EventGroup(object): + """EventGroup that can be waited with an OR condition.""" + + class Event(_EventClass): #pylint:disable=too-few-public-methods + """Threading event meant to be used in an group.""" + + def __init__(self, shared_condition): + """ + Construct an event. + + :param shared_condition: shared condition varaible. + :type shared_condition: threading.Condition + """ + _EventClass.__init__(self) + self._shared_cond = shared_condition + + def set(self): + """Set the event.""" + _EventClass.set(self) + with self._shared_cond: + self._shared_cond.notify() + + def __init__(self): + """Construct an event group.""" + self._cond = threading.Condition() + + def make_event(self): + """ + Make a new event associated to this waitable group. + + :returns: an event that can be awaited as part of a group + :rtype: EventGroup.Event + """ + return EventGroup.Event(self._cond) + + def wait(self, timeout=None): + """ + Wait until one of the events is triggered. + + :param timeout: how many seconds to wait. None means forever. + :type timeout: int + + :returns: True if the condition was notified within the specified timeout. False otherwise. + :rtype: bool + """ + with self._cond: + return self._cond.wait(timeout) diff --git a/tests/models/test_token.py b/tests/models/test_token.py index 5ab90ed5..935de52b 100644 --- a/tests/models/test_token.py +++ b/tests/models/test_token.py @@ -6,10 +6,7 @@ class TokenTests(object): """Token model tests.""" - raw_false = { - 'pushEnabled': False, - 'token': 'eyJhbGciOiJIUzI1NiIsImtpZCI6IjVZOU05US45QnJtR0EiLCJ0eXAiOiJKV1QifQ.eyJ4LWFibHktY2FwYWJpbGl0eSI6IntcIk56TTJNREk1TXpjMF9NVGd5TlRnMU1UZ3dOZz09X3NlZ21lbnRzXCI6W1wic3Vic2NyaWJlXCJdLFwiTnpNMk1ESTVNemMwX01UZ3lOVGcxTVRnd05nPT1fc3BsaXRzXCI6W1wic3Vic2NyaWJlXCJdLFwiY29udHJvbF9wcmlcIjpbXCJzdWJzY3JpYmVcIixcImNoYW5uZWwtbWV0YWRhdGE6cHVibGlzaGVyc1wiXSxcImNvbnRyb2xfc2VjXCI6W1wic3Vic2NyaWJlXCIsXCJjaGFubmVsLW1ldGFkYXRhOnB1Ymxpc2hlcnNcIl19IiwieC1hYmx5LWNsaWVudElkIjoiY2xpZW50SWQiLCJleHAiOjE2MDIwODgxMjcsImlhdCI6MTYwMjA4NDUyN30.5_MjWonhs6yoFhw44hNJm3H7_YMjXpSW105DwjjppqE', - } + raw_false = {'pushEnabled': False} def test_from_raw_false(self): """Test token model parsing.""" diff --git a/tests/push/__init__.py b/tests/push/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/push/mockserver.py b/tests/push/mockserver.py new file mode 100644 index 00000000..6e76b946 --- /dev/null +++ b/tests/push/mockserver.py @@ -0,0 +1,91 @@ +"""asd.""" +import queue +import threading + +from http.server import HTTPServer, BaseHTTPRequestHandler + + +class SSEMockServer(object): + """SSE server for testing purposes.""" + + protocol_version = 'HTTP/1.1' + + GRACEFUL_REQUEST_END = 'REQ-END' + VIOLENT_REQUEST_END = 'REQ-KILL' + + def __init__(self, req_queue=None): + """Consruct a mock server.""" + self._queue = queue.Queue() + self._server = HTTPServer(('localhost', 0), + lambda *xs: SSEHandler(self._queue, *xs, req_queue=req_queue)) + self._server_thread = threading.Thread(target=self._blocking_run) + self._server_thread.setDaemon(True) + self._done_event = threading.Event() + + def _blocking_run(self): + """Execute.""" + self._server.serve_forever() + self._done_event.set() + + def port(self): + """Return the assigned port.""" + return self._server.server_port + + def publish(self, event): + """Publish an event.""" + self._queue.put(event, block=False) + + def start(self): + """Start the server asyncrhonously.""" + self._server_thread.start() + + def wait(self, timeout=None): + """Wait for the server to shutdown.""" + return self._done_event.wait(timeout) + + def stop(self): + """Stop the server.""" + self._server.shutdown() + + +class SSEHandler(BaseHTTPRequestHandler): + """Handler.""" + + def __init__(self, event_queue, *args, **kwargs): + """Construct a handler.""" + self._queue = event_queue + self._req_queue = kwargs.get('req_queue') + BaseHTTPRequestHandler.__init__(self, *args) + + def do_GET(self): #pylint:disable=invalid-name + """Respond to a GET request.""" + self.send_response(200) + self.send_header("Content-type", "text/event-stream") + self.send_header("Transfer-Encoding", "chunked") + self.send_header("Connection", "keep-alive") + self.end_headers() + + if self._req_queue is not None: + self._req_queue.put(self.path) + + def write_chunk(chunk): + """Write an event/chunk.""" + tosend = '%X\r\n%s\r\n'%(len(chunk), chunk) + self.wfile.write(tosend.encode('utf-8')) + + while True: + event = self._queue.get() + if event == SSEMockServer.GRACEFUL_REQUEST_END: + break + elif event == SSEMockServer.VIOLENT_REQUEST_END: + raise Exception('exploding') + + chunk = '' + chunk += 'id: % s\n' % event['id'] if 'id' in event else '' + chunk += 'event: % s\n' % event['event'] if 'event' in event else '' + chunk += 'retry: % s\n' % event['retry'] if 'retry' in event else '' + chunk += 'data: % s\n' % event['data'] if 'data' in event else '' + if chunk != '': + write_chunk(chunk + '\r\n') + + self.wfile.write('0\r\n\r\n'.encode('utf-8')) diff --git a/tests/push/test_splitsse.py b/tests/push/test_splitsse.py new file mode 100644 index 00000000..7bd4fab0 --- /dev/null +++ b/tests/push/test_splitsse.py @@ -0,0 +1,81 @@ +"""SSEClient unit tests.""" + +import time +import threading +from queue import Queue +import pytest +from splitio.models.token import Token +from splitio.push.splitsse import SplitSSEClient +from splitio.push.sse import SSEEvent + +from .mockserver import SSEMockServer + + +class SSEClientTests(object): + """SSEClient test cases.""" + + def test_split_sse_success(self): + """Test correct initialization. Client ends the connection.""" + + events = [] + def handler(event): + """Handler.""" + events.append(event) + + request_queue = Queue() + server = SSEMockServer(request_queue) + server.start() + + client = SplitSSEClient(handler, 'http://localhost:' + str(server.port())) + + token = Token(True, 'some', {'chan1': ['subscribe'], 'chan2': ['subscribe', 'channel-metadata:publishers']}, + 1, 2) + + server.publish({'id': '1'}) # send a non-error event early to unblock start + assert client.start(token) + with pytest.raises(Exception): + client.start(token) + + server.publish({'id': '1', 'data': 'a', 'retry': '1', 'event': 'message'}) + server.publish({'id': '2', 'data': 'a', 'retry': '1', 'event': 'message'}) + time.sleep(1) + client.stop() + + assert request_queue.get() == '/event-stream?v=1.1&accessToken=some&channels=chan1,[?occupancy=metrics.publishers]chan2' + + assert events == [ + SSEEvent('1', 'message', '1', 'a'), + SSEEvent('2', 'message', '1', 'a') + ] + + server.publish(SSEMockServer.VIOLENT_REQUEST_END) + server.stop() + + def test_split_sse_error(self): + """Test correct initialization. Client ends the connection.""" + + events = [] + def handler(event): + """Handler.""" + events.append(event) + + request_queue = Queue() + server = SSEMockServer(request_queue) + server.start() + + client = SplitSSEClient(handler, 'http://localhost:' + str(server.port())) + + token = Token(True, 'some', {'chan1': ['subscribe'], 'chan2': ['subscribe', 'channel-metadata:publishers']}, + 1, 2) + + server.publish({'event': 'error'}) # send an error event early to unblock start + assert not client.start(token) + client.stop(True) + with pytest.raises(Exception): + client.stop() + + assert request_queue.get() == ('/event-stream?v=1.1&accessToken=some' + '&channels=chan1,[?occupancy=metrics.publishers]chan2') + + server.publish(SSEMockServer.VIOLENT_REQUEST_END) + server.stop() diff --git a/tests/push/test_sse.py b/tests/push/test_sse.py new file mode 100644 index 00000000..928fefbf --- /dev/null +++ b/tests/push/test_sse.py @@ -0,0 +1,128 @@ +"""SSEClient unit tests.""" + +import time +import threading +import pytest +from splitio.push.sse import SSEClient, SSEEvent +from .mockserver import SSEMockServer + + +class SSEClientTests(object): + """SSEClient test cases.""" + + def test_sse_client_disconnects(self): + """Test correct initialization. Client ends the connection.""" + server = SSEMockServer() + server.start() + + events = [] + def callback(event): + """Callback.""" + events.append(event) + + client = SSEClient(callback) + + def runner(): + """SSE client runner thread.""" + assert client.start('http://127.0.0.1:' + str(server.port())) + client_task = threading.Thread(target=runner) + client_task.setDaemon(True) + client_task.setName('client') + client_task.start() + with pytest.raises(RuntimeError): + client_task.start() + + server.publish({'id': '1'}) + server.publish({'id': '2', 'event': 'message', 'data': 'abc'}) + server.publish({'id': '3', 'event': 'message', 'data': 'def'}) + server.publish({'id': '4', 'event': 'message', 'data': 'ghi'}) + time.sleep(1) + client.shutdown() + time.sleep(1) + + assert events == [ + SSEEvent('1', None, None, None), + SSEEvent('2', 'message', None, 'abc'), + SSEEvent('3', 'message', None, 'def'), + SSEEvent('4', 'message', None, 'ghi') + ] + + assert client._connection is None + server.publish(server.GRACEFUL_REQUEST_END) + server.stop() + + def test_sse_server_disconnects(self): + """Test correct initialization. Server ends connection.""" + server = SSEMockServer() + server.start() + + events = [] + def callback(event): + """Callback.""" + events.append(event) + + client = SSEClient(callback) + + def runner(): + """SSE client runner thread.""" + assert client.start('http://127.0.0.1:' + str(server.port())) + client_task = threading.Thread(target=runner) + client_task.setDaemon(True) + client_task.setName('client') + client_task.start() + + server.publish({'id': '1'}) + server.publish({'id': '2', 'event': 'message', 'data': 'abc'}) + server.publish({'id': '3', 'event': 'message', 'data': 'def'}) + server.publish({'id': '4', 'event': 'message', 'data': 'ghi'}) + time.sleep(1) + server.publish(server.GRACEFUL_REQUEST_END) + server.stop() + time.sleep(1) + + assert events == [ + SSEEvent('1', None, None, None), + SSEEvent('2', 'message', None, 'abc'), + SSEEvent('3', 'message', None, 'def'), + SSEEvent('4', 'message', None, 'ghi') + ] + + assert client._connection is None + + def test_sse_server_disconnects_abruptly(self): + """Test correct initialization. Server ends connection.""" + server = SSEMockServer() + server.start() + + events = [] + def callback(event): + """Callback.""" + events.append(event) + + client = SSEClient(callback) + + def runner(): + """SSE client runner thread.""" + assert client.start('http://127.0.0.1:' + str(server.port())) + client_task = threading.Thread(target=runner) + client_task.setDaemon(True) + client_task.setName('client') + client_task.start() + + server.publish({'id': '1'}) + server.publish({'id': '2', 'event': 'message', 'data': 'abc'}) + server.publish({'id': '3', 'event': 'message', 'data': 'def'}) + server.publish({'id': '4', 'event': 'message', 'data': 'ghi'}) + time.sleep(1) + server.publish(server.VIOLENT_REQUEST_END) + server.stop() + time.sleep(1) + + assert events == [ + SSEEvent('1', None, None, None), + SSEEvent('2', 'message', None, 'abc'), + SSEEvent('3', 'message', None, 'def'), + SSEEvent('4', 'message', None, 'ghi') + ] + + assert client._connection is None diff --git a/tests/util/test_threadutil.py b/tests/util/test_threadutil.py new file mode 100644 index 00000000..7473aa96 --- /dev/null +++ b/tests/util/test_threadutil.py @@ -0,0 +1,44 @@ +"""threading utilities unit tests.""" + +import time +import threading + +from splitio.util.threadutil import EventGroup + + +class EventGroupTests(object): + """EventGroup class test cases.""" + + def test_basic_functionality(self): + """Test basic functionality.""" + + def fun(event): #pylint:disable=missing-docstring + time.sleep(1) + event.set() + + group = EventGroup() + event1 = group.make_event() + event2 = group.make_event() + + task = threading.Thread(target=fun, args=(event1,)) + task.start() + group.wait(3) + assert event1.is_set() + assert not event2.is_set() + + group = EventGroup() + event1 = group.make_event() + event2 = group.make_event() + + task = threading.Thread(target=fun, args=(event2,)) + task.start() + group.wait(3) + assert not event1.is_set() + assert event2.is_set() + + group = EventGroup() + event1 = group.make_event() + event2 = group.make_event() + group.wait(3) + assert not event1.is_set() + assert not event2.is_set()