Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions docarray/doc_index/abstract_doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,9 +653,19 @@ def _create_columns(self, schema: Type[BaseDocument]) -> Dict[str, _ColumnInfo]:
return columns

def _create_single_column(self, field: 'ModelField', type_: Type) -> _ColumnInfo:
db_type = self.python_type_to_db_type(type_)
config = self._runtime_config.default_column_config[db_type].copy()
custom_config = field.field_info.extra

if 'col_type' in custom_config.keys():
db_type = custom_config['col_type']
custom_config.pop('col_type')
if db_type not in self._runtime_config.default_column_config.keys():
raise ValueError(
f'The given col_type is not a valid db type: {db_type}'
)
else:
db_type = self.python_type_to_db_type(type_)

config = self._runtime_config.default_column_config[db_type].copy()
config.update(custom_config)
# parse n_dim from parametrized tensor type
if (
Expand Down
21 changes: 20 additions & 1 deletion docs/tutorials/add_doc_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Make sure that you call the `super().__init__` method, which will do some basic

Your backend (database or similar) should represent Documents in the following way:
- Every field of a Document is a column in the database
- Column types follow a default that you define, based on the type hint of the associated field, but can also be configures by the user
- Column types follow a default that you define, based on the type hint of the associated field, but can also be configured by the user
- Every row in your database thus represents a Document
- **Nesting:** The most common way to handle nested Document (and the one where the `AbstractDocumentIndex` will hold your hand the most), is to flatten out nested Documents. But if your backend natively supports nesting representations, then feel free to leverage those!

Expand Down Expand Up @@ -146,6 +146,13 @@ class _ColumnInfo:

Again, these are automatically populated for you, so you can just use them in your implementation.

**Note:**
`_ColumnInfo.docarray_type` contains the python type as specified in `self._schema`, whereas
`_ColumnInfo.db_type` contains the data type of a particular database column.
By default, it holds that `_ColumnInfo.docarray_type == self.python_type_to_db_type(_ColumnInfo.db_type)`, as we will see later.
However, you should not rely on this, because a user can manually specify a different db_type.
Therefore, your implementation should rely on `_ColumnInfo.db_type` and not directly call `python_type_to_db_type()`.

### Properly handle `n_dim`

`_ColumnInfo.n_dim` is automatically obtained from type parametrizations of the form `NdArray[100]`;
Expand Down Expand Up @@ -292,6 +299,18 @@ This method is slightly special, because 1) it is not exposed to the user, and 2
It is intended to do the following: It takes a type of a field in the store's schema (e.g. `NdArray` for `tensor`), and returns the corresponding type in the database (e.g. `np.ndarray`).
The `BaseDocumentIndex` class uses this information to create and populate the `_ColumnInfo`s in `self._column_infos`.

If the user wants to change the default behaviour, one can set the db type by using the `col_type` field:

```python
class MySchema(BaseDocument):
my_num: float = Field(col_type='float64')
my_text: str = Field(..., col_type='varchar', max_len=2048)
```

In this case, the db type of `my_num` will be `'float64'` and the db type of `my_text` will be `'varchar'`.
Additional information regarding the col_type, such as `max_len` for `varchar` will be stored in the `_ColumnsInfo.config`.
The given col_type has to be a valid db type, meaning that has to be described in the index's `RuntimeConfig.default_column_config`.

### The `_index()` method

When indexing Documents, your implementation should behave in the following way:
Expand Down
35 changes: 34 additions & 1 deletion tests/doc_index/base_classes/test_base_doc_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ class DummyDocIndex(BaseDocumentIndex):
@dataclass
class RuntimeConfig(BaseDocumentIndex.RuntimeConfig):
default_column_config: Dict[Type, Dict[str, Any]] = field(
default_factory=lambda: {str: {'hi': 'there'}, np.ndarray: {'you': 'good?'}}
default_factory=lambda: {
str: {'hi': 'there'},
np.ndarray: {'you': 'good?'},
'varchar': {'good': 'bye'},
}
)

@dataclass
Expand Down Expand Up @@ -141,6 +145,35 @@ def test_create_columns():
assert store._column_infos['d__tens'].config == {'dim': 1000, 'hi': 'there'}


def test_columns_db_type_with_user_defined_mapping(tmp_path):
class MyDoc(BaseDocument):
tens: NdArray[10] = Field(dim=1000, col_type=np.ndarray)

store = DummyDocIndex[MyDoc](work_dir=str(tmp_path))

assert store._column_infos['tens'].db_type == np.ndarray


def test_columns_db_type_with_user_defined_mapping_additional_params(tmp_path):
class MyDoc(BaseDocument):
tens: NdArray[10] = Field(dim=1000, col_type='varchar', max_len=1024)

store = DummyDocIndex[MyDoc](work_dir=str(tmp_path))

assert store._column_infos['tens'].db_type == 'varchar'
assert store._column_infos['tens'].config['max_len'] == 1024


def test_columns_illegal_mapping(tmp_path):
class MyDoc(BaseDocument):
tens: NdArray[10] = Field(dim=1000, col_type='non_valid_type')

with pytest.raises(
ValueError, match='The given col_type is not a valid db type: non_valid_type'
):
DummyDocIndex[MyDoc](work_dir=str(tmp_path))


def test_is_schema_compatible():
class OtherSimpleDoc(SimpleDoc):
...
Expand Down
8 changes: 8 additions & 0 deletions tests/doc_index/hnswlib/test_index_get_del.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def test_index_simple_schema(ten_simple_docs, tmp_path, use_docarray):
assert index.get_current_count() == 10


def test_schema_with_user_defined_mapping(tmp_path):
class MyDoc(BaseDocument):
tens: NdArray[10] = Field(dim=1000, col_type=np.ndarray)

store = HnswDocumentIndex[MyDoc](work_dir=str(tmp_path))
assert store._column_infos['tens'].db_type == np.ndarray


@pytest.mark.parametrize('use_docarray', [True, False])
def test_index_flat_schema(ten_flat_docs, tmp_path, use_docarray):
store = HnswDocumentIndex[FlatDoc](work_dir=str(tmp_path))
Expand Down