Skip to content

Commit 522811f

Browse files
authored
feat: use literal in type hints (docarray#1827)
Signed-off-by: Ben Shaver <[email protected]>
1 parent d5d928b commit 522811f

5 files changed

Lines changed: 41 additions & 36 deletions

File tree

docarray/array/doc_list/io.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
_dict_to_access_paths,
3737
)
3838
from docarray.utils._internal.compress import _decompress_bytes, _get_compress_ctx
39-
from docarray.utils._internal.misc import import_library
39+
from docarray.utils._internal.misc import import_library, ProtocolType
4040

4141
if TYPE_CHECKING:
4242
import pandas as pd
@@ -57,9 +57,9 @@
5757

5858
def _protocol_and_compress_from_file_path(
5959
file_path: Union[pathlib.Path, str],
60-
default_protocol: Optional[str] = None,
60+
default_protocol: Optional[ProtocolType] = None,
6161
default_compress: Optional[str] = None,
62-
) -> Tuple[Optional[str], Optional[str]]:
62+
) -> Tuple[Optional[ProtocolType], Optional[str]]:
6363
"""Extract protocol and compression algorithm from a string, use defaults if not found.
6464
:param file_path: path of a file.
6565
:param default_protocol: default serialization protocol used in case not found.
@@ -79,7 +79,7 @@ def _protocol_and_compress_from_file_path(
7979
file_extensions = [e.replace('.', '') for e in pathlib.Path(file_path).suffixes]
8080
for extension in file_extensions:
8181
if extension in ALLOWED_PROTOCOLS:
82-
protocol = extension
82+
protocol = cast(ProtocolType, extension)
8383
elif extension in ALLOWED_COMPRESSIONS:
8484
compress = extension
8585

@@ -135,7 +135,7 @@ def to_protobuf(self) -> 'DocListProto':
135135
def from_bytes(
136136
cls: Type[T],
137137
data: bytes,
138-
protocol: str = 'protobuf-array',
138+
protocol: ProtocolType = 'protobuf-array',
139139
compress: Optional[str] = None,
140140
show_progress: bool = False,
141141
) -> T:
@@ -157,7 +157,7 @@ def from_bytes(
157157
def _write_bytes(
158158
self,
159159
bf: BinaryIO,
160-
protocol: str = 'protobuf-array',
160+
protocol: ProtocolType = 'protobuf-array',
161161
compress: Optional[str] = None,
162162
show_progress: bool = False,
163163
) -> None:
@@ -201,7 +201,7 @@ def _write_bytes(
201201

202202
def _to_binary_stream(
203203
self,
204-
protocol: str = 'protobuf',
204+
protocol: ProtocolType = 'protobuf',
205205
compress: Optional[str] = None,
206206
show_progress: bool = False,
207207
) -> Iterator[bytes]:
@@ -241,7 +241,7 @@ def _to_binary_stream(
241241

242242
def to_bytes(
243243
self,
244-
protocol: str = 'protobuf-array',
244+
protocol: ProtocolType = 'protobuf-array',
245245
compress: Optional[str] = None,
246246
file_ctx: Optional[BinaryIO] = None,
247247
show_progress: bool = False,
@@ -273,7 +273,7 @@ def to_bytes(
273273
def from_base64(
274274
cls: Type[T],
275275
data: str,
276-
protocol: str = 'protobuf-array',
276+
protocol: ProtocolType = 'protobuf-array',
277277
compress: Optional[str] = None,
278278
show_progress: bool = False,
279279
) -> T:
@@ -294,7 +294,7 @@ def from_base64(
294294

295295
def to_base64(
296296
self,
297-
protocol: str = 'protobuf-array',
297+
protocol: ProtocolType = 'protobuf-array',
298298
compress: Optional[str] = None,
299299
show_progress: bool = False,
300300
) -> str:
@@ -383,7 +383,6 @@ def _from_csv_file(
383383
file: Union[StringIO, TextIOWrapper],
384384
dialect: Union[str, csv.Dialect],
385385
) -> 'T':
386-
387386
rows = csv.DictReader(file, dialect=dialect)
388387

389388
doc_type = cls.doc_type
@@ -576,7 +575,7 @@ def _get_proto_class(cls: Type[T]):
576575
def _load_binary_all(
577576
cls: Type[T],
578577
file_ctx: Union[ContextManager[io.BufferedReader], ContextManager[bytes]],
579-
protocol: Optional[str],
578+
protocol: Optional[ProtocolType],
580579
compress: Optional[str],
581580
show_progress: bool,
582581
tensor_type: Optional[Type['AbstractTensor']] = None,
@@ -659,7 +658,9 @@ def _load_binary_all(
659658
start_pos = end_doc_pos
660659

661660
# variable length bytes doc
662-
load_protocol: str = protocol or 'protobuf'
661+
load_protocol: ProtocolType = protocol or cast(
662+
ProtocolType, 'protobuf'
663+
)
663664
doc = cls.doc_type.from_bytes(
664665
d[start_doc_pos:end_doc_pos],
665666
protocol=load_protocol,
@@ -680,7 +681,7 @@ def _load_binary_all(
680681
def _load_binary_stream(
681682
cls: Type[T],
682683
file_ctx: ContextManager[io.BufferedReader],
683-
protocol: str = 'protobuf',
684+
protocol: ProtocolType = 'protobuf',
684685
compress: Optional[str] = None,
685686
show_progress: bool = False,
686687
) -> Generator['T_doc', None, None]:
@@ -728,7 +729,7 @@ def _load_binary_stream(
728729
len_current_doc_in_bytes = int.from_bytes(
729730
f.read(4), 'big', signed=False
730731
)
731-
load_protocol: str = protocol
732+
load_protocol: ProtocolType = protocol
732733
yield cls.doc_type.from_bytes(
733734
f.read(len_current_doc_in_bytes),
734735
protocol=load_protocol,
@@ -743,10 +744,12 @@ def _load_binary_stream(
743744
@staticmethod
744745
def _get_file_context(
745746
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
746-
protocol: str,
747+
protocol: ProtocolType,
747748
compress: Optional[str] = None,
748-
) -> Tuple[Union[nullcontext, io.BufferedReader], Optional[str], Optional[str]]:
749-
load_protocol: Optional[str] = protocol
749+
) -> Tuple[
750+
Union[nullcontext, io.BufferedReader], Optional[ProtocolType], Optional[str]
751+
]:
752+
load_protocol: Optional[ProtocolType] = protocol
750753
load_compress: Optional[str] = compress
751754
file_ctx: Union[nullcontext, io.BufferedReader]
752755
if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)):
@@ -765,7 +768,7 @@ def _get_file_context(
765768
def load_binary(
766769
cls: Type[T],
767770
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
768-
protocol: str = 'protobuf-array',
771+
protocol: ProtocolType = 'protobuf-array',
769772
compress: Optional[str] = None,
770773
show_progress: bool = False,
771774
streaming: bool = False,
@@ -814,7 +817,7 @@ def load_binary(
814817
def save_binary(
815818
self,
816819
file: Union[str, pathlib.Path],
817-
protocol: str = 'protobuf-array',
820+
protocol: ProtocolType = 'protobuf-array',
818821
compress: Optional[str] = None,
819822
show_progress: bool = False,
820823
) -> None:

docarray/array/doc_vec/io.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from docarray.typing import NdArray
3232
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3333
from docarray.utils._internal.pydantic import is_pydantic_v2
34+
from docarray.utils._internal.misc import ProtocolType
3435

3536
if TYPE_CHECKING:
3637
import csv
@@ -134,7 +135,6 @@ def _from_json_col_dict(
134135
json_columns: Dict[str, Any],
135136
tensor_type: Type[AbstractTensor] = NdArray,
136137
) -> T:
137-
138138
tensor_cols = json_columns['tensor_columns']
139139
doc_cols = json_columns['doc_columns']
140140
docs_vec_cols = json_columns['docs_vec_columns']
@@ -351,7 +351,7 @@ def from_csv(
351351
def from_base64(
352352
cls: Type[T],
353353
data: str,
354-
protocol: str = 'protobuf-array',
354+
protocol: ProtocolType = 'protobuf-array',
355355
compress: Optional[str] = None,
356356
show_progress: bool = False,
357357
tensor_type: Type['AbstractTensor'] = NdArray,
@@ -377,7 +377,7 @@ def from_base64(
377377
def from_bytes(
378378
cls: Type[T],
379379
data: bytes,
380-
protocol: str = 'protobuf-array',
380+
protocol: ProtocolType = 'protobuf-array',
381381
compress: Optional[str] = None,
382382
show_progress: bool = False,
383383
tensor_type: Type['AbstractTensor'] = NdArray,
@@ -454,7 +454,7 @@ class Person(BaseDoc):
454454
def load_binary(
455455
cls: Type[T],
456456
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
457-
protocol: str = 'protobuf-array',
457+
protocol: ProtocolType = 'protobuf-array',
458458
compress: Optional[str] = None,
459459
show_progress: bool = False,
460460
streaming: bool = False,

docarray/base_doc/mixins/io.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS
2727
from docarray.utils._internal._typing import safe_issubclass
2828
from docarray.utils._internal.compress import _compress_bytes, _decompress_bytes
29-
from docarray.utils._internal.misc import import_library
29+
from docarray.utils._internal.misc import ProtocolType, import_library
3030
from docarray.utils._internal.pydantic import is_pydantic_v2
3131

3232
if TYPE_CHECKING:
@@ -37,7 +37,6 @@
3737
from docarray.proto import DocProto, NodeProto
3838
from docarray.typing import TensorFlowTensor, TorchTensor
3939

40-
4140
else:
4241
tf = import_library('tensorflow', raise_error=False)
4342
if tf is not None:
@@ -150,7 +149,7 @@ def __bytes__(self) -> bytes:
150149
return self.to_bytes()
151150

152151
def to_bytes(
153-
self, protocol: str = 'protobuf', compress: Optional[str] = None
152+
self, protocol: ProtocolType = 'protobuf', compress: Optional[str] = None
154153
) -> bytes:
155154
"""Serialize itself into bytes.
156155
@@ -177,7 +176,7 @@ def to_bytes(
177176
def from_bytes(
178177
cls: Type[T],
179178
data: bytes,
180-
protocol: str = 'protobuf',
179+
protocol: ProtocolType = 'protobuf',
181180
compress: Optional[str] = None,
182181
) -> T:
183182
"""Build Document object from binary bytes
@@ -203,7 +202,7 @@ def from_bytes(
203202
)
204203

205204
def to_base64(
206-
self, protocol: str = 'protobuf', compress: Optional[str] = None
205+
self, protocol: ProtocolType = 'protobuf', compress: Optional[str] = None
207206
) -> str:
208207
"""Serialize a Document object into as base64 string
209208
@@ -329,7 +328,6 @@ def _get_content_from_node_proto(
329328
return_field = getattr(value, content_key)
330329

331330
elif content_key in arg_to_container.keys():
332-
333331
if field_name and field_name in cls._docarray_fields():
334332
field_type = cls._get_field_inner_type(field_name)
335333
else:
@@ -347,7 +345,6 @@ def _get_content_from_node_proto(
347345
deser_dict: Dict[str, Any] = dict()
348346

349347
if field_name and field_name in cls._docarray_fields():
350-
351348
if is_pydantic_v2:
352349
dict_args = get_args(
353350
cls._docarray_fields()[field_name].annotation

docarray/store/helpers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from rich import filesize
77
from typing_extensions import TYPE_CHECKING, Protocol
88

9+
from docarray.utils._internal.misc import ProtocolType
910
from docarray.utils._internal.progress_bar import _get_progressbar
1011

1112
if TYPE_CHECKING:
@@ -112,12 +113,12 @@ def raise_req_error(resp: 'requests.Response') -> NoReturn:
112113
class Streamable(Protocol):
113114
"""A protocol for streamable objects."""
114115

115-
def to_bytes(self, protocol: str, compress: Optional[str]) -> bytes:
116+
def to_bytes(self, protocol: ProtocolType, compress: Optional[str]) -> bytes:
116117
...
117118

118119
@classmethod
119120
def from_bytes(
120-
cls: Type[T_Elem], bytes: bytes, protocol: str, compress: Optional[str]
121+
cls: Type[T_Elem], bytes: bytes, protocol: ProtocolType, compress: Optional[str]
121122
) -> 'T_Elem':
122123
...
123124

@@ -133,7 +134,7 @@ def close(self):
133134
def _to_binary_stream(
134135
iterator: Iterator['Streamable'],
135136
total: Optional[int] = None,
136-
protocol: str = 'protobuf',
137+
protocol: ProtocolType = 'protobuf',
137138
compress: Optional[str] = None,
138139
show_progress: bool = False,
139140
) -> Iterator[bytes]:
@@ -170,7 +171,7 @@ def _from_binary_stream(
170171
cls: Type[T],
171172
stream: ReadableBytes,
172173
total: Optional[int] = None,
173-
protocol: str = 'protobuf',
174+
protocol: ProtocolType = 'protobuf',
174175
compress: Optional[str] = None,
175176
show_progress: bool = False,
176177
) -> Iterator['T']:

docarray/utils/_internal/misc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import re
44
import types
5-
from typing import Any, Optional
5+
from typing import Any, Optional, Literal
66

77
import numpy as np
88

@@ -52,6 +52,10 @@
5252
'pymilvus': '"docarray[milvus]"',
5353
}
5454

55+
ProtocolType = Literal[
56+
'protobuf', 'pickle', 'json', 'json-array', 'protobuf-array', 'pickle-array'
57+
]
58+
5559

5660
def import_library(
5761
package: str, raise_error: bool = True

0 commit comments

Comments
 (0)