diff --git a/.circleci/config.yml b/.circleci/config.yml index d7d80ca..6f502d2 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,10 +8,7 @@ workflows: - test-3.7 - test-3.6 - test-3.5 - - test-3.4 - - test-2.7 - test-pypy3 - - test-pypy2 jobs: test-3.9: &test-template docker: @@ -31,11 +28,16 @@ jobs: command: | . venv/bin/activate pip install -e .[tests] + - run: + name: Typecheck + command: | + . venv/bin/activate + venv/bin/mypy -p smpplib - run: name: Run tests command: | . venv/bin/activate - pytest -v + pytest -v smpplib test-3.8: <<: *test-template docker: @@ -52,19 +54,7 @@ jobs: <<: *test-template docker: - image: python:3.5-alpine - test-3.4: - <<: *test-template - docker: - - image: python:3.4-alpine - test-2.7: - <<: *test-template - docker: - - image: python:2.7-alpine test-pypy3: <<: *test-template docker: - image: pypy:3-slim - test-pypy2: - <<: *test-template - docker: - - image: pypy:2-slim diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a058f13 --- /dev/null +++ b/Makefile @@ -0,0 +1,14 @@ + +venv: + python -m venv venv + +deps: + pip install -e .[tests] + +typecheck: + . venv/bin/activate + venv/bin/mypy -p smpplib + +test: + . venv/bin/activate + pytest -v smpplib diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..274bbc2 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,21 @@ +[mypy] + +strict_optional = True +no_implicit_optional = True +warn_unused_configs = True +strict_equality = True +warn_unused_ignores = True + +check_untyped_defs = True + +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +disallow_untyped_decorators = True +disallow_any_generics = True + +[mypy-smpplib/tests/*] + +disallow_untyped_calls = False +disallow_untyped_defs = False +disallow_incomplete_defs = False diff --git a/setup.py b/setup.py index 06ba978..20e2e7c 100644 --- a/setup.py +++ b/setup.py @@ -16,9 +16,8 @@ url='https://github.com/podshumok/python-smpplib', description='SMPP library for python', packages=find_packages(), - install_requires=['six'], extras_require=dict( - tests=('pytest', 'mock'), + tests=('pytest', 'mock', 'mypy', 'types-mock'), ), zip_safe=True, classifiers=( diff --git a/smpplib/client.py b/smpplib/client.py index b89af50..9a4ff1d 100644 --- a/smpplib/client.py +++ b/smpplib/client.py @@ -24,8 +24,11 @@ import socket import struct import warnings +from typing import Optional, Callable +from ssl import SSLContext -from smpplib import consts, exceptions, smpp +from smpplib import consts, exceptions, smpp, pdu +from typing import Any, Dict, NoReturn, List class SimpleSequenceGenerator(object): @@ -33,14 +36,14 @@ class SimpleSequenceGenerator(object): MIN_SEQUENCE = 0x00000001 MAX_SEQUENCE = 0x7FFFFFFF - def __init__(self): + def __init__(self) -> None: self._sequence = self.MIN_SEQUENCE @property - def sequence(self): + def sequence(self) -> int: return self._sequence - def next_sequence(self): + def next_sequence(self) -> int: if self._sequence == self.MAX_SEQUENCE: self._sequence = self.MIN_SEQUENCE else: @@ -53,23 +56,23 @@ class Client(object): state = consts.SMPP_CLIENT_STATE_CLOSED - host = None - port = None + host: str + port: int vendor = None - _socket = None + _socket: Optional[socket.socket] = None _ssl_context = None - sequence_generator = None + sequence_generator: Any def __init__( self, - host, - port, - timeout=5, - sequence_generator=None, - logger_name=None, - ssl_context=None, - allow_unknown_opt_params=None, - ): + host: str, + port: int, + timeout: float =5, + sequence_generator: Optional[Any]=None, + logger_name: Optional[str]=None, + ssl_context: Optional[SSLContext]=None, + allow_unknown_opt_params: Optional[bool]=None, + ) -> None: self.host = host self.port = int(port) self._ssl_context = ssl_context @@ -94,10 +97,10 @@ def __init__( self._socket = self._create_socket() - def __enter__(self): + def __enter__(self) -> "Client": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self._socket is not None: try: self.unbind() @@ -108,18 +111,18 @@ def __exit__(self, exc_type, exc_value, traceback): self.logger.warning('%s. Ignored', e) self.disconnect() - def __del__(self): + def __del__(self) -> None: if self._socket is not None: self.logger.warning('%s was not closed', self) @property - def sequence(self): + def sequence(self) -> int: return self.sequence_generator.sequence - def next_sequence(self): + def next_sequence(self) -> int: return self.sequence_generator.next_sequence() - def _create_socket(self): + def _create_socket(self) -> socket.socket: raw_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) raw_socket.settimeout(self.timeout) @@ -128,7 +131,7 @@ def _create_socket(self): return self._ssl_context.wrap_socket(raw_socket) - def connect(self): + def connect(self) -> None: """Connect to SMSC""" self.logger.info('Connecting to %s:%s...', self.host, self.port) @@ -141,7 +144,7 @@ def connect(self): except socket.error: raise exceptions.ConnectionError("Connection refused") - def disconnect(self): + def disconnect(self) -> None: """Disconnect from the SMSC""" self.logger.info('Disconnecting...') @@ -152,7 +155,7 @@ def disconnect(self): self._socket = None self.state = consts.SMPP_CLIENT_STATE_CLOSED - def _bind(self, command_name, **kwargs): + def _bind(self, command_name: str, **kwargs: Any) -> Any: """Send bind_transmitter command to the SMSC""" if command_name in ('bind_receiver', 'bind_transceiver'): @@ -174,19 +177,19 @@ def _bind(self, command_name, **kwargs): ) return resp - def bind_transmitter(self, **kwargs): + def bind_transmitter(self, **kwargs: Any) -> Any: """Bind as a transmitter""" return self._bind('bind_transmitter', **kwargs) - def bind_receiver(self, **kwargs): + def bind_receiver(self, **kwargs: Any) -> Any: """Bind as a receiver""" return self._bind('bind_receiver', **kwargs) - def bind_transceiver(self, **kwargs): + def bind_transceiver(self, **kwargs: Any) -> Any: """Bind as a transmitter and receiver at once""" return self._bind('bind_transceiver', **kwargs) - def unbind(self): + def unbind(self) -> Any: """Unbind from the SMSC""" p = smpp.make_pdu('unbind', client=self) @@ -197,7 +200,7 @@ def unbind(self): except socket.timeout: raise exceptions.ConnectionError() - def send_pdu(self, p): + def send_pdu(self, p: pdu.PDU) -> bool: """Send PDU to the SMSC""" if self.state not in consts.COMMAND_STATES[p.command]: @@ -216,6 +219,7 @@ def send_pdu(self, p): while sent < len(generated): try: + assert self._socket is not None sent_last = self._socket.send(generated[sent:]) except socket.error as e: self.logger.warning(e) @@ -226,12 +230,13 @@ def send_pdu(self, p): return True - def read_pdu(self): + def read_pdu(self) -> Any: """Read PDU from the SMSC""" self.logger.debug('Waiting for PDU...') try: + assert self._socket is not None raw_len = self._socket.recv(4) except socket.timeout: raise @@ -249,6 +254,7 @@ def read_pdu(self): raw_pdu = raw_len while len(raw_pdu) < length: + assert self._socket is not None raw_pdu += self._socket.recv(length - len(raw_pdu)) self.logger.debug('<<%s (%d bytes)', binascii.b2a_hex(raw_pdu), len(raw_pdu)) @@ -269,11 +275,11 @@ def read_pdu(self): return pdu - def accept(self, obj): + def accept(self, obj: Any) -> NoReturn: """Accept an object""" raise NotImplementedError('not implemented') - def _message_received(self, pdu): + def _message_received(self, pdu: pdu.PDU) -> None: """Handler for received message event""" status = self.message_received_handler(pdu=pdu) if status is None: @@ -282,56 +288,58 @@ def _message_received(self, pdu): dsmr.sequence = pdu.sequence self.send_pdu(dsmr) - def _enquire_link_received(self, pdu): + def _enquire_link_received(self, pdu: pdu.PDU) -> None: """Response to enquire_link""" ler = smpp.make_pdu('enquire_link_resp', client=self) ler.sequence = pdu.sequence self.send_pdu(ler) - def _alert_notification(self, pdu): + def _alert_notification(self, pdu: pdu.PDU) -> None: """Handler for alert notification event""" self.message_received_handler(pdu=pdu) - def set_message_received_handler(self, func): + def set_message_received_handler(self, func: Callable[..., Optional[int]]) -> None: """Set new function to handle message receive event""" - self.message_received_handler = func + self.message_received_handler = func # type: ignore - def set_message_sent_handler(self, func): + def set_message_sent_handler(self, func: Callable[..., None]) -> None: """Set new function to handle message sent event""" - self.message_sent_handler = func + self.message_sent_handler = func # type: ignore - def set_query_resp_handler(self, func): + def set_query_resp_handler(self, func: Callable[..., None]) -> None: """Set new function to handle query resp event""" - self.query_resp_handler = func + self.query_resp_handler = func # type: ignore - def set_error_pdu_handler(self, func): + def set_error_pdu_handler(self, func: Callable[..., None]) -> None: """Set new function to handle PDUs with an error status""" - self.error_pdu_handler = func + self.error_pdu_handler = func # type: ignore - def message_received_handler(self, pdu, **kwargs): + def message_received_handler(self, pdu: pdu.PDU, **kwargs: Any) -> Optional[int]: """Custom handler to process received message. May be overridden""" self.logger.warning('Message received handler (Override me)') + return None - def message_sent_handler(self, pdu, **kwargs): + def message_sent_handler(self, pdu: pdu.PDU, **kwargs: Any) -> None: """ Called when SMPP server accept message (SUBMIT_SM_RESP). May be overridden """ self.logger.warning('Message sent handler (Override me)') - def query_resp_handler(self, pdu, **kwargs): + def query_resp_handler(self, pdu: pdu.PDU, **kwargs: Any) -> None: """Custom handler to process response to queries. May be overridden""" self.logger.warning('Query resp handler (Override me)') - def error_pdu_handler(self, pdu): + def error_pdu_handler(self, pdu: pdu.PDU) -> NoReturn: raise exceptions.PDUError('({}) {}: {}'.format( pdu.status, pdu.command, + 'Unknown status' if pdu.status is None else consts.DESCRIPTIONS.get(pdu.status, 'Unknown status')), - int(pdu.status), + pdu.status, ) - def read_once(self, ignore_error_codes=None, auto_send_enquire_link=True): + def read_once(self, ignore_error_codes: Optional[List[int]]=None, auto_send_enquire_link: bool=True) -> None: """Read a PDU and act""" if ignore_error_codes is not None: @@ -378,7 +386,7 @@ def read_once(self, ignore_error_codes=None, auto_send_enquire_link=True): else: raise - def poll(self, ignore_error_codes=None, auto_send_enquire_link=True): + def poll(self, ignore_error_codes: Optional[List[int]]=None, auto_send_enquire_link: bool=True) -> None: """Act on available PDUs and return""" while True: readable, _writable, _exceptional = select.select([self._socket], [], [], 0) @@ -386,12 +394,12 @@ def poll(self, ignore_error_codes=None, auto_send_enquire_link=True): break self.read_once(ignore_error_codes, auto_send_enquire_link) - def listen(self, ignore_error_codes=None, auto_send_enquire_link=True): + def listen(self, ignore_error_codes: Optional[List[int]]=None, auto_send_enquire_link: bool=True) -> None: """Listen for PDUs and act""" while True: self.read_once(ignore_error_codes, auto_send_enquire_link) - def send_message(self, **kwargs): + def send_message(self, **kwargs: Any) -> Any: """Send message Required Arguments: @@ -406,7 +414,7 @@ def send_message(self, **kwargs): self.send_pdu(ssm) return ssm - def query_message(self, **kwargs): + def query_message(self, **kwargs: Any) -> Any: """Query message state Required Arguments: diff --git a/smpplib/command.py b/smpplib/command.py index 5bc6379..5478bcf 100644 --- a/smpplib/command.py +++ b/smpplib/command.py @@ -21,16 +21,16 @@ import logging import struct - -import six +from typing import Dict, Tuple, Any, Optional, Callable from smpplib import consts, exceptions, pdu from smpplib.ptypes import flag, ostr +from typing import NoReturn, TypeVar logger = logging.getLogger('smpplib.command') -def factory(command_name, **kwargs): +def factory(command_name: str, **kwargs: Any) -> pdu.PDU: """Return instance of a specific command class""" try: @@ -60,18 +60,18 @@ def factory(command_name, **kwargs): raise exceptions.UnknownCommandError('Command "%s" is not supported' % command_name) -def get_optional_name(code): +def get_optional_name(code: int) -> str: """Return optional_params name by given code. If code is unknown, raise UnkownCommandError exception""" - for key, value in six.iteritems(consts.OPTIONAL_PARAMS): + for key, value in consts.OPTIONAL_PARAMS.items(): if value == code: return key raise exceptions.UnknownCommandError('Unknown SMPP command code "0x%x"' % code) -def get_optional_code(name): +def get_optional_code(name: str) -> int: """Return optional_params code by given command name. If name is unknown, raise UnknownCommandError exception""" @@ -81,16 +81,17 @@ def get_optional_code(name): raise exceptions.UnknownCommandError('Unknown SMPP command name "%s"' % name) -def unpack_short(data, pos): +def unpack_short(data: bytes, pos: int) -> Tuple[int, int]: return struct.unpack('>H', data[pos:pos+2])[0], pos + 2 class Command(pdu.PDU): """SMPP PDU Command class""" - params = {} + params: Dict[str, "Param"] = {} + params_order: Tuple[str, ...] - def __init__(self, command, need_sequence=True, allow_unknown_opt_params=False, **kwargs): + def __init__(self, command: str, need_sequence: bool=True, allow_unknown_opt_params: bool=False, **kwargs: Any) -> None: super(Command, self).__init__(**kwargs) self.allow_unknown_opt_params = allow_unknown_opt_params @@ -104,17 +105,17 @@ def __init__(self, command, need_sequence=True, allow_unknown_opt_params=False, self._set_vars(**kwargs) - def _set_vars(self, **kwargs): + def _set_vars(self, **kwargs: Any) -> None: """set attributes accordingly to kwargs""" - for key, value in six.iteritems(kwargs): + for key, value in kwargs.items(): if not hasattr(self, key) or getattr(self, key) is None: setattr(self, key, value) - def generate_params(self): + def generate_params(self) -> bytes: """Generate binary data from the object""" - if hasattr(self, 'prep') and callable(self.prep): - self.prep() + if hasattr(self, 'prep') and callable(self.prep): # type: ignore + self.prep() # type: ignore body = consts.EMPTY_STRING @@ -146,87 +147,93 @@ def generate_params(self): body += value return body - def _generate_opt_header(self, field): + def _generate_opt_header(self, field: str) -> NoReturn: """Generate a header for an optional parameter""" raise NotImplementedError('Vendors not supported') - def _generate_int(self, field): + def _generate_int(self, field: str) -> bytes: """Generate integer value""" fmt = self._int_pack_format(field) - data = getattr(self, field) + data: int = getattr(self, field) if data: return struct.pack(">" + fmt, data) else: return consts.NULL_STRING - def _generate_string(self, field): + def _generate_string(self, field: str) -> bytes: """Generate string value""" - field_value = getattr(self, field) + field_value: bytes = getattr(self, field) if hasattr(self.params[field], 'size'): - size = self.params[field].size - value = field_value.ljust(size, chr(0)) + size = self.params[field].size # type: ignore + value = field_value.ljust(size, consts.NULL_STRING) elif hasattr(self.params[field], 'max'): - if len(field_value or '') >= self.params[field].max: - field_value = field_value[0:self.params[field].max - 1] + if len(field_value or '') >= self.params[field].max: # type: ignore + field_value = field_value[0:self.params[field].max - 1] # type: ignore if field_value: - value = field_value + chr(0) + value = field_value + consts.NULL_STRING else: - value = chr(0) + value = consts.NULL_STRING + else: + assert False, "Param must have either size or max." setattr(self, field, field_value) - return six.b(value) + return value - def _generate_ostring(self, field): + def _generate_ostring(self, field: str) -> Optional[bytes]: """Generate octet string value (no null terminator)""" - value = getattr(self, field) + value: bytes = getattr(self, field) if value: return value else: - return None # chr(0) + return None # consts.NULL_STRING - def _generate_int_tlv(self, field): + def _generate_int_tlv(self, field: str) -> Optional[bytes]: """Generate integer value""" fmt = self._int_pack_format(field) - data = getattr(self, field) + data: int = getattr(self, field) field_code = get_optional_code(field) - field_length = self.params[field].size + field_length = self.params[field].size # type: ignore value = None if data is not None: value = struct.pack(">HH" + fmt, field_code, field_length, data) return value - def _generate_string_tlv(self, field): + def _generate_string_tlv(self, field: str) -> Optional[bytes]: """Generate string value""" - field_value = getattr(self, field) + field_value: bytes = getattr(self, field) field_code = get_optional_code(field) + value: Optional[bytes] if hasattr(self.params[field], 'size'): - size = self.params[field].size - fvalue = field_value.ljust(size, chr(0)) + size = self.params[field].size # type: ignore + fvalue = field_value.ljust(size, consts.NULL_STRING) value = struct.pack(">HH", field_code, size) + fvalue elif hasattr(self.params[field], 'max'): - if len(field_value or '') > self.params[field].max: - field_value = field_value[0:self.params[field].max - 1] + if len(field_value or '') > self.params[field].max: # type: ignore + field_value = field_value[0:self.params[field].max - 1] # type: ignore if field_value: - fvalue = field_value + chr(0) + fvalue = field_value + consts.NULL_STRING field_length = len(fvalue) - value = struct.pack(">HH", field_code, field_length) + fvalue.encode() + value = struct.pack(">HH", field_code, field_length) + fvalue else: - value = None # chr(0) + value = None # consts.NULL_STRING + else: + assert False, "Param must have either size or max." + return value - def _generate_ostring_tlv(self, field): + def _generate_ostring_tlv(self, field: str) -> Optional[bytes]: """Generate octet string value (no null terminator)""" try: - field_value = getattr(self, field) + field_value: bytes = getattr(self, field) except: return None field_code = get_optional_code(field) @@ -237,17 +244,17 @@ def _generate_ostring_tlv(self, field): value = struct.pack(">HH", field_code, field_length) + field_value return value - def _int_pack_format(self, field): + def _int_pack_format(self, field: str) -> str: """Return format type""" - return consts.INT_PACK_FORMATS[self.params[field].size] + return consts.INT_PACK_FORMATS[self.params[field].size] # type: ignore - def _parse_int(self, field, data, pos): + def _parse_int(self, field: str, data: bytes, pos: int) -> Tuple[bytes, int]: """ Parse fixed-length chunk from a PDU. Return (data, pos) tuple. """ - size = self.params[field].size + size = self.params[field].size # type: ignore fmt = self._int_pack_format(field) field_value, = struct.unpack(">" + fmt, data[pos:pos + size]) setattr(self, field, field_value) @@ -255,7 +262,7 @@ def _parse_int(self, field, data, pos): return data, pos - def _parse_string(self, field, data, pos, length=None): + def _parse_string(self, field: str, data: bytes, pos: int, length: Optional[int]=None) -> Tuple[bytes, int]: """ Parse variable-length string from a PDU. Return (data, pos) tuple. @@ -272,14 +279,14 @@ def _parse_string(self, field, data, pos, length=None): return data, pos - def _parse_ostring(self, field, data, pos, length=None): + def _parse_ostring(self, field: str, data: bytes, pos: int, length: Optional[int]=None) -> Tuple[bytes, int]: """ Parse an octet string from a PDU. Return (data, pos) tuple. """ if length is None: - length_field = self.params[field].len_field + length_field = self.params[field].len_field # type: ignore length = int(getattr(self, length_field)) setattr(self, field, data[pos:pos + length]) @@ -287,14 +294,14 @@ def _parse_ostring(self, field, data, pos, length=None): return data, pos - def is_fixed(self, field): + def is_fixed(self, field: str) -> bool: """Return True if field has fixed length, False otherwise""" if hasattr(self.params[field], 'size'): return True return False - def parse_params(self, data): + def parse_params(self, data: bytes) -> None: """Parse data into the object structure""" pos = 0 @@ -314,7 +321,7 @@ def parse_params(self, data): if pos < dlen: self.parse_optional_params(data[pos:]) - def parse_optional_params(self, data): + def parse_optional_params(self, data: bytes) -> None: """Parse optional parameters. Optional parameters have the following format: @@ -346,14 +353,14 @@ def parse_optional_params(self, data): elif param.type is ostr: data, pos = self._parse_ostring(field, data, pos, length) - def field_exists(self, field): + def field_exists(self, field: str) -> bool: """Return True if field exists, False otherwise""" return hasattr(self.params, field) - def field_is_optional(self, field): + def field_is_optional(self, field: str) -> bool: """Return True if field is optional, False otherwise""" - if hasattr(self, 'mandatory_fields') and field in self.mandatory_fields: + if hasattr(self, 'mandatory_fields') and field in self.mandatory_fields: # type: ignore return False elif field in consts.OPTIONAL_PARAMS: return True @@ -363,11 +370,16 @@ def field_is_optional(self, field): return False + def __repr__(self) -> str: + args = ', '.join(p + ":" + str(getattr(self, p)) for p in self.params_order) + return f'<{self.command} {args}>' + + class Param(object): """Command parameter info class""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: if 'type' not in kwargs: raise KeyError('Parameter Type not defined') @@ -385,7 +397,7 @@ def __init__(self, **kwargs): if param in kwargs: setattr(self, param, kwargs[param]) - def __repr__(self): + def __repr__(self) -> str: """Shows type of Param in console""" return ''.join(('')) @@ -409,7 +421,7 @@ class BindTransmitter(Command): 'interface_version', 'addr_ton', 'addr_npi', 'address_range', ) - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(BindTransmitter, self).__init__(command, need_sequence=False, **kwargs) self._set_vars(**(dict.fromkeys(self.params))) @@ -418,13 +430,13 @@ def __init__(self, command, **kwargs): class BindReceiver(BindTransmitter): """Bind as a receiver command""" - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(BindReceiver, self).__init__(command, **kwargs) class BindTransceiver(BindTransmitter): """Bind as receiver and transmitter command""" - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(BindTransceiver, self).__init__(command, **kwargs) @@ -438,7 +450,7 @@ class BindTransmitterResp(Command): params_order = ('system_id', 'sc_interface_version') - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(BindTransmitterResp, self).__init__(command, need_sequence=False, **kwargs) @@ -447,13 +459,13 @@ def __init__(self, command, **kwargs): class BindReceiverResp(BindTransmitterResp): """Response for bind as a reciever command""" - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(BindReceiverResp, self).__init__(command, **kwargs) class BindTransceiverResp(BindTransmitterResp): """Response for bind as a transceiver command""" - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(BindTransceiverResp, self).__init__(command, **kwargs) @@ -533,7 +545,7 @@ class DataSM(Command): 'its_session_info', ) - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(DataSM, self).__init__(command, **kwargs) self._set_vars(**(dict.fromkeys(self.params))) @@ -559,7 +571,7 @@ class DataSMResp(Command): 'dpf_result', ) - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(DataSMResp, self).__init__(command, **kwargs) self._set_vars(**(dict.fromkeys(self.params))) @@ -567,9 +579,10 @@ def __init__(self, command, **kwargs): class GenericNAck(Command): """General Negative Acknowledgement class""" - _defs = [] + # TODO: seems unused + _defs: Any = [] - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(GenericNAck, self).__init__(command, need_sequence=False, **kwargs) @@ -697,7 +710,7 @@ class SubmitSM(Command): 'ussd_service_op': Param(type=int, size=1), } - params_order = ( + params_order: Tuple[str, ...] = ( 'service_type', 'source_addr_ton', 'source_addr_npi', 'source_addr', 'dest_addr_ton', 'dest_addr_npi', 'destination_addr', 'esm_class', 'protocol_id', 'priority_flag', @@ -718,11 +731,11 @@ class SubmitSM(Command): 'ussd_service_op', ) - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(SubmitSM, self).__init__(command, **kwargs) self._set_vars(**(dict.fromkeys(self.params))) - def prep(self): + def prep(self) -> None: """Prepare to generate binary data""" if self.short_message: @@ -742,7 +755,7 @@ class SubmitSMResp(Command): params_order = ('message_id',) - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(SubmitSMResp, self).__init__(command, need_sequence=False, **kwargs) self._set_vars(**(dict.fromkeys(self.params))) @@ -814,7 +827,7 @@ class DeliverSM(SubmitSM): 'source_network_type', 'dest_network_type', 'more_messages_to_send', ) - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(DeliverSM, self).__init__(command, need_sequence=False, **kwargs) self._set_vars(**(dict.fromkeys(self.params))) @@ -823,7 +836,7 @@ class DeliverSMResp(SubmitSMResp): """deliver_sm_response response class, same as submit_sm""" message_id = None - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(DeliverSMResp, self).__init__(command, **kwargs) class QuerySM(Command): @@ -858,11 +871,11 @@ class QuerySM(Command): 'source_addr', ) - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(QuerySM, self).__init__(command, **kwargs) self._set_vars(**(dict.fromkeys(self.params))) - def prep(self): + def prep(self) -> None: """Prepare to generate binary data""" if not self.message_id: @@ -872,6 +885,7 @@ def prep(self): class QuerySMResp(Command): """Response command for query_sm""" + # TODO: this seems like a bug, misssing a , to make it a tuple mandatory_fields = ('message_state') params = { @@ -886,7 +900,7 @@ class QuerySMResp(Command): 'error_code', ) - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(QuerySMResp, self).__init__(command, need_sequence=False, **kwargs) self._set_vars(**(dict.fromkeys(self.params))) @@ -894,38 +908,38 @@ def __init__(self, command, **kwargs): class Unbind(Command): """Unbind command""" - params = {} + params: Dict[str, Param] = {} params_order = () - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(Unbind, self).__init__(command, need_sequence=False, **kwargs) class UnbindResp(Command): """Unbind response command""" - params = {} + params: Dict[str, Param] = {} params_order = () - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(UnbindResp, self).__init__(command, need_sequence=False, **kwargs) class EnquireLink(Command): """Enquire link command""" - params = {} + params: Dict[str, Param] = {} params_order = () - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(EnquireLink, self).__init__(command, need_sequence=False, **kwargs) class EnquireLinkResp(Command): """Enquire link command response""" - params = {} + params: Dict[str, Param] = {} params_order = () - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(EnquireLinkResp, self).__init__(command, need_sequence=False, **kwargs) @@ -973,6 +987,6 @@ class AlertNotification(Command): 'ms_availability_status', ) - def __init__(self, command, **kwargs): + def __init__(self, command: str, **kwargs: Any) -> None: super(AlertNotification, self).__init__(command, **kwargs) self._set_vars(**(dict.fromkeys(self.params))) diff --git a/smpplib/command_codes.py b/smpplib/command_codes.py index eab6868..677f01b 100644 --- a/smpplib/command_codes.py +++ b/smpplib/command_codes.py @@ -1,5 +1,3 @@ -import six - from smpplib import exceptions # @@ -36,20 +34,20 @@ } -def get_command_name(code): +def get_command_name(code: int) -> str: """ Return command name by given code. If code is unknown, raise UnknownCommandError exception. """ - for key, value in six.iteritems(commands): + for key, value in commands.items(): if value == code: return key raise exceptions.UnknownCommandError("Unknown SMPP command code '0x%x'" % code) -def get_command_code(name): +def get_command_code(name: str) -> int: """ Return command code by given command name. If name is unknown, raise UnknownCommandError exception. diff --git a/smpplib/gsm.py b/smpplib/gsm.py index c3336a4..37e3057 100644 --- a/smpplib/gsm.py +++ b/smpplib/gsm.py @@ -1,12 +1,11 @@ -# -*- coding: utf8 -*- +# -*- coding: utf-8 -*- import random -import six - from smpplib import consts, exceptions +from typing import Any, List, Tuple, TypeVar, Union -def make_parts(text, encoding=consts.SMPP_ENCODING_DEFAULT, use_udhi=True): +def make_parts(text: str, encoding: int=consts.SMPP_ENCODING_DEFAULT, use_udhi: bool=True) -> Tuple[Any, int, int]: """Returns tuple(parts, encoding, esm_class)""" try: # Try to encode with the user-defined encoding first. @@ -51,11 +50,11 @@ def make_parts(text, encoding=consts.SMPP_ENCODING_DEFAULT, use_udhi=True): ) -def gsm_encode(plaintext): +def gsm_encode(plaintext: str) -> bytes: """Performs default GSM 7-bit encoding. Beware it's vendor-specific and not recommended for use.""" try: return b''.join( - six.int2byte(index) if index < 0x80 else b'\x1B' + six.int2byte(index - 0x80) + bytes((index, )) if index < 0x80 else b'\x1B' + bytes((index - 0x80, )) for index in map(GSM_CHARACTER_TABLE.index, plaintext) ) except ValueError: @@ -71,18 +70,18 @@ def gsm_encode(plaintext): } -def make_parts_encoded(encoded_text, part_size): +def make_parts_encoded(encoded_text: bytes, part_size: int) -> List[bytes]: """Splits encoded text into SMS parts""" chunks = split_sequence(encoded_text, part_size) if len(chunks) > 255: raise exceptions.MessageTooLong() uid = random.randint(0, 255) - header = b''.join((b'\x05\x00\x03', six.int2byte(uid), six.int2byte(len(chunks)))) + header = b''.join((b'\x05\x00\x03', bytes((uid, )), bytes((len(chunks), )))) - return [b''.join((header, six.int2byte(i), chunk)) for i, chunk in enumerate(chunks, start=1)] + return [b''.join((header, bytes((i, )), chunk)) for i, chunk in enumerate(chunks, start=1)] -def split_sequence(sequence, part_size): +def split_sequence(sequence: bytes, part_size: int) -> List[bytes]: """Splits the sequence into equal parts""" return [sequence[i:i + part_size] for i in range(0, len(sequence), part_size)] diff --git a/smpplib/pdu.py b/smpplib/pdu.py index 4557210..b3db084 100644 --- a/smpplib/pdu.py +++ b/smpplib/pdu.py @@ -19,12 +19,17 @@ """PDU module""" import struct +from typing import TYPE_CHECKING, Union from smpplib import command_codes, consts from smpplib.consts import SMPP_ESME_ROK +from typing import Any, Optional +if TYPE_CHECKING: + from smpplib.client import Client -def extract_command(pdu): + +def extract_command(pdu: Any) -> str: """Extract command from a PDU""" code = struct.unpack('>L', pdu[4:8])[0] @@ -36,62 +41,67 @@ class default_client(object): """Dummy client""" sequence = 0 + def next_sequence(self) -> int: + raise NotImplementedError() + class PDU(object): """PDU class""" length = 0 - command = None - status = None - _sequence = None + command: str + status: Optional[int] = None + _sequence: Optional[int]= None + _client: Union["Client", default_client] - def __init__(self, client=default_client(), **kwargs): + def __init__(self, client: Union["Client", default_client]=default_client(), **kwargs: Any) -> None: """Singleton dummy client will be used if omitted""" if client is None: self._client = default_client() else: self._client = client - def _get_sequence(self): + def _get_sequence(self) -> int: """Return global sequence number""" return self._sequence if self._sequence is not None else \ self._client.sequence - def _set_sequence(self, sequence): + def _set_sequence(self, sequence: int) -> None: """Setter for sequence""" self._sequence = sequence sequence = property(_get_sequence, _set_sequence) - def _next_seq(self): + def _next_seq(self) -> int: """Return next sequence number""" return self._client.next_sequence() - def is_vendor(self): + def is_vendor(self) -> bool: """Return True if this is a vendor PDU, False otherwise""" return hasattr(self, 'vendor') - def is_request(self): + def is_request(self) -> bool: """Return True if this is a request PDU, False otherwise""" return not self.is_response() - def is_response(self): + def is_response(self) -> bool: """Return True if this is a response PDU, False otherwise""" if command_codes.get_command_code(self.command) & 0x80000000: return True return False - def is_error(self): + def is_error(self) -> bool: """Return True if this is an error response, False otherwise""" if self.status != SMPP_ESME_ROK: return True return False - def get_status_desc(self, status=None): + def get_status_desc(self, status: Optional[int]=None) -> str: """Return status description""" if status is None: status = self.status + assert status is not None try: desc = consts.DESCRIPTIONS[status] @@ -100,7 +110,13 @@ def get_status_desc(self, status=None): return desc - def parse(self, data): + def parse_params(self, data: bytes) -> None: + raise NotImplementedError() + + def generate_params(self) -> bytes: + raise NotImplementedError() + + def parse(self, data: bytes) -> None: """Parse raw PDU""" # @@ -126,7 +142,7 @@ def parse(self, data): if len(data) > 16: self.parse_params(data[16:]) - def generate(self): + def generate(self) -> bytes: """Generate raw PDU""" body = self.generate_params() diff --git a/smpplib/smpp.py b/smpplib/smpp.py index c93b3f5..0ff81e3 100644 --- a/smpplib/smpp.py +++ b/smpplib/smpp.py @@ -19,9 +19,10 @@ """SMPP module""" from smpplib import command, pdu +from typing import Any -def make_pdu(command_name, **kwargs): +def make_pdu(command_name: str, **kwargs: Any) -> pdu.PDU: """Return PDU instance""" f = command.factory(command_name, **kwargs) @@ -29,7 +30,7 @@ def make_pdu(command_name, **kwargs): return f -def parse_pdu(data, **kwargs): +def parse_pdu(data: bytes, **kwargs: Any) -> pdu.PDU: """Parse binary PDU""" command = pdu.extract_command(data) diff --git a/smpplib/tests/test_client.py b/smpplib/tests/test_client.py index a10666f..1051a3c 100644 --- a/smpplib/tests/test_client.py +++ b/smpplib/tests/test_client.py @@ -10,6 +10,9 @@ def test_client_construction_allow_unknown_opt_params_warning(): with warnings.catch_warnings(record=True) as w: + # TODO: we should probably switch to assertWarns if we drop python 2 + # support entirely + warnings.simplefilter("always") client = Client("localhost", 5679) assert len(w) == 1 @@ -21,7 +24,7 @@ def test_client_error_pdu_default(): client = Client("localhost", 5679) error_pdu = make_pdu("submit_sm_resp") error_pdu.status = consts.SMPP_ESME_RINVMSGLEN - client.read_pdu = Mock(return_value=error_pdu) + client.read_pdu = Mock(return_value=error_pdu) # type: ignore with pytest.raises(exceptions.PDUError) as exec_info: client.read_once() @@ -36,7 +39,7 @@ def test_client_error_pdu_custom_handler(): client = Client("localhost", 5679) error_pdu = make_pdu("submit_sm_resp") error_pdu.status = consts.SMPP_ESME_RINVMSGLEN - client.read_pdu = Mock(return_value=error_pdu) + client.read_pdu = Mock(return_value=error_pdu) # type: ignore mock_error_pdu_handler = Mock() client.set_error_pdu_handler(mock_error_pdu_handler) diff --git a/smpplib/tests/test_command.py b/smpplib/tests/test_command.py index a0e2433..1cf2ae6 100644 --- a/smpplib/tests/test_command.py +++ b/smpplib/tests/test_command.py @@ -1,12 +1,43 @@ from smpplib import consts, exceptions -from smpplib.command import DeliverSM +from smpplib.command import DeliverSM, SubmitSM, SubmitSMResp import pytest +from mock import Mock +def test_parse_submit_sm(): + # Example from smpp.org + raw = bytes.fromhex( + "000000480000000400000000000000020005004d656c726f73654c61627300" + "01013434373731323334353637380000000000000100000010" + "48656c6c6f20576f726c64201b650201" + ) + pdu = SubmitSM('submit_sm', client=Mock()) + pdu.parse(raw) + + assert pdu.source_addr == b'MelroseLabs' + assert pdu.destination_addr == b'447712345678' + assert pdu.data_coding == consts.SMPP_ENCODING_DEFAULT + assert pdu.short_message == b'Hello World \x1be\x02\x01' + + assert pdu.generate() == raw + + +def test_parse_submit_sm_resp(): + # Another example from smpp.org + raw = bytes.fromhex( + "00000051800000040000000000000002" + "30393537326130613039626337336632653930653933386263366561386361326463663" + "06364343562343039383165343632396638343035353534376561333100" + ) + pdu = SubmitSMResp('submit_sm_resp', client=Mock()) + pdu.parse(raw) + + assert pdu.message_id == b'09572a0a09bc73f2e90e938bc6ea8ca2dcf0cd45b40981e4629f84055547ea31' # type: ignore + + assert pdu.generate() == raw def test_parse_deliver_sm(): - pdu = DeliverSM('deliver_sm') - pdu.parse( + raw = ( b"\x00\x00\x00\xcb\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x01\x00" b"\x01\x0131600000000\x00\x05\x00XXX YYYY\x00\x04\x00\x00\x00\x00\x00" b"\x00\x00\x00\x00\x00\x0e\x00\x01\x01\x00\x06\x00\x01\x01\x00\x1e\x00" @@ -14,16 +45,21 @@ def test_parse_deliver_sm(): b" dlvrd:001 submit date:1810151907 done date:1810151907 stat:DELIVRD" b" err:000 text:\x04\x1f\x04@\x048\x042\x045\x04B\x04&\x00\x01\x01" ) + pdu = DeliverSM('deliver_sm') + pdu.parse(raw) assert pdu.source_addr_ton == consts.SMPP_TON_INTL assert pdu.source_addr_npi == consts.SMPP_NPI_ISDN assert pdu.source_addr == b'31600000000' assert pdu.destination_addr == b'XXX YYYY' - assert pdu.receipted_message_id == b'1d305b4c' - assert pdu.source_network_type == consts.SMPP_NETWORK_TYPE_GSM - assert pdu.message_state == consts.SMPP_MESSAGE_STATE_DELIVERED - assert pdu.user_message_reference is None + assert pdu.receipted_message_id == b'1d305b4c' # type: ignore + assert pdu.source_network_type == consts.SMPP_NETWORK_TYPE_GSM # type: ignore + assert pdu.message_state == consts.SMPP_MESSAGE_STATE_DELIVERED # type: ignore + assert pdu.user_message_reference is None # type: ignore + # TODO: not sure why this doesn't re-generate the raw input, but it seems + # worth having this test anyway. + assert pdu.generate() == b"\x00\x00\x00\xcb\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01\x0131600000000\x00\x05\x00XXX YYYY\x00\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04$\x00rid:0489708364 sub:001 dlvrd:001 submit date:1810151907 done date:1810151907 stat:DELIVRD err:000 text:\x04\x1f\x04@\x048\x042\x045\x04B\x04'\x00\x01\x02\x00\x1e\x00\t1d305b4c\x00\x00\x0e\x00\x01\x01\x00\x06\x00\x01\x01\x04&\x00\x01\x01" def test_unrecognised_optional_parameters(): pdu = DeliverSM("deliver_sm", allow_unknown_opt_params=True) diff --git a/smpplib/tests/test_gsm.py b/smpplib/tests/test_gsm.py index b6641fd..44b8c66 100644 --- a/smpplib/tests/test_gsm.py +++ b/smpplib/tests/test_gsm.py @@ -1,4 +1,4 @@ -# -*- coding: utf8 -*- +# -*- coding: utf-8 -*- import mock from pytest import mark, raises