diff --git a/docarray/array/array_stacked.py b/docarray/array/array_stacked.py index 9758fc6e271..319e86c9495 100644 --- a/docarray/array/array_stacked.py +++ b/docarray/array/array_stacked.py @@ -87,7 +87,7 @@ def _from_columns( da_stacked._docs = docs return da_stacked - def to(self: T, device: str): + def to(self: T, device: str) -> T: """Move all tensors of this DocumentArrayStacked to the given device :param device: the device to move the data to @@ -106,6 +106,7 @@ def to(self: T, device: str): else: # recursive call col_docarray = cast(T, col) col_docarray.to(device) + return self @classmethod def _get_columns_schema(