From f2bb55ef696a18af1e63fbd15813eaf43ffb9f34 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 11 Apr 2023 21:38:55 +0200 Subject: [PATCH 1/5] test: add books.csv to toydata dir Signed-off-by: anna-charlotte --- docarray/array/doc_list/io.py | 9 +++++++++ tests/toydata/books.csv | 4 ++++ 2 files changed, 13 insertions(+) create mode 100644 tests/toydata/books.csv diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index 40d3486699f..12d2a822490 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -372,6 +372,15 @@ def from_csv( doc_type = cls.doc_type docs = DocList.__class_getitem__(doc_type)() + if file_path.startswith('http'): + import urllib3 + + http = urllib3.PoolManager() + + response = http.request('GET', file_path) + print(f"response.data = {response.data}") + file_path = response.data + with open(file_path, 'r', encoding=encoding) as fp: rows = csv.DictReader(fp, dialect=dialect) field_names: List[str] = ( diff --git a/tests/toydata/books.csv b/tests/toydata/books.csv new file mode 100644 index 00000000000..7467bd4586e --- /dev/null +++ b/tests/toydata/books.csv @@ -0,0 +1,4 @@ +title,author,year +Harry Potter and the Philosopher's Stone,J. K. Rowling,1997 +Klara and the sun,Kazuo Ishiguro,2020 +A little life,Hanya Yanagihara,2015 \ No newline at end of file From 9d175727423291b54c2891b12b6de444c40c83fa Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Tue, 11 Apr 2023 23:09:09 +0200 Subject: [PATCH 2/5] fix: remote csv file and test Signed-off-by: anna-charlotte --- docarray/array/doc_list/io.py | 64 +++++++++++---------- tests/units/array/test_array_from_to_csv.py | 13 +++++ 2 files changed, 47 insertions(+), 30 deletions(-) diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index 12d2a822490..79591d22825 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -6,6 +6,7 @@ import pickle from abc import abstractmethod from contextlib import nullcontext +from io import StringIO from itertools import compress from typing import ( TYPE_CHECKING, @@ -361,49 +362,52 @@ def from_csv( 'unix' (for csv file generated on UNIX systems). :return: DocList """ - from docarray import DocList - if cls.doc_type == AnyDoc: raise TypeError( 'There is no document schema defined. ' 'Please specify the DocList\'s Document type using `DocList[MyDoc]`.' ) - doc_type = cls.doc_type - docs = DocList.__class_getitem__(doc_type)() - if file_path.startswith('http'): - import urllib3 + import urllib.request - http = urllib3.PoolManager() + with urllib.request.urlopen(file_path) as f: + file_as_string = StringIO(f.read().decode(encoding)) + rows = csv.DictReader(file_as_string, dialect=dialect) + docs = cls._from_csv_dict_reader(rows) + else: + with open(file_path, 'r', encoding=encoding) as fp: + rows = csv.DictReader(fp, dialect=dialect) + docs = cls._from_csv_dict_reader(rows) - response = http.request('GET', file_path) - print(f"response.data = {response.data}") - file_path = response.data + return docs - with open(file_path, 'r', encoding=encoding) as fp: - rows = csv.DictReader(fp, dialect=dialect) - field_names: List[str] = ( - [] if rows.fieldnames is None else [str(f) for f in rows.fieldnames] - ) - if field_names is None or len(field_names) == 0: - raise TypeError("No field names are given.") + @classmethod + def _from_csv_dict_reader(cls, rows: csv.DictReader) -> 'DocList': + from docarray import DocList - valid_paths = _all_access_paths_valid( - doc_type=doc_type, access_paths=field_names + doc_type = cls.doc_type + docs = DocList.__class_getitem__(doc_type)() + + field_names: List[str] = ( + [] if rows.fieldnames is None else [str(f) for f in rows.fieldnames] + ) + if field_names is None or len(field_names) == 0: + raise TypeError("No field names are given.") + + valid_paths = _all_access_paths_valid( + doc_type=doc_type, access_paths=field_names + ) + if not all(valid_paths): + raise ValueError( + f'Column names do not match the schema of the DocList\'s ' + f'document type ({cls.doc_type.__name__}): ' + f'{list(compress(field_names, [not v for v in valid_paths]))}' ) - if not all(valid_paths): - raise ValueError( - f'Column names do not match the schema of the DocList\'s ' - f'document type ({cls.doc_type.__name__}): ' - f'{list(compress(field_names, [not v for v in valid_paths]))}' - ) - for access_path2val in rows: - doc_dict: Dict[Any, Any] = _access_path_dict_to_nested_dict( - access_path2val - ) - docs.append(doc_type.parse_obj(doc_dict)) + for access_path2val in rows: + doc_dict: Dict[Any, Any] = _access_path_dict_to_nested_dict(access_path2val) + docs.append(doc_type.parse_obj(doc_dict)) return docs diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index 09ec98b6432..f58dcefd1cb 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -97,3 +97,16 @@ def test_from_csv_without_schema_raise_exception(): def test_from_csv_with_wrong_schema_raise_exception(nested_doc): with pytest.raises(ValueError, match='Column names do not match the schema'): DocList[nested_doc.__class__].from_csv(file_path=str(TOYDATA_DIR / 'docs.csv')) + + +def test_from_remote_csv_file(): + remote_url = 'https://github.com/docarray/docarray/blob/feat-csv-from-remote-file/tests/toydata/books.csv?raw=true' + + class Book(BaseDoc): + title: str + author: str + year: int + + books = DocList[Book].from_csv(file_path=remote_url) + + assert len(books) == 3 From 44d13262b0d36bcf2bf32fd42711f441fcff2d86 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 12 Apr 2023 09:14:38 +0200 Subject: [PATCH 3/5] fix: apply suggestion from code review Signed-off-by: anna-charlotte --- docarray/array/doc_list/io.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index 79591d22825..126a17fa8f0 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -372,13 +372,12 @@ def from_csv( import urllib.request with urllib.request.urlopen(file_path) as f: - file_as_string = StringIO(f.read().decode(encoding)) - rows = csv.DictReader(file_as_string, dialect=dialect) - docs = cls._from_csv_dict_reader(rows) + file = StringIO(f.read().decode(encoding)) else: - with open(file_path, 'r', encoding=encoding) as fp: - rows = csv.DictReader(fp, dialect=dialect) - docs = cls._from_csv_dict_reader(rows) + file = open(file_path, 'r', encoding=encoding) + + rows = csv.DictReader(file, dialect=dialect) + docs = cls._from_csv_dict_reader(rows) return docs From cdcf7fb4db95bb4950679a1d8a4302a1eae29407 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 12 Apr 2023 09:48:42 +0200 Subject: [PATCH 4/5] fix: apply suggestion Signed-off-by: anna-charlotte --- docarray/array/doc_list/io.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index 126a17fa8f0..eb4b87781a6 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -6,7 +6,7 @@ import pickle from abc import abstractmethod from contextlib import nullcontext -from io import StringIO +from io import StringIO, TextIOWrapper from itertools import compress from typing import ( TYPE_CHECKING, @@ -373,18 +373,19 @@ def from_csv( with urllib.request.urlopen(file_path) as f: file = StringIO(f.read().decode(encoding)) + return cls._from_csv_dict_reader(file, dialect) else: - file = open(file_path, 'r', encoding=encoding) - - rows = csv.DictReader(file, dialect=dialect) - docs = cls._from_csv_dict_reader(rows) - - return docs + with open(file_path, 'r', encoding=encoding) as fp: + return cls._from_csv_dict_reader(fp, dialect) @classmethod - def _from_csv_dict_reader(cls, rows: csv.DictReader) -> 'DocList': + def _from_csv_dict_reader( + cls, file: Union[StringIO, TextIOWrapper], dialect: Union[str, csv.Dialect] + ) -> 'DocList': from docarray import DocList + rows = csv.DictReader(file, dialect=dialect) + doc_type = cls.doc_type docs = DocList.__class_getitem__(doc_type)() From f15df092eecc109d1f9f80b201474bfc6cd6d8f9 Mon Sep 17 00:00:00 2001 From: anna-charlotte Date: Wed, 12 Apr 2023 09:49:28 +0200 Subject: [PATCH 5/5] refactor: rename private method Signed-off-by: anna-charlotte --- docarray/array/doc_list/io.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docarray/array/doc_list/io.py b/docarray/array/doc_list/io.py index eb4b87781a6..497096b74e6 100644 --- a/docarray/array/doc_list/io.py +++ b/docarray/array/doc_list/io.py @@ -373,13 +373,13 @@ def from_csv( with urllib.request.urlopen(file_path) as f: file = StringIO(f.read().decode(encoding)) - return cls._from_csv_dict_reader(file, dialect) + return cls._from_csv_file(file, dialect) else: with open(file_path, 'r', encoding=encoding) as fp: - return cls._from_csv_dict_reader(fp, dialect) + return cls._from_csv_file(fp, dialect) @classmethod - def _from_csv_dict_reader( + def _from_csv_file( cls, file: Union[StringIO, TextIOWrapper], dialect: Union[str, csv.Dialect] ) -> 'DocList': from docarray import DocList