3636 _dict_to_access_paths ,
3737)
3838from docarray .utils ._internal .compress import _decompress_bytes , _get_compress_ctx
39- from docarray .utils ._internal .misc import import_library
39+ from docarray .utils ._internal .misc import import_library , ProtocolType
4040
4141if TYPE_CHECKING :
4242 import pandas as pd
5757
5858def _protocol_and_compress_from_file_path (
5959 file_path : Union [pathlib .Path , str ],
60- default_protocol : Optional [str ] = None ,
60+ default_protocol : Optional [ProtocolType ] = None ,
6161 default_compress : Optional [str ] = None ,
62- ) -> Tuple [Optional [str ], Optional [str ]]:
62+ ) -> Tuple [Optional [ProtocolType ], Optional [str ]]:
6363 """Extract protocol and compression algorithm from a string, use defaults if not found.
6464 :param file_path: path of a file.
6565 :param default_protocol: default serialization protocol used in case not found.
@@ -79,7 +79,7 @@ def _protocol_and_compress_from_file_path(
7979 file_extensions = [e .replace ('.' , '' ) for e in pathlib .Path (file_path ).suffixes ]
8080 for extension in file_extensions :
8181 if extension in ALLOWED_PROTOCOLS :
82- protocol = extension
82+ protocol = cast ( ProtocolType , extension )
8383 elif extension in ALLOWED_COMPRESSIONS :
8484 compress = extension
8585
@@ -135,7 +135,7 @@ def to_protobuf(self) -> 'DocListProto':
135135 def from_bytes (
136136 cls : Type [T ],
137137 data : bytes ,
138- protocol : str = 'protobuf-array' ,
138+ protocol : ProtocolType = 'protobuf-array' ,
139139 compress : Optional [str ] = None ,
140140 show_progress : bool = False ,
141141 ) -> T :
@@ -157,7 +157,7 @@ def from_bytes(
157157 def _write_bytes (
158158 self ,
159159 bf : BinaryIO ,
160- protocol : str = 'protobuf-array' ,
160+ protocol : ProtocolType = 'protobuf-array' ,
161161 compress : Optional [str ] = None ,
162162 show_progress : bool = False ,
163163 ) -> None :
@@ -201,7 +201,7 @@ def _write_bytes(
201201
202202 def _to_binary_stream (
203203 self ,
204- protocol : str = 'protobuf' ,
204+ protocol : ProtocolType = 'protobuf' ,
205205 compress : Optional [str ] = None ,
206206 show_progress : bool = False ,
207207 ) -> Iterator [bytes ]:
@@ -241,7 +241,7 @@ def _to_binary_stream(
241241
242242 def to_bytes (
243243 self ,
244- protocol : str = 'protobuf-array' ,
244+ protocol : ProtocolType = 'protobuf-array' ,
245245 compress : Optional [str ] = None ,
246246 file_ctx : Optional [BinaryIO ] = None ,
247247 show_progress : bool = False ,
@@ -273,7 +273,7 @@ def to_bytes(
273273 def from_base64 (
274274 cls : Type [T ],
275275 data : str ,
276- protocol : str = 'protobuf-array' ,
276+ protocol : ProtocolType = 'protobuf-array' ,
277277 compress : Optional [str ] = None ,
278278 show_progress : bool = False ,
279279 ) -> T :
@@ -294,7 +294,7 @@ def from_base64(
294294
295295 def to_base64 (
296296 self ,
297- protocol : str = 'protobuf-array' ,
297+ protocol : ProtocolType = 'protobuf-array' ,
298298 compress : Optional [str ] = None ,
299299 show_progress : bool = False ,
300300 ) -> str :
@@ -383,7 +383,6 @@ def _from_csv_file(
383383 file : Union [StringIO , TextIOWrapper ],
384384 dialect : Union [str , csv .Dialect ],
385385 ) -> 'T' :
386-
387386 rows = csv .DictReader (file , dialect = dialect )
388387
389388 doc_type = cls .doc_type
@@ -576,7 +575,7 @@ def _get_proto_class(cls: Type[T]):
576575 def _load_binary_all (
577576 cls : Type [T ],
578577 file_ctx : Union [ContextManager [io .BufferedReader ], ContextManager [bytes ]],
579- protocol : Optional [str ],
578+ protocol : Optional [ProtocolType ],
580579 compress : Optional [str ],
581580 show_progress : bool ,
582581 tensor_type : Optional [Type ['AbstractTensor' ]] = None ,
@@ -659,7 +658,9 @@ def _load_binary_all(
659658 start_pos = end_doc_pos
660659
661660 # variable length bytes doc
662- load_protocol : str = protocol or 'protobuf'
661+ load_protocol : ProtocolType = protocol or cast (
662+ ProtocolType , 'protobuf'
663+ )
663664 doc = cls .doc_type .from_bytes (
664665 d [start_doc_pos :end_doc_pos ],
665666 protocol = load_protocol ,
@@ -680,7 +681,7 @@ def _load_binary_all(
680681 def _load_binary_stream (
681682 cls : Type [T ],
682683 file_ctx : ContextManager [io .BufferedReader ],
683- protocol : str = 'protobuf' ,
684+ protocol : ProtocolType = 'protobuf' ,
684685 compress : Optional [str ] = None ,
685686 show_progress : bool = False ,
686687 ) -> Generator ['T_doc' , None , None ]:
@@ -728,7 +729,7 @@ def _load_binary_stream(
728729 len_current_doc_in_bytes = int .from_bytes (
729730 f .read (4 ), 'big' , signed = False
730731 )
731- load_protocol : str = protocol
732+ load_protocol : ProtocolType = protocol
732733 yield cls .doc_type .from_bytes (
733734 f .read (len_current_doc_in_bytes ),
734735 protocol = load_protocol ,
@@ -743,10 +744,12 @@ def _load_binary_stream(
743744 @staticmethod
744745 def _get_file_context (
745746 file : Union [str , bytes , pathlib .Path , io .BufferedReader , _LazyRequestReader ],
746- protocol : str ,
747+ protocol : ProtocolType ,
747748 compress : Optional [str ] = None ,
748- ) -> Tuple [Union [nullcontext , io .BufferedReader ], Optional [str ], Optional [str ]]:
749- load_protocol : Optional [str ] = protocol
749+ ) -> Tuple [
750+ Union [nullcontext , io .BufferedReader ], Optional [ProtocolType ], Optional [str ]
751+ ]:
752+ load_protocol : Optional [ProtocolType ] = protocol
750753 load_compress : Optional [str ] = compress
751754 file_ctx : Union [nullcontext , io .BufferedReader ]
752755 if isinstance (file , (io .BufferedReader , _LazyRequestReader , bytes )):
@@ -765,7 +768,7 @@ def _get_file_context(
765768 def load_binary (
766769 cls : Type [T ],
767770 file : Union [str , bytes , pathlib .Path , io .BufferedReader , _LazyRequestReader ],
768- protocol : str = 'protobuf-array' ,
771+ protocol : ProtocolType = 'protobuf-array' ,
769772 compress : Optional [str ] = None ,
770773 show_progress : bool = False ,
771774 streaming : bool = False ,
@@ -814,7 +817,7 @@ def load_binary(
814817 def save_binary (
815818 self ,
816819 file : Union [str , pathlib .Path ],
817- protocol : str = 'protobuf-array' ,
820+ protocol : ProtocolType = 'protobuf-array' ,
818821 compress : Optional [str ] = None ,
819822 show_progress : bool = False ,
820823 ) -> None :
0 commit comments