diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index f3f1ffce16..1a0ea20e55 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -231,7 +231,9 @@ def __init__( # Now that we're starting the session, don't allow the options to be # changed. context._session_started = True - self._df_snapshot: Dict[bigquery.TableReference, datetime.datetime] = {} + self._df_snapshot: Dict[ + bigquery.TableReference, Tuple[datetime.datetime, bigquery.Table] + ] = {} @property def bqclient(self): @@ -698,16 +700,25 @@ def _get_snapshot_sql_and_primary_key( column(s), then return those too so that ordering generation can be avoided. """ - # If there are primary keys defined, the query engine assumes these - # columns are unique, even if the constraint is not enforced. We make - # the same assumption and use these columns as the total ordering keys. - table = self.bqclient.get_table(table_ref) + ( + snapshot_timestamp, + table, + ) = bigframes_io.get_snapshot_datetime_and_table_metadata( + self.bqclient, + table_ref=table_ref, + api_name=api_name, + cache=self._df_snapshot, + use_cache=use_cache, + ) if table.location.casefold() != self._location.casefold(): raise ValueError( f"Current session is in {self._location} but dataset '{table.project}.{table.dataset_id}' is located in {table.location}" ) + # If there are primary keys defined, the query engine assumes these + # columns are unique, even if the constraint is not enforced. We make + # the same assumption and use these columns as the total ordering keys. primary_keys = None if ( (table_constraints := getattr(table, "table_constraints", None)) is not None @@ -718,37 +729,6 @@ def _get_snapshot_sql_and_primary_key( ): primary_keys = columns - job_config = bigquery.QueryJobConfig() - job_config.labels["bigframes-api"] = api_name - if use_cache and table_ref in self._df_snapshot.keys(): - snapshot_timestamp = self._df_snapshot[table_ref] - - # Cache hit could be unexpected. See internal issue 329545805. - # Raise a warning with more information about how to avoid the - # problems with the cache. - warnings.warn( - f"Reading cached table from {snapshot_timestamp} to avoid " - "incompatibilies with previous reads of this table. To read " - "the latest version, set `use_cache=False` or close the " - "current session with Session.close() or " - "bigframes.pandas.close_session().", - # There are many layers before we get to (possibly) the user's code: - # pandas.read_gbq_table - # -> with_default_session - # -> Session.read_gbq_table - # -> _read_gbq_table - # -> _get_snapshot_sql_and_primary_key - stacklevel=6, - ) - else: - snapshot_timestamp = list( - self.bqclient.query( - "SELECT CURRENT_TIMESTAMP() AS `current_timestamp`", - job_config=job_config, - ).result() - )[0][0] - self._df_snapshot[table_ref] = snapshot_timestamp - try: table_expression = self.ibis_client.sql( bigframes_io.create_snapshot_sql(table_ref, snapshot_timestamp) diff --git a/bigframes/session/_io/bigquery.py b/bigframes/session/_io/bigquery.py index ac6ba4bae4..94576cfa12 100644 --- a/bigframes/session/_io/bigquery.py +++ b/bigframes/session/_io/bigquery.py @@ -23,6 +23,7 @@ import types from typing import Dict, Iterable, Optional, Sequence, Tuple, Union import uuid +import warnings import google.api_core.exceptions import google.cloud.bigquery as bigquery @@ -121,6 +122,59 @@ def table_ref_to_sql(table: bigquery.TableReference) -> str: return f"`{table.project}`.`{table.dataset_id}`.`{table.table_id}`" +def get_snapshot_datetime_and_table_metadata( + bqclient: bigquery.Client, + table_ref: bigquery.TableReference, + *, + api_name: str, + cache: Dict[bigquery.TableReference, Tuple[datetime.datetime, bigquery.Table]], + use_cache: bool = True, +) -> Tuple[datetime.datetime, bigquery.Table]: + cached_table = cache.get(table_ref) + if use_cache and cached_table is not None: + snapshot_timestamp, _ = cached_table + + # Cache hit could be unexpected. See internal issue 329545805. + # Raise a warning with more information about how to avoid the + # problems with the cache. + warnings.warn( + f"Reading cached table from {snapshot_timestamp} to avoid " + "incompatibilies with previous reads of this table. To read " + "the latest version, set `use_cache=False` or close the " + "current session with Session.close() or " + "bigframes.pandas.close_session().", + # There are many layers before we get to (possibly) the user's code: + # pandas.read_gbq_table + # -> with_default_session + # -> Session.read_gbq_table + # -> _read_gbq_table + # -> _get_snapshot_sql_and_primary_key + # -> get_snapshot_datetime_and_table_metadata + stacklevel=7, + ) + return cached_table + + # TODO(swast): It's possible that the table metadata is changed between now + # and when we run the CURRENT_TIMESTAMP() query to see when we can time + # travel to. Find a way to fetch the table metadata and BQ's current time + # atomically. + table = bqclient.get_table(table_ref) + + # TODO(b/336521938): Refactor to make sure we set the "bigframes-api" + # whereever we execute a query. + job_config = bigquery.QueryJobConfig() + job_config.labels["bigframes-api"] = api_name + snapshot_timestamp = list( + bqclient.query( + "SELECT CURRENT_TIMESTAMP() AS `current_timestamp`", + job_config=job_config, + ).result() + )[0][0] + cached_table = (snapshot_timestamp, table) + cache[table_ref] = cached_table + return cached_table + + def create_snapshot_sql( table_ref: bigquery.TableReference, current_timestamp: datetime.datetime ) -> str: diff --git a/tests/unit/session/test_session.py b/tests/unit/session/test_session.py index 543196066a..4ba47190bd 100644 --- a/tests/unit/session/test_session.py +++ b/tests/unit/session/test_session.py @@ -42,8 +42,11 @@ def test_read_gbq_cached_table(): google.cloud.bigquery.DatasetReference("my-project", "my_dataset"), "my_table", ) - session._df_snapshot[table_ref] = datetime.datetime( - 1999, 1, 2, 3, 4, 5, 678901, tzinfo=datetime.timezone.utc + table = google.cloud.bigquery.Table(table_ref) + table._properties["location"] = session._location + session._df_snapshot[table_ref] = ( + datetime.datetime(1999, 1, 2, 3, 4, 5, 678901, tzinfo=datetime.timezone.utc), + table, ) with pytest.warns(UserWarning, match=re.escape("use_cache=False")):