From 584bc9334727ea460d7785b22ea9534c23f8a172 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 1 May 2024 17:11:18 +0000 Subject: [PATCH 1/2] fix: include `index_col` when selecting `columns` and `filters` in `read_gbq_table` Fixes internal issue 339430305 test: refactor `read_gbq` / `read_gbq_table` tests to test with all parameters combined refactor: move query generation code to BigQuery I/O module --- bigframes/session/__init__.py | 139 +++++---------------- bigframes/session/_io/bigquery/__init__.py | 97 +++++++++++++- tests/system/small/test_session.py | 94 ++++++++++---- tests/unit/session/test_io_bigquery.py | 106 ++++++++++++++++ tests/unit/session/test_session.py | 82 ------------ 5 files changed, 306 insertions(+), 212 deletions(-) diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 7c7d93541c..86d103d83e 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -20,7 +20,6 @@ import datetime import logging import os -import re import secrets import typing from typing import ( @@ -89,7 +88,7 @@ import bigframes.formatting_helpers as formatting_helpers from bigframes.functions.remote_function import read_gbq_function as bigframes_rgf from bigframes.functions.remote_function import remote_function as bigframes_rf -import bigframes.session._io.bigquery as bigframes_io +import bigframes.session._io.bigquery as bf_io_bigquery import bigframes.session._io.bigquery.read_gbq_table as bf_read_gbq_table import bigframes.session.clients import bigframes.version @@ -145,14 +144,18 @@ ) -def _is_query(query_or_table: str) -> bool: - """Determine if `query_or_table` is a table ID or a SQL string""" - return re.search(r"\s", query_or_table.strip(), re.MULTILINE) is not None +def _to_index_cols( + index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = (), +) -> List[str]: + """Convert index_col into a list of column names.""" + if isinstance(index_col, bigframes.enums.DefaultIndexKind): + index_cols: List[str] = [] + elif isinstance(index_col, str): + index_cols = [index_col] + else: + index_cols = list(index_col) - -def _is_table_with_wildcard_suffix(query_or_table: str) -> bool: - """Determine if `query_or_table` is a table and contains a wildcard suffix.""" - return not _is_query(query_or_table) and query_or_table.endswith("*") + return index_cols class Session( @@ -322,13 +325,19 @@ def read_gbq( columns = col_order filters = list(filters) - if len(filters) != 0 or _is_table_with_wildcard_suffix(query_or_table): + if len(filters) != 0 or bf_io_bigquery.is_table_with_wildcard_suffix( + query_or_table + ): # TODO(b/338111344): This appears to be missing index_cols, which # are necessary to be selected. - # TODO(b/338039517): Also, need to account for primary keys. - query_or_table = self._to_query(query_or_table, columns, filters) + # TODO(b/338039517): Refactor this to be called inside both + # _read_gbq_query and _read_gbq_table (after detecting primary keys) + # so we can make sure index_col/index_cols reflects primary keys. + query_or_table = bf_io_bigquery.to_query( + query_or_table, _to_index_cols(index_col), columns, filters + ) - if _is_query(query_or_table): + if bf_io_bigquery.is_query(query_or_table): return self._read_gbq_query( query_or_table, index_col=index_col, @@ -355,85 +364,6 @@ def read_gbq( use_cache=use_cache if use_cache is not None else True, ) - def _to_query( - self, - query_or_table: str, - columns: Iterable[str], - filters: third_party_pandas_gbq.FiltersType, - ) -> str: - """Compile query_or_table with conditions(filters, wildcards) to query.""" - filters = list(filters) - sub_query = ( - f"({query_or_table})" - if _is_query(query_or_table) - else f"`{query_or_table}`" - ) - - # TODO(b/338111344): Generate an index based on DefaultIndexKind if we - # don't have index columns specified. - select_clause = "SELECT " + ( - ", ".join(f"`{column}`" for column in columns) if columns else "*" - ) - - where_clause = "" - if filters: - valid_operators: Mapping[third_party_pandas_gbq.FilterOps, str] = { - "in": "IN", - "not in": "NOT IN", - "LIKE": "LIKE", - "==": "=", - ">": ">", - "<": "<", - ">=": ">=", - "<=": "<=", - "!=": "!=", - } - - # If single layer filter, add another pseudo layer. So the single layer represents "and" logic. - if isinstance(filters[0], tuple) and ( - len(filters[0]) == 0 or not isinstance(list(filters[0])[0], tuple) - ): - filters = typing.cast(third_party_pandas_gbq.FiltersType, [filters]) - - or_expressions = [] - for group in filters: - if not isinstance(group, Iterable): - group = [group] - - and_expressions = [] - for filter_item in group: - if not isinstance(filter_item, tuple) or (len(filter_item) != 3): - raise ValueError( - f"Filter condition should be a tuple of length 3, {filter_item} is not valid." - ) - - column, operator, value = filter_item - - if not isinstance(column, str): - raise ValueError( - f"Column name should be a string, but received '{column}' of type {type(column).__name__}." - ) - - if operator not in valid_operators: - raise ValueError(f"Operator {operator} is not valid.") - - operator_str = valid_operators[operator] - - if operator_str in ["IN", "NOT IN"]: - value_list = ", ".join([repr(v) for v in value]) - expression = f"`{column}` {operator_str} ({value_list})" - else: - expression = f"`{column}` {operator_str} {repr(value)}" - and_expressions.append(expression) - - or_expressions.append(" AND ".join(and_expressions)) - - if or_expressions: - where_clause = " WHERE " + " OR ".join(or_expressions) - - full_query = f"{select_clause} FROM {sub_query} AS sub{where_clause}" - return full_query - def _query_to_destination( self, query: str, @@ -610,12 +540,7 @@ def _read_gbq_query( True if use_cache is None else use_cache ) - if isinstance(index_col, bigframes.enums.DefaultIndexKind): - index_cols = [] - elif isinstance(index_col, str): - index_cols = [index_col] - else: - index_cols = list(index_col) + index_cols = _to_index_cols(index_col) destination, query_job = self._query_to_destination( query, @@ -682,8 +607,13 @@ def read_gbq_table( columns = col_order filters = list(filters) - if len(filters) != 0 or _is_table_with_wildcard_suffix(query): - query = self._to_query(query, columns, filters) + if len(filters) != 0 or bf_io_bigquery.is_table_with_wildcard_suffix(query): + # TODO(b/338039517): Refactor this to be called inside both + # _read_gbq_query and _read_gbq_table (after detecting primary keys) + # so we can make sure index_col/index_cols reflects primary keys. + query = bf_io_bigquery.to_query( + query, _to_index_cols(index_col), columns, filters + ) return self._read_gbq_query( query, @@ -838,12 +768,7 @@ def _read_bigquery_load_job( index_col: Iterable[str] | str | bigframes.enums.DefaultIndexKind = (), columns: Iterable[str] = (), ) -> dataframe.DataFrame: - if isinstance(index_col, bigframes.enums.DefaultIndexKind): - index_cols = [] - elif isinstance(index_col, str): - index_cols = [index_col] - else: - index_cols = list(index_col) + index_cols = _to_index_cols(index_col) if not job_config.clustering_fields and index_cols: job_config.clustering_fields = index_cols[:_MAX_CLUSTER_COLUMNS] @@ -1430,7 +1355,7 @@ def _create_empty_temp_table( datetime.datetime.now(datetime.timezone.utc) + constants.DEFAULT_EXPIRATION ) - table = bigframes_io.create_temp_table( + table = bf_io_bigquery.create_temp_table( self, expiration, schema=schema, diff --git a/bigframes/session/_io/bigquery/__init__.py b/bigframes/session/_io/bigquery/__init__.py index 79108c71a2..98e0dac1e8 100644 --- a/bigframes/session/_io/bigquery/__init__.py +++ b/bigframes/session/_io/bigquery/__init__.py @@ -19,10 +19,13 @@ import datetime import itertools import os +import re import textwrap import types -from typing import Dict, Iterable, Optional, Sequence, Tuple, Union +import typing +from typing import Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union +import bigframes_vendored.pandas.io.gbq as third_party_pandas_gbq import google.api_core.exceptions import google.cloud.bigquery as bigquery @@ -311,3 +314,95 @@ def create_bq_dataset_reference( query_destination.project, query_destination.dataset_id, ) + + +def is_query(query_or_table: str) -> bool: + """Determine if `query_or_table` is a table ID or a SQL string""" + return re.search(r"\s", query_or_table.strip(), re.MULTILINE) is not None + + +def is_table_with_wildcard_suffix(query_or_table: str) -> bool: + """Determine if `query_or_table` is a table and contains a wildcard suffix.""" + return not is_query(query_or_table) and query_or_table.endswith("*") + + +def to_query( + query_or_table: str, + index_cols: Iterable[str], + columns: Iterable[str], + filters: third_party_pandas_gbq.FiltersType, +) -> str: + """Compile query_or_table with conditions(filters, wildcards) to query.""" + filters = list(filters) + sub_query = ( + f"({query_or_table})" if is_query(query_or_table) else f"`{query_or_table}`" + ) + + # TODO(b/338111344): Generate an index based on DefaultIndexKind if we + # don't have index columns specified. + if columns: + # We only reduce the selection if columns is set, but we always + # want to make sure index_cols is also included. + all_columns = itertools.chain(index_cols, columns) + select_clause = "SELECT " + ", ".join(f"`{column}`" for column in all_columns) + else: + select_clause = "SELECT *" + + where_clause = "" + if filters: + valid_operators: Mapping[third_party_pandas_gbq.FilterOps, str] = { + "in": "IN", + "not in": "NOT IN", + "LIKE": "LIKE", + "==": "=", + ">": ">", + "<": "<", + ">=": ">=", + "<=": "<=", + "!=": "!=", + } + + # If single layer filter, add another pseudo layer. So the single layer represents "and" logic. + if isinstance(filters[0], tuple) and ( + len(filters[0]) == 0 or not isinstance(list(filters[0])[0], tuple) + ): + filters = typing.cast(third_party_pandas_gbq.FiltersType, [filters]) + + or_expressions = [] + for group in filters: + if not isinstance(group, Iterable): + group = [group] + + and_expressions = [] + for filter_item in group: + if not isinstance(filter_item, tuple) or (len(filter_item) != 3): + raise ValueError( + f"Filter condition should be a tuple of length 3, {filter_item} is not valid." + ) + + column, operator, value = filter_item + + if not isinstance(column, str): + raise ValueError( + f"Column name should be a string, but received '{column}' of type {type(column).__name__}." + ) + + if operator not in valid_operators: + raise ValueError(f"Operator {operator} is not valid.") + + operator_str = valid_operators[operator] + + if operator_str in ["IN", "NOT IN"]: + value_list = ", ".join([repr(v) for v in value]) + expression = f"`{column}` {operator_str} ({value_list})" + else: + expression = f"`{column}` {operator_str} {repr(value)}" + and_expressions.append(expression) + + or_expressions.append(" AND ".join(and_expressions)) + + if or_expressions: + where_clause = " WHERE " + " OR ".join(or_expressions) + + full_query = f"{select_clause} FROM {sub_query} AS sub{where_clause}" + return full_query diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index 6b2d7df50d..5daa01ad38 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -18,7 +18,7 @@ import textwrap import time import typing -from typing import List +from typing import List, Sequence import google import google.cloud.bigquery as bigquery @@ -338,30 +338,80 @@ def test_read_gbq_table_clustered_with_filter(session: bigframes.Session): assert "OLI_TIRS" in sensors.index -def test_read_gbq_wildcard(session: bigframes.Session): - df = session.read_gbq("bigquery-public-data.noaa_gsod.gsod193*") - assert df.shape == (348485, 32) +_GSOD_ALL_TABLES = "bigquery-public-data.noaa_gsod.gsod*" +_GSOD_1930S = "bigquery-public-data.noaa_gsod.gsod193*" -def test_read_gbq_wildcard_with_filter(session: bigframes.Session): - df = session.read_gbq( - "bigquery-public-data.noaa_gsod.gsod19*", - filters=[("_table_suffix", ">=", "30"), ("_table_suffix", "<=", "39")], # type: ignore - ) - assert df.shape == (348485, 32) - - -def test_read_gbq_table_wildcard(session: bigframes.Session): - df = session.read_gbq_table("bigquery-public-data.noaa_gsod.gsod193*") - assert df.shape == (348485, 32) - - -def test_read_gbq_table_wildcard_with_filter(session: bigframes.Session): - df = session.read_gbq_table( - "bigquery-public-data.noaa_gsod.gsod19*", - filters=[("_table_suffix", ">=", "30"), ("_table_suffix", "<=", "39")], # type: ignore +@pytest.mark.parametrize( + "api_method", + # Test that both methods work as there's a risk that read_gbq / + # read_gbq_table makes for an infinite loop. Table reads can convert to + # queries and read_gbq reads from tables. + ["read_gbq", "read_gbq_table"], +) +@pytest.mark.parametrize( + ("filters", "table_id", "index_col", "columns"), + [ + pytest.param( + [("_table_suffix", ">=", "1930"), ("_table_suffix", "<=", "1939")], + _GSOD_ALL_TABLES, + ["stn", "wban", "year", "mo", "da"], + ["temp", "max", "min"], + id="all", + ), + pytest.param( + (), # filters + _GSOD_1930S, + (), # index_col + ["temp", "max", "min"], + id="columns", + ), + pytest.param( + [("_table_suffix", ">=", "1930"), ("_table_suffix", "<=", "1939")], + _GSOD_ALL_TABLES, + (), # index_col, + (), # columns + id="filters", + ), + pytest.param( + (), # filters + _GSOD_1930S, + ["stn", "wban", "year", "mo", "da"], + (), # columns + id="index_col", + ), + ], +) +def test_read_gbq_wildcard( + session: bigframes.Session, + api_method: str, + filters, + table_id: str, + index_col: Sequence[str], + columns: Sequence[str], +): + table_metadata = session.bqclient.get_table(table_id) + method = getattr(session, api_method) + df = method(table_id, filters=filters, index_col=index_col, columns=columns) + num_rows, num_columns = df.shape + + if index_col: + assert list(df.index.names) == list(index_col) + else: + assert df.index.name is None + + expected_columns = ( + columns + if columns + else [ + field.name + for field in table_metadata.schema + if field.name not in index_col and field.name not in columns + ] ) - assert df.shape == (348485, 32) + assert list(df.columns) == expected_columns + assert num_rows > 0 + assert num_columns == len(expected_columns) @pytest.mark.parametrize( diff --git a/tests/unit/session/test_io_bigquery.py b/tests/unit/session/test_io_bigquery.py index 43865fc2c8..9da085e824 100644 --- a/tests/unit/session/test_io_bigquery.py +++ b/tests/unit/session/test_io_bigquery.py @@ -210,3 +210,109 @@ def test_create_temp_table_default_expiration(): def test_bq_schema_to_sql(schema: Iterable[bigquery.SchemaField], expected: str): sql = io_bq.bq_schema_to_sql(schema) assert sql == expected + + +@pytest.mark.parametrize( + ("query_or_table", "index_cols", "columns", "filters", "expected_output"), + [ + pytest.param( + "test_table", + [], + [], + ["date_col", ">", "2022-10-20"], + None, + marks=pytest.mark.xfail( + raises=ValueError, + ), + id="raise_error", + ), + pytest.param( + "test_table", + ["row_index"], + ["string_col"], + [ + (("rowindex", "not in", [0, 6]),), + (("string_col", "in", ["Hello, World!", "こんにちは"]),), + ], + ( + "SELECT `row_index`, `string_col` FROM `test_table` AS sub WHERE " + "`rowindex` NOT IN (0, 6) OR `string_col` IN ('Hello, World!', " + "'こんにちは')" + ), + id="table-all_params-filter_or_operation", + ), + pytest.param( + """SELECT + rowindex, + string_col, + FROM `test_table` AS t + """, + ["rowindex"], + ["string_col"], + [ + ("rowindex", "<", 4), + ("string_col", "==", "Hello, World!"), + ], + """SELECT `rowindex`, `string_col` FROM (SELECT + rowindex, + string_col, + FROM `test_table` AS t + ) AS sub WHERE `rowindex` < 4 AND `string_col` = 'Hello, World!'""", + id="subquery-all_params-filter_and_operation", + ), + pytest.param( + "test_table", + [], + ["col_a", "col_b"], + [], + "SELECT `col_a`, `col_b` FROM `test_table` AS sub", + id="table-columns", + ), + pytest.param( + "test_table", + [], + [], + [("date_col", ">", "2022-10-20")], + "SELECT * FROM `test_table` AS sub WHERE `date_col` > '2022-10-20'", + id="table-filter", + ), + pytest.param( + "test_table*", + [], + [], + [], + "SELECT * FROM `test_table*` AS sub", + id="wildcard-no_params", + ), + pytest.param( + "test_table*", + [], + [], + [("_TABLE_SUFFIX", ">", "2022-10-20")], + "SELECT * FROM `test_table*` AS sub WHERE `_TABLE_SUFFIX` > '2022-10-20'", + id="wildcard-filter", + ), + ], +) +def test_to_query(query_or_table, index_cols, columns, filters, expected_output): + query = io_bq.to_query( + query_or_table, + index_cols, + columns, + filters, + ) + assert query == expected_output + + +@pytest.mark.parametrize( + ("query_or_table", "filters", "expected_output"), + [], +) +def test_to_query_with_wildcard_table(query_or_table, filters, expected_output): + query = io_bq.to_query( + query_or_table, + (), # index_cols + (), # columns + filters, + ) + assert query == expected_output diff --git a/tests/unit/session/test_session.py b/tests/unit/session/test_session.py index a161c2df76..bea858e037 100644 --- a/tests/unit/session/test_session.py +++ b/tests/unit/session/test_session.py @@ -398,85 +398,3 @@ def test_session_init_fails_with_no_project(): credentials=mock.Mock(spec=google.auth.credentials.Credentials) ) ) - - -@pytest.mark.parametrize( - ("query_or_table", "columns", "filters", "expected_output"), - [ - pytest.param( - """SELECT - rowindex, - string_col, - FROM `test_table` AS t - """, - [], - [("rowindex", "<", 4), ("string_col", "==", "Hello, World!")], - """SELECT * FROM (SELECT - rowindex, - string_col, - FROM `test_table` AS t - ) AS sub WHERE `rowindex` < 4 AND `string_col` = 'Hello, World!'""", - id="query_input", - ), - pytest.param( - "test_table", - [], - [("date_col", ">", "2022-10-20")], - "SELECT * FROM `test_table` AS sub WHERE `date_col` > '2022-10-20'", - id="table_input", - ), - pytest.param( - "test_table", - ["row_index", "string_col"], - [ - (("rowindex", "not in", [0, 6]),), - (("string_col", "in", ["Hello, World!", "こんにちは"]),), - ], - ( - "SELECT `row_index`, `string_col` FROM `test_table` AS sub WHERE " - "`rowindex` NOT IN (0, 6) OR `string_col` IN ('Hello, World!', " - "'こんにちは')" - ), - id="or_operation", - ), - pytest.param( - "test_table", - [], - ["date_col", ">", "2022-10-20"], - None, - marks=pytest.mark.xfail( - raises=ValueError, - ), - id="raise_error", - ), - ], -) -def test_read_gbq_with_filters(query_or_table, columns, filters, expected_output): - session = resources.create_bigquery_session() - query = session._to_query(query_or_table, columns, filters) - assert query == expected_output - - -@pytest.mark.parametrize( - ("query_or_table", "columns", "filters", "expected_output"), - [ - pytest.param( - "test_table*", - [], - [], - "SELECT * FROM `test_table*` AS sub", - id="wildcard_table_input", - ), - pytest.param( - "test_table*", - [], - [("_TABLE_SUFFIX", ">", "2022-10-20")], - "SELECT * FROM `test_table*` AS sub WHERE `_TABLE_SUFFIX` > '2022-10-20'", - id="wildcard_table_input_with_filter", - ), - ], -) -def test_read_gbq_wildcard(query_or_table, columns, filters, expected_output): - session = resources.create_bigquery_session() - query = session._to_query(query_or_table, columns, filters) - assert query == expected_output From 365d66daebab955506eb42ab60756ca22aed8a80 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 8 May 2024 21:33:31 +0000 Subject: [PATCH 2/2] fix location detection --- bigframes/exceptions.py | 6 ++++++ bigframes/pandas/__init__.py | 3 ++- bigframes/session/__init__.py | 19 +++++++++++++++-- tests/system/small/test_pandas_options.py | 26 ++++++++++++++++------- 4 files changed, 43 insertions(+), 11 deletions(-) diff --git a/bigframes/exceptions.py b/bigframes/exceptions.py index 5caf2aa1df..3ca6d8e1af 100644 --- a/bigframes/exceptions.py +++ b/bigframes/exceptions.py @@ -17,6 +17,12 @@ # NOTE: This module should not depend on any others in the package. +# Uses UserWarning for backwards compatibility with warning without a category +# set. +class DefaultLocationWarning(UserWarning): + """No location was specified, so using a default one.""" + + class UnknownLocationWarning(Warning): """The location is set to an unknown value.""" diff --git a/bigframes/pandas/__init__.py b/bigframes/pandas/__init__.py index 2200fd6aa4..1d6da46fae 100644 --- a/bigframes/pandas/__init__.py +++ b/bigframes/pandas/__init__.py @@ -67,6 +67,7 @@ import bigframes.operations as ops import bigframes.series import bigframes.session +import bigframes.session._io.bigquery import bigframes.session.clients @@ -391,7 +392,7 @@ def _set_default_session_location_if_possible(query): bqclient = clients_provider.bqclient - if bigframes.session._is_query(query): + if bigframes.session._io.bigquery.is_query(query): job = bqclient.query(query, bigquery.QueryJobConfig(dry_run=True)) options.bigquery.location = job.location else: diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index 86d103d83e..89845bb842 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -85,6 +85,7 @@ import bigframes.core.tree_properties as tree_properties import bigframes.core.utils as utils import bigframes.dtypes +import bigframes.exceptions import bigframes.formatting_helpers as formatting_helpers from bigframes.functions.remote_function import read_gbq_function as bigframes_rgf from bigframes.functions.remote_function import remote_function as bigframes_rf @@ -184,12 +185,26 @@ def __init__( if context is None: context = bigquery_options.BigQueryOptions() - # TODO(swast): Get location from the environment. if context.location is None: self._location = "US" warnings.warn( f"No explicit location is set, so using location {self._location} for the session.", - stacklevel=2, + # User's code + # -> get_global_session() + # -> connect() + # -> Session() + # + # Note: We could also have: + # User's code + # -> read_gbq() + # -> with_default_session() + # -> get_global_session() + # -> connect() + # -> Session() + # but we currently have no way to disambiguate these + # situations. + stacklevel=4, + category=bigframes.exceptions.DefaultLocationWarning, ) else: self._location = context.location diff --git a/tests/system/small/test_pandas_options.py b/tests/system/small/test_pandas_options.py index afb75c65e3..c580f926c9 100644 --- a/tests/system/small/test_pandas_options.py +++ b/tests/system/small/test_pandas_options.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +import re from unittest import mock import warnings @@ -69,8 +70,12 @@ def test_read_gbq_start_sets_session_location( assert not bpd.options.bigquery.location # Starting user journey with read_gbq* should work for a table in any - # location, in this case tokyo - df = read_method(query_tokyo) + # location, in this case tokyo. + with warnings.catch_warnings(): + # Since the query refers to a specific location, no warning should be + # raised. + warnings.simplefilter("error", bigframes.exceptions.DefaultLocationWarning) + df = read_method(query_tokyo) assert df is not None # Now bigquery options location should be set to tokyo @@ -146,7 +151,11 @@ def test_read_gbq_after_session_start_must_comply_with_default_location( # Starting user journey with anything other than read_gbq*, such as # read_pandas would bind the session to default location US - df = bpd.read_pandas(scalars_pandas_df_index) + with pytest.warns( + bigframes.exceptions.DefaultLocationWarning, + match=re.escape("using location US for the session"), + ): + df = bpd.read_pandas(scalars_pandas_df_index) assert df is not None # Doing read_gbq* from a table in another location should fail @@ -262,17 +271,18 @@ def test_read_gbq_must_comply_with_set_location_non_US( def test_credentials_need_reauthentication(monkeypatch): # Use a simple test query to verify that default session works to interact - # with BQ + # with BQ. test_query = "SELECT 1" - # Confirm that default session has BQ client with valid credentials - session = bpd.get_global_session() - assert session.bqclient._credentials.valid - # Confirm that default session works as usual df = bpd.read_gbq(test_query) assert df is not None + # Call get_global_session() *after* read_gbq so that our location detection + # has a chance to work. + session = bpd.get_global_session() + assert session.bqclient._credentials.valid + with monkeypatch.context() as m: # Simulate expired credentials to trigger the credential refresh flow m.setattr(session.bqclient._credentials, "expiry", datetime.datetime.utcnow())