From ca17b39f497201a8dfad4bf707e3873ef3cde1db Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Mar 2024 13:04:05 -0600 Subject: [PATCH] feat: add predict by visit to samples/snippets/bqml_getting_started_test.py --- samples/snippets/bqml_getting_started_test.py | 100 +++++++++++++++--- 1 file changed, 83 insertions(+), 17 deletions(-) diff --git a/samples/snippets/bqml_getting_started_test.py b/samples/snippets/bqml_getting_started_test.py index 74dd5d4501..3f1a1453ec 100644 --- a/samples/snippets/bqml_getting_started_test.py +++ b/samples/snippets/bqml_getting_started_test.py @@ -14,7 +14,7 @@ def test_bqml_getting_started(random_model_id): - your_model_id = random_model_id + your_model_id = random_model_id # for example: bqml_tutorial.sample_model # [START bigquery_dataframes_bqml_getting_started_tutorial] from bigframes.ml.linear_model import LogisticRegression @@ -29,8 +29,8 @@ def test_bqml_getting_started(random_model_id): df = bpd.read_gbq_table( "bigquery-public-data.google_analytics_sample.ga_sessions_*", filters=[ - ("_table_suffix", ">=", "20170701"), - ("_table_suffix", "<=", "20170801"), + ("_table_suffix", ">=", "20160801"), + ("_table_suffix", "<=", "20170630"), ], ) @@ -68,7 +68,7 @@ def test_bqml_getting_started(random_model_id): features = bpd.DataFrame( { "os": operating_system, - "isMobile": is_mobile, + "is_mobile": is_mobile, "country": country, "pageviews": pageviews, } @@ -96,9 +96,7 @@ def test_bqml_getting_started(random_model_id): your_model_id, # For example: "bqml_tutorial.sample_model", ) - # The WHERE clause — _TABLE_SUFFIX BETWEEN '20170701' AND '20170801' — - # limits the number of tables scanned by the query. The date range scanned is - # July 1, 2017 to August 1, 2017. This is the data you're using to evaluate the predictive performance + # July 1, 2017 to August 1, 2017 is the data you're using to evaluate the predictive performance # of the model. It was collected in the month immediately following the time # period spanned by the training data. @@ -109,6 +107,7 @@ def test_bqml_getting_started(random_model_id): ("_table_suffix", "<=", "20170801"), ], ) + transactions = df["totals"].struct.field("transactions") label = transactions.notnull().map({True: 1, False: 0}) operating_system = df["device"].struct.field("operatingSystem") @@ -119,7 +118,7 @@ def test_bqml_getting_started(random_model_id): features = bpd.DataFrame( { "os": operating_system, - "isMobile": is_mobile, + "is_mobile": is_mobile, "country": country, "pageviews": pageviews, } @@ -155,7 +154,14 @@ def test_bqml_getting_started(random_model_id): # [1 rows x 6 columns] # [END bigquery_dataframes_bqml_getting_started_tutorial_evaluate] - # [START bigquery_dataframes_bqml_getting_started_tutorial_predict] + # [START bigquery_dataframes_bqml_getting_started_tutorial_predict_by_country] + import bigframes.pandas as bpd + + # Select model you'll use for training. `read_gbq_model` loads model data from a + # BigQuery, but you could also use the `model` object from the previous steps. + model = bpd.read_gbq_model( + your_model_id, # For example: "bqml_tutorial.sample_model", + ) df = bpd.read_gbq_table( "bigquery-public-data.google_analytics_sample.ga_sessions_*", filters=[ @@ -172,24 +178,84 @@ def test_bqml_getting_started(random_model_id): features = bpd.DataFrame( { "os": operating_system, - "isMobile": is_mobile, + "is_mobile": is_mobile, "country": country, "pageviews": pageviews, } ) # Use Logistic Regression predict method to, find more information here in - # [BigFrames](/bigframes/latest/bigframes.ml.linear_model.LogisticRegression#bigframes_ml_linear_model_LogisticRegression_predict) + # [BigFrames](https://cloud.google.com/python/docs/reference/bigframes/latest/bigframes.ml.linear_model.LogisticRegression#bigframes_ml_linear_model_LogisticRegression_predict) + predictions = model.predict(features) - countries = predictions.groupby(["country"])[["predicted_transactions"]].sum() + total_predicted_purchases = predictions.groupby(["country"])[ + ["predicted_label"] + ].sum() + total_predicted_purchases.sort_values(ascending=False).head(10) - countries.sort_values(ascending=False).head(10) + # country # total_predicted_purchases + # United States 220 + # Taiwan 8 + # Canada 7 + # India 2 + # Japan 2 + # Turkey 2 + # Australia 1 + # Brazil 1 + # Germany 1 + # Guyana 1 + # Name: predicted_label, dtype: Int64 - predictions = model.predict(features) + # [END bigquery_dataframes_bqml_getting_started_tutorial_predict_by_country] - total_predicted_purchases = predictions.groupby(["country"])[ - ["predicted_transactions"] + # [START bigquery_dataframes_bqml_getting_started_tutorial_predict_by_visitor] + + model = bpd.read_gbq_model( + your_model_id, # For example: "bqml_tutorial.sample_model", + ) + df = bpd.read_gbq_table( + "bigquery-public-data.google_analytics_sample.ga_sessions_*", + filters=[ + ("_table_suffix", ">=", "20170701"), + ("_table_suffix", "<=", "20170801"), + ], + ) + + operating_system = df["device"].struct.field("operatingSystem") + operating_system = operating_system.fillna("") + is_mobile = df["device"].struct.field("isMobile") + country = df["geoNetwork"].struct.field("country").fillna("") + pageviews = df["totals"].struct.field("pageviews").fillna(0) + full_visitor_id = df["fullVisitorId"] + + features = bpd.DataFrame( + { + "os": operating_system, + "is_mobile": is_mobile, + "country": country, + "pageviews": pageviews, + "fullVisitorId": full_visitor_id, + } + ) + + predictions = model.predict(features) + total_predicted_purchases = predictions.groupby(["fullVisitorId"])[ + ["predicted_label"] ].sum() total_predicted_purchases.sort_values(ascending=False).head(10) - # [END bigquery_dataframes_bqml_getting_started_tutorial_predict] + # fullVisitorId # total_predicted_purchases + # 9417857471295131045 4 + # 0376394056092189113 2 + # 0456807427403774085 2 + # 057693500927581077 2 + # 112288330928895942 2 + # 1280993661204347450 2 + # 2105122376016897629 2 + # 2158257269735455737 2 + # 2969418676126258798 2 + # 489038402765684003 2 + # Name: predicted_label, dtype: Int64 + + +# [END bigquery_dataframes_bqml_getting_started_tutorial_predict_by_visitor]