diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index 40d3486699f..497096b74e6 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -6,6 +6,7 @@ import pickle from abc import abstractmethod from contextlib import nullcontext +from io import StringIO, TextIOWrapper from itertools import compress from typing import ( TYPE_CHECKING, @@ -361,40 +362,52 @@ def from_csv( 'unix' (for csv file generated on UNIX systems). :return: DocList """ - from docarray import DocList - if cls.doc_type == AnyDoc: raise TypeError( 'There is no document schema defined. ' 'Please specify the DocList\'s Document type using `DocList[MyDoc]`.' ) + if file_path.startswith('http'): + import urllib.request + + with urllib.request.urlopen(file_path) as f: + file = StringIO(f.read().decode(encoding)) + return cls._from_csv_file(file, dialect) + else: + with open(file_path, 'r', encoding=encoding) as fp: + return cls._from_csv_file(fp, dialect) + + @classmethod + def _from_csv_file( + cls, file: Union[StringIO, TextIOWrapper], dialect: Union[str, csv.Dialect] + ) -> 'DocList': + from docarray import DocList + + rows = csv.DictReader(file, dialect=dialect) + doc_type = cls.doc_type docs = DocList.__class_getitem__(doc_type)() - with open(file_path, 'r', encoding=encoding) as fp: - rows = csv.DictReader(fp, dialect=dialect) - field_names: List[str] = ( - [] if rows.fieldnames is None else [str(f) for f in rows.fieldnames] - ) - if field_names is None or len(field_names) == 0: - raise TypeError("No field names are given.") + field_names: List[str] = ( + [] if rows.fieldnames is None else [str(f) for f in rows.fieldnames] + ) + if field_names is None or len(field_names) == 0: + raise TypeError("No field names are given.") - valid_paths = _all_access_paths_valid( - doc_type=doc_type, access_paths=field_names + valid_paths = _all_access_paths_valid( + doc_type=doc_type, access_paths=field_names + ) + if not all(valid_paths): + raise ValueError( + f'Column names do not match the schema of the DocList\'s ' + f'document type ({cls.doc_type.__name__}): ' + f'{list(compress(field_names, [not v for v in valid_paths]))}' ) - if not all(valid_paths): - raise ValueError( - f'Column names do not match the schema of the DocList\'s ' - f'document type ({cls.doc_type.__name__}): ' - f'{list(compress(field_names, [not v for v in valid_paths]))}' - ) - for access_path2val in rows: - doc_dict: Dict[Any, Any] = _access_path_dict_to_nested_dict( - access_path2val - ) - docs.append(doc_type.parse_obj(doc_dict)) + for access_path2val in rows: + doc_dict: Dict[Any, Any] = _access_path_dict_to_nested_dict(access_path2val) + docs.append(doc_type.parse_obj(doc_dict)) return docs diff --git a/tests/toydata/books.csv b/tests/toydata/books.csv new file mode 100644 index 00000000000..7467bd4586e --- /dev/null +++ b/tests/toydata/books.csv @@ -0,0 +1,4 @@ +title,author,year +Harry Potter and the Philosopher's Stone,J. K. Rowling,1997 +Klara and the sun,Kazuo Ishiguro,2020 +A little life,Hanya Yanagihara,2015 \ No newline at end of file diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index 09ec98b6432..f58dcefd1cb 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -97,3 +97,16 @@ def test_from_csv_without_schema_raise_exception(): def test_from_csv_with_wrong_schema_raise_exception(nested_doc): with pytest.raises(ValueError, match='Column names do not match the schema'): DocList[nested_doc.__class__].from_csv(file_path=str(TOYDATA_DIR / 'docs.csv')) + + +def test_from_remote_csv_file(): + remote_url = 'https://github.com/docarray/docarray/blob/feat-csv-from-remote-file/tests/toydata/books.csv?raw=true' + + class Book(BaseDoc): + title: str + author: str + year: int + + books = DocList[Book].from_csv(file_path=remote_url) + + assert len(books) == 3