diff --git a/docarray/doc_index/abstract_doc_index.py b/docarray/doc_index/abstract_doc_index.py index 352188bafae..e55096d1e60 100644 --- a/docarray/doc_index/abstract_doc_index.py +++ b/docarray/doc_index/abstract_doc_index.py @@ -653,9 +653,19 @@ def _create_columns(self, schema: Type[BaseDocument]) -> Dict[str, _ColumnInfo]: return columns def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo: - db_type = self.python_type_to_db_type(type_) - config = self._runtime_config.default_column_config[db_type].copy() custom_config = field.field_info.extra + + if 'col_type' in custom_config.keys(): + db_type = custom_config['col_type'] + custom_config.pop('col_type') + if db_type not in self._runtime_config.default_column_config.keys(): + raise ValueError( + f'The given col_type is not a valid db type: {db_type}' + ) + else: + db_type = self.python_type_to_db_type(type_) + + config = self._runtime_config.default_column_config[db_type].copy() config.update(custom_config) # parse n_dim from parametrized tensor type if ( diff --git a/docs/tutorials/add_doc_index.md b/docs/tutorials/add_doc_index.md index f93522c13b8..afc8476d87e 100644 --- a/docs/tutorials/add_doc_index.md +++ b/docs/tutorials/add_doc_index.md @@ -57,7 +57,7 @@ Make sure that you call the `super().__init__` method, which will do some basic Your backend (database or similar) should represent Documents in the following way: - Every field of a Document is a column in the database -- Column types follow a default that you define, based on the type hint of the associated field, but can also be configures by the user +- Column types follow a default that you define, based on the type hint of the associated field, but can also be configured by the user - Every row in your database thus represents a Document - **Nesting:** The most common way to handle nested Document (and the one where the `AbstractDocumentIndex` will hold your hand the most), is to flatten out nested Documents. But if your backend natively supports nesting representations, then feel free to leverage those! @@ -146,6 +146,13 @@ class _ColumnInfo: Again, these are automatically populated for you, so you can just use them in your implementation. +**Note:** +`_ColumnInfo.docarray_type` contains the python type as specified in `self._schema`, whereas +`_ColumnInfo.db_type` contains the data type of a particular database column. +By default, it holds that `_ColumnInfo.docarray_type == self.python_type_to_db_type(_ColumnInfo.db_type)`, as we will see later. +However, you should not rely on this, because a user can manually specify a different db_type. +Therefore, your implementation should rely on `_ColumnInfo.db_type` and not directly call `python_type_to_db_type()`. + ### Properly handle `n_dim` `_ColumnInfo.n_dim` is automatically obtained from type parametrizations of the form `NdArray[100]`; @@ -292,6 +299,18 @@ This method is slightly special, because 1) it is not exposed to the user, and 2 It is intended to do the following: It takes a type of a field in the store's schema (e.g. `NdArray` for `tensor`), and returns the corresponding type in the database (e.g. `np.ndarray`). The `BaseDocumentIndex` class uses this information to create and populate the `_ColumnInfo`s in `self._column_infos`. +If the user wants to change the default behaviour, one can set the db type by using the `col_type` field: + +```python +class MySchema(BaseDocument): + my_num: float = Field(col_type='float64') + my_text: str = Field(..., col_type='varchar', max_len=2048) +``` + +In this case, the db type of `my_num` will be `'float64'` and the db type of `my_text` will be `'varchar'`. +Additional information regarding the col_type, such as `max_len` for `varchar` will be stored in the `_ColumnsInfo.config`. +The given col_type has to be a valid db type, meaning that has to be described in the index's `RuntimeConfig.default_column_config`. + ### The `_index()` method When indexing Documents, your implementation should behave in the following way: diff --git a/tests/doc_index/base_classes/test_base_doc_store.py b/tests/doc_index/base_classes/test_base_doc_store.py index a6d1173a36b..e7e8357ef42 100644 --- a/tests/doc_index/base_classes/test_base_doc_store.py +++ b/tests/doc_index/base_classes/test_base_doc_store.py @@ -44,7 +44,11 @@ class DummyDocIndex(BaseDocumentIndex): @dataclass class RuntimeConfig(BaseDocumentIndex.RuntimeConfig): default_column_config: Dict[Type, Dict[str, Any]] = field( - default_factory=lambda: {str: {'hi': 'there'}, np.ndarray: {'you': 'good?'}} + default_factory=lambda: { + str: {'hi': 'there'}, + np.ndarray: {'you': 'good?'}, + 'varchar': {'good': 'bye'}, + } ) @dataclass @@ -141,6 +145,35 @@ def test_create_columns(): assert store._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'} +def test_columns_db_type_with_user_defined_mapping(tmp_path): + class MyDoc(BaseDocument): + tens: NdArray[10] = Field(dim=1000, col_type=np.ndarray) + + store = DummyDocIndex[MyDoc](work_dir=str(tmp_path)) + + assert store._column_infos['tens'].db_type == np.ndarray + + +def test_columns_db_type_with_user_defined_mapping_additional_params(tmp_path): + class MyDoc(BaseDocument): + tens: NdArray[10] = Field(dim=1000, col_type='varchar', max_len=1024) + + store = DummyDocIndex[MyDoc](work_dir=str(tmp_path)) + + assert store._column_infos['tens'].db_type == 'varchar' + assert store._column_infos['tens'].config['max_len'] == 1024 + + +def test_columns_illegal_mapping(tmp_path): + class MyDoc(BaseDocument): + tens: NdArray[10] = Field(dim=1000, col_type='non_valid_type') + + with pytest.raises( + ValueError, match='The given col_type is not a valid db type: non_valid_type' + ): + DummyDocIndex[MyDoc](work_dir=str(tmp_path)) + + def test_is_schema_compatible(): class OtherSimpleDoc(SimpleDoc): ... diff --git a/tests/doc_index/hnswlib/test_index_get_del.py b/tests/doc_index/hnswlib/test_index_get_del.py index 72fecc66b9a..7b4ec2861c9 100644 --- a/tests/doc_index/hnswlib/test_index_get_del.py +++ b/tests/doc_index/hnswlib/test_index_get_del.py @@ -56,6 +56,14 @@ def test_index_simple_schema(ten_simple_docs, tmp_path, use_docarray): assert index.get_current_count() == 10 +def test_schema_with_user_defined_mapping(tmp_path): + class MyDoc(BaseDocument): + tens: NdArray[10] = Field(dim=1000, col_type=np.ndarray) + + store = HnswDocumentIndex[MyDoc](work_dir=str(tmp_path)) + assert store._column_infos['tens'].db_type == np.ndarray + + @pytest.mark.parametrize('use_docarray', [True, False]) def test_index_flat_schema(ten_flat_docs, tmp_path, use_docarray): store = HnswDocumentIndex[FlatDoc](work_dir=str(tmp_path))