Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 87 additions & 33 deletions btrdb/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
import json
import logging
import re
import uuid
import uuid as uuidlib
import warnings
from collections import deque
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Dict, List

import pyarrow
import pyarrow.compute as pc

try:
import pyarrow as pa
Expand Down Expand Up @@ -2183,22 +2187,10 @@ def arrow_values(
)
stream_uus = [str(s.uuid) for s in self._streams]
data = list(aligned_windows_gen)
tablex = data.pop(0)
uu = stream_uus.pop(0)
tab_columns = [
c if c == "time" else uu + "/" + c for c in tablex.column_names
]
tablex = tablex.rename_columns(tab_columns)
if data:
for tab, uu in zip(data, stream_uus):
tab_columns = [
c if c == "time" else uu + "/" + c for c in tab.column_names
]
tab = tab.rename_columns(tab_columns)
tablex = tablex.join(tab, "time", join_type="full outer")
data = tablex
else:
data = tablex
table_joined = _merge_pyarrow_tables(
{uu: tab for uu, tab in zip(stream_uus, data)}
)
data = table_joined

elif self.width is not None and self.depth is not None:
# create list of stream.windows data (the windows method should
Expand All @@ -2217,22 +2209,10 @@ def arrow_values(
)
stream_uus = [str(s.uuid) for s in self._streams]
data = list(windows_gen)
tablex = data.pop(0)
uu = stream_uus.pop(0)
tab_columns = [
c if c == "time" else uu + "/" + c for c in tablex.column_names
]
tablex = tablex.rename_columns(tab_columns)
if data:
for tab, uu in zip(data, stream_uus):
tab_columns = [
c if c == "time" else uu + "/" + c for c in tab.column_names
]
tab = tab.rename_columns(tab_columns)
tablex = tablex.join(tab, "time", join_type="full outer")
data = tablex
else:
data = tablex
table_joined = _merge_pyarrow_tables(
{uu: tab for uu, tab in zip(stream_uus, data)}
)
data = table_joined
else:
sampling_freq = params.pop("sampling_frequency", 0)
period_ns = 0
Expand Down Expand Up @@ -2349,3 +2329,77 @@ def _coalesce_table_deque(tables: deque):
t2, "time", join_type="full outer", right_suffix=f"_{idx}"
)
return main_table


def _extract_unique_times(stream_map: Dict[uuid.UUID, pa.Table]) -> pa.Array:
"""Extracts and returns unique 'time' values from all tables."""
time_arrays = [
table.column("time").combine_chunks()
for table in stream_map.values()
if table.num_rows > 0
]
if not time_arrays: # Check if the list is empty
return pa.array([], type=pa.timestamp("ns"))

all_times = pa.concat_arrays(time_arrays)
return pc.unique(all_times).sort()


def _build_combined_schema(
stream_map: Dict[uuid.UUID, pa.Table], unique_times: pa.Array
) -> pa.Schema:
"""Constructs a combined schema for the merged table, ensuring unique column names."""
combined_schema = [("time", unique_times.type)]
# Use a list comprehension to flatten the loop into a single iterable over all columns
combined_schema += [
pa.field(
f"{uu}/{col_name}",
table.column(col_name).type if table.num_rows > 0 else pa.null(),
)
for uu, table in stream_map.items()
for col_name in table.column_names
if col_name != "time"
]

return pa.schema(combined_schema)


def _merge_pyarrow_tables(stream_map: Dict[uuid.UUID, pa.Table]) -> pa.Table:
"""Merges PyArrow tables based on 'time' values into a single table."""
unique_times = _extract_unique_times(stream_map)
combined_schema = _build_combined_schema(stream_map, unique_times)

none_data = [None] * len(unique_times)
preallocated_data = {
field.name: pa.array(none_data, type=field.type) for field in combined_schema
}
preallocated_data["time"] = unique_times

for uu, table in stream_map.items():
if table.num_rows > 0:
time_indices = pc.index_in(
preallocated_data["time"],
value_set=table.column("time"),
skip_nulls=True,
)
for col_name in table.column_names:
if col_name == "time":
continue
combined_col_name = f"{str(uu)}/{col_name}"
preallocated_data[combined_col_name] = pc.take(
table.column(col_name), indices=time_indices
)
else:
# For empty tables, ensure their columns are represented with all nulls
for col_name in combined_schema.names:
if (
col_name.startswith(f"{str(uu)}/")
and col_name not in preallocated_data
):
field_type = combined_schema.field(col_name).type
preallocated_data[col_name] = pa.array(
[None] * preallocated_data["time"].length(), type=field_type
)

arrays = [preallocated_data[col] for col in combined_schema.names]
return pa.Table.from_arrays(arrays=arrays, schema=combined_schema)
41 changes: 40 additions & 1 deletion tests/btrdb_integration/test_streamset.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,46 @@ def test_streamset_arrow_aligned_windows_vs_aligned_windows(
ss = (
btrdb.stream.StreamSet([s1, s2, s3])
.filter(start=100, end=121)
.windows(width=btrdb.utils.general.pointwidth.from_nanoseconds(10))
.aligned_windows(pointwidth=btrdb.utils.general.pointwidth.from_nanoseconds(10))
)
values_arrow = ss.arrow_to_dataframe(name_callable=name_callable)
values_arrow.index = pd.DatetimeIndex(values_arrow.index)
values_prev = ss.to_dataframe(
name_callable=name_callable
) # .convert_dtypes(dtype_backend='pyarrow')
values_prev = values_prev.apply(lambda x: x.astype(str(x.dtype) + "[pyarrow]"))
values_prev = values_prev.apply(
lambda x: x.astype("uint64[pyarrow]") if "count" in x.name else x
)
values_prev.index = pd.DatetimeIndex(values_prev.index, tz="UTC")
col_map = {old_col: old_col + "/mean" for old_col in values_prev.columns}
values_prev = values_prev.rename(columns=col_map)
assert values_arrow.equals(values_prev)


@pytest.mark.parametrize(
"name_callable",
[(None), (lambda s: str(s.uuid)), (lambda s: s.name + "/" + s.collection)],
ids=["empty", "uu_as_str", "name_collection"],
)
def test_streamset_arrow_aligned_windows_join_logic(
conn, tmp_collection, name_callable
):
s1 = conn.create(new_uuid(), tmp_collection, tags={"name": "s1"})
s2 = conn.create(new_uuid(), tmp_collection, tags={"name": "s2"})
s3 = conn.create(new_uuid(), tmp_collection, tags={"name": "s3"})
t1 = [100, 105, 110, 115, 120]
t2 = [101, 106, 110, 132, 140]
d1 = [0.0, 1.0, 2.0, 3.0, 4.0]
d2 = [5.0, 6.0, 7.0, 8.0, 9.0]
d3 = [1.0, 9.0, 44.0, 8.0, 9.0]
s1.insert(list(zip(t1, d1)))
s2.insert(list(zip(t2, d2)))
s3.insert(list(zip(t2, d3)))
ss = (
btrdb.stream.StreamSet([s1, s2, s3])
.filter(start=100, end=141)
.aligned_windows(pointwidth=btrdb.utils.general.pointwidth.from_nanoseconds(8))
)
values_arrow = ss.arrow_to_dataframe(name_callable=name_callable)
values_arrow.index = pd.DatetimeIndex(values_arrow.index)
Expand Down