Before going into the list of PEFT methods, let’s explain how we traditionally adapted our pretrained models for a downstream task. The process involves taking a pre-trained model, which has already learned useful features from a large-scale dataset, and further training it on a specific task using a smaller, task-specific dataset.
The process of traditional fine-tuning is as follows:
In recent years, the use of large language models has revolutionized the field of natural language processing (NLP). These models, such as GPT-3 and BLOOM, are capable of generating human-like text and understanding the context of a language. However, fine-tuning these models traditionally can be computationally expensive and require a lot of memory due to the large number of parameters involved. To tackle this problem, there has been a upsurge in research community to find the efficient way to fine-tune pretrained models for downstream task. These founds and to be found methods are collectively called “Parameter-Efficient Fine-Tuning” methods, PEFT in short.
If one tries to list all the PEFT methods available online, it might take forever. Here, I try to list and provide quick explanation to some popular one. Most of them are extracted from survey paper by Lialin et al., 2023. However, I’ve included some that the paper not listed or some that are published after the survey paper. Basically, the paper categorized PEFT methods based on the conceptual framework underlying the approach.
[Click Methods to Expand their Explanations ]
In conclusion, the pursuit of parameter efficiency in fine-tuning LLMs is a critical aspect of contemporary machine learning research. By successfully leveraging the existing parameters of a pre-trained model and minimizing the addition or modification of new parameters, PEFT methods offer a promising solution to the computational and memory challenges associated with fine-tuning large models. As we continue to push the boundaries of what’s possible with machine learning and artificial intelligence, these methods will undoubtedly play a pivotal role in shaping the future of the field.
]]>import numpy as np
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
np.inner(x, y)
32
np.dot(x, y)
32
From the both np.inner and np.dot function, we got the same result. When popular libraries evaluates both function to same value, it is tempting to think of them as same. Not only that, many books and experts, particularly in machine learning, use the term “inner product” and “dot product” interchangeably, assuming they refer to the same mathematical operation. I also used them in similar fashion until I read the book “Mathematics for Machine Learning” by M. P. Deisenroth, A. A. Faisal, and C. S. Ong. In this blog post, I aim to clear up this misconception and shed light on the differences between these fundamental product in linear algebra.
Why not start with the definition of dot product, so called “inner product”, most of us know. The dot product of vectors \(\mathbf{x}, \mathbf{y} \epsilon \mathbb{R}^{n}\) is defined as:
\[\mathbf{x}\cdotp\mathbf{y} = \mathbf{x}^{\top}\mathbf{y}=\sum_{i=1}^{n}x_iy_i\]where \(x_i\) and \(y_i\) are the \(i\)th elements of vectors \(\mathbf{x}\) and \(\mathbf{y}\) respectively. This is the dot product we use to calculate length of a vector and angle or similarity between two vectors. The NumPy functions, np.dot and np.inner did the same operation in above example. However, inner product is much more than this definition of dot product. Spoiler alert, the dot product is a specific instance of inner product. We’ll see how shortly.
Please bear with me while I give you the formal(mathematical) definition of the inner product using the fundamental concepts of linear algebra. Let’s start with the linear mapping:
Consider two vector spaces \(V\) and \(W\), a mapping \(Φ:V\rightarrow W\) is called linear mapping if
\[\forall \mathbf{x}, \mathbf{y} \epsilon V \forall\lambda,\varphi \epsilon \mathbb{R}: \Phi(\lambda\mathbf{x} + \varphi\mathbf{y})=\lambda\Phi(\mathbf{x})+\varphi\Phi(\mathbf{y})\]In simple and intuitive terms, a tranformation function between two vector spaces that preserves the origin, collinearity and parallelism is called linear mapping.
Let’s expand the linear mapping to two arguments function called bilinear mapping:
Consider two vector spaces \(V\) and \(W\), a mapping with two arguments \(Ω: V\times V\rightarrow W\) is called bilinear mapping if \(\forall \mathbf{x}, \mathbf{y}, \mathbf{z} \epsilon V \forall\lambda,\varphi \epsilon \mathbb{R}:\)
\[\Omega(λ\mathbf{x}+\varphi\mathbf{y}, \mathbf{z})=\lambda\Omega(\mathbf{x}, \mathbf{z})+\varphi\Omega(\mathbf{y}, \mathbf{z})\] \[\Omega(\mathbf{x}, λ\mathbf{y}+\varphi\mathbf{z})=\lambda\Omega(\mathbf{x}, \mathbf{y})+\varphi\Omega(\mathbf{x}, \mathbf{z})\]With these two definitions in hand, we can define the inner product formally as:
Consider a vector space \(V\), a bilinear mapping \(\Omega:V\times V\rightarrow \mathbb{R}\) is called inner product on \(V\) if it satisfies the following two conditions:
From the definition, we can simply say that inner product can be any functions that takes two vectors as arguments, outputs a real number, is symmetric meaning we can interchange the arguments and evaluates to positive real number when both arguments are same.
To find the relationship between dot product and inner product, let’s expand the definition of inner product in terms of basis of the vector space and coordinate representation of vectors.
Consider \(B=(\mathbf{b}_1, \dots, \mathbf{b}_n)\) be the basis of the \(n\)-dimensional vector space \(V\). Also, let’s consider two vectors \(\mathbf{x}, \mathbf{y}\epsilon V\) in terms coordinate vectors \(\hat{\mathbf{x}}, \hat{\mathbf{y}}\) respectively with respect to the basis \(B\).
\[\mathbf{x}=\sum_{i=1}^{n}\varphi_i\mathbf{b}_i, \space \mathbf{y}=\sum_{j=1}^{n}\lambda_j\mathbf{b}_j\] \[\hat{\mathbf{x}}=\begin{bmatrix}\varphi_1\\\vdots\\\varphi_n\end{bmatrix}, \space \hat{\mathbf{y}}=\begin{bmatrix}\lambda_1\\\vdots\\\lambda_n\end{bmatrix}\]where \(\varphi_i\) and \(\lambda_j\) is the \(i\)th and \(j\)th elements of the coordinate representation/coordinate vectors \(\hat{\mathbf{x}}\) and \(\hat{\mathbf{y}}\) respectively.
Now, we can write the inner product \(\Omega\), typically represented as \(\left \langle \cdot, \cdot \right \rangle\), as the following:
\[\left \langle \mathbf{x}, \mathbf{y} \right \rangle=\left \langle \sum_{i=1}^{n}\varphi_i\mathbf{b}_i, \sum_{j=1}^{n}\lambda_j\mathbf{b}_j \right \rangle\]Using the property of bilinearity defined above:
\[\left \langle \mathbf{x}, \mathbf{y} \right \rangle=\sum_{i=1}^{n}\sum_{j=1}^{n}\varphi_i\left \langle \mathbf{b}_i, \mathbf{b}_j \right \rangle\lambda_j=\hat{\mathbf{x}}^\top\mathbf{A}\hat{\mathbf{y}}\]where \(A_{ij}=\left \langle \mathbf{b}_i, \mathbf{b}_j \right \rangle\) i.e., the inner product of the basis \(\mathbf{b}_i\) and \(\mathbf{b}_j\).
From here, we can see that we can define inner product using a square matrix \(\mathbf{A}\epsilon\mathbb{R}^{n\times n}\) that is symmetric and positive definite.
Now, we can finally relate dot product with the inner product. When the symmetric and positive definite matrix that governed the inner product is an identity matrix, the inner product can be called as dot product. Formally, the dot product on \(n\)-dimensional vector space \(V\) is defined as a bilinear mapping \(\Omega: V\times V\rightarrow \mathbb{R}\) where:
\[\forall \mathbf{x},\mathbf{y}\epsilon V: \Omega(\mathbf{x},\mathbf{y})=\mathbf{x}^\top\mathbf{I}_n\mathbf{y}=\mathbf{x}^\top\mathbf{y}\]where \(\mathbf{I}_n\) is an \(n\times n\) identity matrix.
Let’s implement a full-fledged inner product using NumPy:
def inner_product(x, y, A):
"""Returns inner product of x and y governed by matrix A
Arguments:
x = 1-D array of shape (n,)
y = 1-D array of shape (n,)
A = 2-D array of shape (n, n)
"""
x_vec, y_vec = x[:, np.newaxis], y[:, np.newaxis]
xTA = np.matmul(x_vec.T, A)
in_prod = np.matmul(xTA, y_vec)
return in_prod
Define example vectors and matrices:
x = np.array([1, 2])
y = np.array([3, 4])
# An example symmetric positive definite matrix
A = np.array(
[[9, 6],
[6, 5]]
)
# An indentity matrix
I = np.array(
[[1, 0],
[0, 1]]
)
Calculate inner product governed by A
inner_product(x, y, A)
array([[127]])
Calculate inner product governed by I. We’ll get the dot product.
inner_product(x, y, I)
array([[11]])
Let’s verify with the numpy implementation.
np.dot(x, y)
11
You’ve come to the end of the blog post. This post has illuminated the distinction and relationship between the inner product and the dot product in linear algebra. While most of us use them interchangeably, it’s crucial to understand that the inner product is a broader concept encompassing a bilinear mapping governed by a symmetric, positive definite matrix. On the other hand, the dot product is a specific instance of the inner product when the governing matrix is the identity matrix.
]]>Before we get started, let’s set up our environment. We’ll be using Python along with some essential libraries.
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
The first block of code imports the necessary libraries. numpy and pandas are widely used for numerical computations and data manipulation, while load_iris is a convenient function for loading the Iris dataset.
Next, for demo purpose, we load the Iris dataset and convert it into a Pandas DataFrame.
iris = load_iris()
df = pd.DataFrame(iris["data"], columns=iris["feature_names"])
df["target"] = iris["target"]
df.head()
| sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
|---|---|---|---|---|---|
| 0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
| 1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
| 2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
| 3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
| 4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
Here, we’ve converted the Iris data into a structured DataFrame, making it easier to work with. This DataFrame includes features like sepal length, sepal width, petal length, petal width, and a target variable.
Now, let’s upload our DataFrame to AWS S3.
import boto3
# You can use your custom bucket and prefixes
bucket, key = "sagemaker-lal-stage", "iris.csv"
s3_client = boto3.client("s3")
s3_client.put_object(
Bucket=bucket,
Key=key,
Body=df.to_csv(index=False),
ContentType="text/csv",
)
By executing this code, we’re placing the data in an S3 bucket, making it accessible for further processing. You can upload your data of any size to S3 any method. Just note the bucket name and prefixes of the data files.
The next step involves connecting to Snowflake using Snowpark, a powerful tool for data processing in Snowflake.
from snowflake.snowpark import Session
connection_parameters = {
"account": "####SF_ACCOUNT_NAME######",
"user": "#####USER#######",
"password": "#####PASSWORD#######",
"warehouse": "#####WAREHOUSE########",
"database": "#####DATABASE########",
"schema": "#####SCHEMA(PUBLIC)########",
}
session = Session.builder.configs(connection_parameters).create()
We establish a connection using session instantiated with the specified parameters. This allows us to interact with the Snowflake database.
Now, let’s create a temporary table in Snowflake where we’ll load our data from S3.
table_name = "iris_dataset"
session.sql(f"""create temporary table {table_name} (
SEPAL_LENGTH float,
SEPAL_WIDTH float,
PETAL_LENGTH float,
PETAL_WIDTH float,
TARGET integer
)"""
).collect()
This code creates a temporary table in Snowflake with the same structure as our DataFrame. This is where we’ll be loading our data. Temporary table are destroyed when our session is terminated. In my experience, temporary table are really fast in comparison to standard table or transient table available in snowflake. So, it is preferable to work on temporary table if we don’t need data persisting on Snowflake.
With our temporary table in place, let’s copy the data from S3 into Snowflake.
session.sql(f"""copy into {table_name}
from 's3://{bucket}/{key}'
credentials=( AWS_KEY_ID='#######AWS_KEY_ID#######' AWS_SECRET_KEY='#######AWS_SECRET_KEY#######')
file_format=(TYPE=CSV COMPRESSION=NONE SKIP_HEADER=1)
"""
).collect()
This command transfers the data from S3 to our Snowflake table using the specified credentials(which have permission to get, put and delete object in AWS S3). All files with prefix {key} in bucket {bucket} are processed by the above command. Note that while uploading data to s3, we uploaded with header. So, we tell Snowflake to ignore header with SKIP_HEADER=1. Also, no compression was done in the uploaded data. You can mention compression type if your data is compressed in any way.
Now that our data is in Snowflake, we can perform any operation we like. We’ll be using Snowpark’s DataFrame capabilities for this.
sdf = session.table(table_name)
This line creates a Snowpark DataFrame from our Snowflake table.
We can display the table by changing the Snowpark DataFrame to Pandas Dataframe.
sdf.to_pandas()
| SEPAL_LENGTH | SEPAL_WIDTH | PETAL_LENGTH | PETAL_WIDTH | TARGET | |
|---|---|---|---|---|---|
| 0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
| 1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
| 2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
| 3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
| 4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
| ... | ... | ... | ... | ... | ... |
| 145 | 6.7 | 3.0 | 5.2 | 2.3 | 2 |
| 146 | 6.3 | 2.5 | 5.0 | 1.9 | 2 |
| 147 | 6.5 | 3.0 | 5.2 | 2.0 | 2 |
| 148 | 6.2 | 3.4 | 5.4 | 2.3 | 2 |
| 149 | 5.9 | 3.0 | 5.1 | 1.8 | 2 |
150 rows × 5 columns
For demo, let’s start by computing pairwise correlations between sepal length and sepal width. For reference, the formula of Pearson Correlation is: \(r=\frac{\sum ( x_{i} -\overline{x})( y_{i} -\overline{y})}{\sqrt{\sum ( x_{i} -\overline{x})^{2}\sum ( y_{i} -\overline{y})^{2}}}\)
from snowflake.snowpark import DataFrame, Column
from snowflake.snowpark import functions as spf
def pair_correlation(df: DataFrame, x: str, y: str) -> DataFrame:
# broadcast mean using `over`
x_diff = df[x] - spf.mean(df[x]).over()
y_diff = df[y] - spf.mean(df[y]).over()
# Store results in columns
df = df.with_columns(
["x_diff", "y_diff"],
[x_diff, y_diff],
)
numerator = spf.sum(df["x_diff"]*df["y_diff"])
denominator = spf.sqrt(spf.sum(spf.pow(df["x_diff"], 2))*spf.sum(spf.pow(df["y_diff"], 2)))
# prepare dataframe with pair values in first two columns and correlation value in last column
return df.select(
spf.lit(x).alias("FEAT1"),
spf.lit(y).alias("FEAT2"),
(numerator/denominator).alias("VALUE"),
)
Here, we define a function pair_correlation that accepts the Snowpark Dataframe and two columns name whose correlation is to be determined, and it will return a new dataframe with results. This function leverages Snowpark’s powerful functions for data manipulation.
pair_correlation(sdf, "SEPAL_LENGTH", "SEPAL_WIDTH").to_pandas()
| FEAT1 | FEAT2 | VALUE | |
|---|---|---|---|
| 0 | SEPAL_LENGTH | SEPAL_WIDTH | -0.11757 |
You can continue with other analysis as you require.
With the analysis complete, we close our Snowflake session.
session.close()
Additionally, we clean up our S3 bucket by deleting the uploaded data.
s3_client.delete_object(
Bucket=bucket,
Key=key,
)
And there you have it! We’ve walked through the steps of loading data into Snowflake from AWS S3, conducting various correlation analyses, and finally, cleaning up our environment.
This powerful combination of Snowpark and Snowflake opens up a world of possibilities for handling large tabular datasets. Whether you’re a data scientist or a machine learning engineer, having these tools in your arsenal can significantly enhance your data processing capabilities. Experiment with your own datasets and unlock valuable insights!
Happy coding!
]]>In the ever-growing realm of machine learning, managing complex workflows efficiently is a crucial aspect. One such project that I recently did is a comprehensive, scalable SageMaker pipeline designed for training, deploying, and monitoring multiple models on varying datasets to a single endpoint. Built on the robust Amazon SageMaker platform, this project offers an end-to-end solution for defining and managing machine learning workflows. It covers everything from data preprocessing, model training, tuning, to the final deployment phase.
The beauty of this pipeline lies in its scalability and cost-effectiveness. It employs multi-model endpoints, which use a shared serving container and the same fleet of resources to host all models. This approach significantly cuts down hosting costs by improving endpoint utilization compared to using single-model endpoints. The pipeline is meticulously designed to tune, train, and deploy multiple models on individual datasets, providing users with the flexibility to work with the data and infrastructures they prefer.
Another aspect of this project is that all the resources involved in implementing the pipeline are orchestrated using a Terraform script. In the upcoming sections of this blog post, we will delve deeper into the implementation details of this project exploring its architecture, codes and resources. Buckle up to gain a deeper understanding of this fascinating SageMaker pipeline project!
Now that we have a high-level understanding of the project, let’s dive deeper into the specific architecture of the pipeline. The detailed architecture diagram, presented next, provides a visual representation of the entire system, showcasing the relationships between different components and steps involved in the pipeline. This diagram will help us comprehend the structure of the system and the sequential order of events in the pipeline.

The pipeline of this project is divided into two sections: the Modeling Pipeline and the Deployment Pipeline. Let’s delve deeper into the steps involved in each of these sections:
The Modeling Pipeline is the initial phase where the model is trained and prepared for deployment. Here’s what happens in each step:
Preprocessing: The raw data is cleaned and transformed into a suitable format for model training. This step is crucial as quality data is a prerequisite for training an effective model. It involves handling missing values, removing outliers, encoding categorical variables, and so on.
Hyperparameter Tuning: This step involves searching for the best hyperparameters for the model. Hyperparameters are the configuration variables that govern the training process of a model. For instance, the learning rate in a neural network is a hyperparameter. A hyperparameter tuning algorithm, like grid search or random search, is used to explore different combinations of hyperparameters to find the optimal set that minimizes the loss function.
Refit Best Model: After the best hyperparameters are found, the model is trained again using these hyperparameters. This step ensures that the model is the best version of itself before it is evaluated and potentially deployed.
Evaluate Best Model: The performance of the best model is evaluated in this step. This is done using a holdout validation set that the model has never seen before. Evaluation metrics like accuracy, precision, recall, or AUC-ROC (for classification tasks), or MSE, MAE, R2 score (for regression tasks) are computed.
Registration Metric Check: The model’s performance metrics are checked against a predefined threshold or previous models’ performance to decide whether to register the model in the registry. This step ensures that only models that meet the quality standards are registered for deployment.
Model Package Registration Step: If the model passes the registration metric check, it is registered to the SageMaker Model Registry. This registry serves as a repository where trained models are stored before they are deployed.
The Deployment Pipeline is the second phase where the registered models are deployed for serving predictions.
The pipeline listens to approval/rejection events in the SageMaker Model Registry via AWS EventBridge. An approval event triggers the deployment of the approved model.
The approved models are deployed to an endpoint’s multi-model artifact location in S3 using a Lambda function. AWS Lambda is a serverless compute service that lets you run your code without provisioning or managing servers.
Another AWS Lambda function with a Function URL is used to interact with the SageMaker endpoint. This function can be used to send data to the endpoint for inference and receive predictions.
The scalability of the SageMaker endpoint is managed by AWS Application Auto Scaling. This service can automatically adjust the capacity of the endpoint to maintain steady, predictable performance at the lowest possible cost.
Overall, these steps ensure a streamlined process from data preprocessing to model deployment, providing an efficient and scalable solution for machine learning workflows.
Now that we’ve understood the architecture of the pipeline, it’s time to delve into the implementation details. In the following section, we will walk through the process of implementing the pipeline to train separate models for two different datasets: the Breast Cancer dataset and the Bank Note Authentication dataset. The pipeline will do a hyperparameter search across multiple SKLearn, XGBoost and LightGBM classifiers and deploys the best-found model for each dataset in a single scalable multimodel Sagemaker endpoint. You’ll find that you can adjust the scripts to include more datasets and custom preprocessing and training steps as you go through the blog. Our implementation will ensure that all steps from beginning to the end in the pipeline are scalable.
In this blog, I’ll describe the main code snippets and components of the implementation. The complete implementation can be found in the following repo. This post is just a complementary to the Github Repo. Please read the README and USAGE_GUIDELINE.
Docker images are the real backbone of the Sagemaker; whatever processing we do in each step of pipeline, we do it inside the container. In this section, we will discuss how to extend a prebuilt Sagemaker SKLearn Image. The main purpose of this is to add additional libraries that are not included in the prebuilt image, in this case, LightGBM and XGBoost. You can create custom Docker image from scratch if you want.
The Dockerfile for extending the SageMaker SKLearn Image can be written as follows:
From this Dockerfile, we then create docker image and push it to ecr using the following script. The script requires aws account id, aws region, and image name passed to it while running. Don’t worry about running this script for now. We run this script using Terraform.
Now that we have Docker container ready on which we can run our modeling jobs, let’s create scripts for those jobs. Let’s begin with creating a python script where we’ll define the pipeline with Sagemaker Python SDK.
We’ll create a sagemaker pipeline definition python called pipeline.py. This file will define all the steps in Modeling Pipeline. For now, let’s import necessary dependencies, create arguments and sagemaker sessions necessary for the pipeline. Further down, we’ll discuss the implementation of each step.
We’ll prepare a python script for downloading, splitting, and saving two specific datasets: the Breast Cancer and the Banknote Authentication datasets.The main block of the script uses argparse to parse command line arguments for selecting the dataset, specifying the test size, random state, stratification, target column name and file name. It then calls the appropriate function to get the requested dataset.
Finally, the script creates directories for saving the training and testing datasets and saves these datasets as CSV files. You can modify the script to include other datasets and preprocessing steps. You just need to take care of dataset argument passed to the script and process that dataset only.
We’ll call this script from preprocessing step in the pipeline. In the pipeline.py, we define Preprocessing step using Sagemaker ScriptProcessor which runs the above script. We’ll also define parameters that the user passes during the pipeline execution. Also, note that we have specified where the output artifacts like train/test split of the preprocessing steps should be dumped so that sagemaker will upload them to s3.
Hyperparameter tuning is basically running multiple training with different combination of model hyperparameters and selecting the model hyperparameter which gives the best result. When we do tuning with k fold cross-validation, we select models that has highest average performance across the folds. The tuning with cross-validation and training steps differs from each other in such a way that we look for peformance metric in tuning but train the model in whole set in training. We’ll prepare script which can do both based on the cross-validation argument flag passed to it. Also, note that we’re not looking for best hyperparameter of a single model in this pipeline. Models including RandomForest, Naive Bayes, NeuralNet, LogisticRegression, XGBoost, and LightGBM competes with each other during tuning. So for each models we create script that does cross-validation and training on whole dataset. For example, Following is the script for LightGBM:
You can find script for other estimators in my github repo.
Now the scripts are ready and later we’ll upload them to s3 using terraform, we can work on implementing tuning step in the pipeline. First we’ll create a dictionary of estimators where key is estimator name that match with the script name for training and value is sagemaker estimator. Along with the paramters required for the tuning steps, we define the hyperparameter search space for each estimators. Please have a look at the following code snippet.
With those estimator and hyperparameters definition, we can create tuning step of the pipeline in the following way.
After tuning job is completed, we need to retrain the model with the best found parameters on the whole training dataset. For that we’ll create a script called refit.py. The Python script is essentially a command-line interface for launching AWS SageMaker training jobs. It extracts hyperparameters from specified best training job, disables cross-validation, and runs a Python script associated with a chosen algorithm using those hyperparameters. Have a look at the script.
We can use this script to create best estimator in sagemaker and add Training Step in the pipeline which will act as Refit Best Model step in the pipeline.
Now that we’ve trained the model in best found parameters, we can measure it’s performance in test split we created in the preprocessing step. Let’s create a evaluation script.The provided Python script is designed to evaluate the performance of a binary classification machine learning model. It begins by parsing command-line arguments for the test data file, the metric to register, the features, and the target feature. The script then extracts a model from a tarball file, loads the test data from a CSV file, and prepares the test data for evaluation. The model is used to make predictions on the test data and calculate several evaluation metrics, including precision, recall, accuracy, f1 score, Roc Auc score, and confusion matrix. Finally, a report including all the calculated metrics is generated and saved to a JSON file.
Next step is to add evaluation step in the pipeline. Following is the code snippet for that. The snippet initiates by defining parameters for model registration in the registry, including the evaluation metric, the minimum threshold for the metric, and the approval status. It then sets up an evaluation step to assess the best model using a ScriptProcessor object, which runs a Python script with specified arguments, inputs, and outputs. The evaluation step is encapsulated in a ProcessingStep` object, which signifies this step in the SageMaker pipeline.
After the evaluation is complete, if the model performance exceeds some predefined threshold, we can register the model in model registry. The follwing provided Python code snippet leverages the Amazon SageMaker Python SDK to create and register a machine learning model. Initially, it defines the model metrics and creates an instance of the model using the SKLearnModel class. The model is then registered using its register method, where important parameters like the content types for the request and response, instances for inference and transformation, and the model metrics are specified. A condition check is implemented to verify if the model’s evaluation metric value meets the required threshold. Depending on the condition’s outcome, the model registration step or a failure step is executed. This code provides a thorough procedure for setting up a machine learning model, registering it, and evaluating its performance.
This concludes the modeling pipeline. Now, we’ll delve into the deployment pipeline.
The Deployment Pipeline is the phase where approved machine learning models, registered in the SageMaker Model Registry, are deployed for serving predictions. The pipeline uses AWS Lambda functions to deploy the models to an S3 location and to interact with the SageMaker endpoint for sending data and receiving predictions. AWS Application Auto Scaling manages the scalability of the endpoint, ensuring steady performance at the lowest cost. This process provides an efficient and scalable solution for machine learning workflows. Let’s begin the pipeline defintion with the lambda function that handles the events in Sagemaker Model Registry like model approval or rejection and updates the model for each dataset in serving.
The Python script below is designed to manage multiple machine learning models in the Amazon Web Services (AWS) environment, specifically using Amazon SageMaker and Amazon S3. It uses the Boto3 library, the Python SDK for AWS, to interact with these services. The script fetches the latest approved model package from a model group, stores a map of group names(dataset) to latest models in an S3 bucket, updates the model in endpoint model artifact S3 bucket, and handles EventBridge events that are triggered when the models are approved or rejected in the sagemaker model registry.
The following Python script is an AWS Lambda function designed to interact with a SageMaker endpoint. It imports necessary libraries, sets up AWS clients, and defines two functions. The function get_groupname2model_map retrieves a JSON object from an S3 bucket mapping group names(dataset) to latest model names. The main function, lambda_handler, retrieves the group-to-model map, and checks if the specified model name exists in the map. If it does, it invokes the SageMaker endpoint with that model as target model, passing in the data from the event body, and returns the result with a 200 status code. If the model name doesn’t exist, it returns a 404 status code with an error message. Essentially, this script serves as an interface between an HTTP request and a multiple machine learning model hosted on AWS SageMaker.
We’ll need a sagemaker endpoint along with model and endpoint configuration in multi-model mode. These resources will be created from Terraform which is discussed in the following section.
We’ll create all the resources including AWS Infrastructure, Docker Images, Pipeline Definition, etc. using the following terraform file. Here is the description for each resources used in this Terraform.
build_and_push.sh script we created at the beginning using Terraform’s null_resource.pipeline.py file is ran again with the Terraform’s null_resource.MultiModel mode is created. It will points to a location is the S3 bucket where all the models for endpoint is stored. Our deployer lambda function will update the model in this location.We’ve come to the conclusion of the blog. Here, we created a scalable end-to-end ML Pipeline. For usage description, please refer to my github repo.
]]>Welcome to my latest blog post! Today, I am excited to share my recent exploration into the fascinating world of reinforcement learning, specifically focusing on the multi-armed bandit problem and its various solutions. As a foundation for my implementation, I closely followed the insightful book, Reinforcement Learning: An Introduction (second edition) by Richard S. Sutton and Andrew G. Barto. In this post, I will walk you through my journey, discussing key concepts and algorithms presented in Chapter 2 of the book, while also providing you with code examples and explanations to help you grasp these intriguing topics. So, let’s dive into the world of multi-armed bandits and reinforcement learning together!
The Multi-Armed Bandit problem is a classic reinforcement learning challenge that exemplifies the exploration-exploitation tradeoff dilemma. Imagine a gambler(an agent in terms of RL terminology) in front of a row of slot machines, also known as “one-armed bandits.” The gambler needs to decide which machines to play, how many times to play each machine, in which order to play them, and whether to continue with the current machine or try a different one. Each machine provides a random reward from a probability distribution specific to that machine, which is unknown to the gambler. The objective is to maximize the sum of rewards earned through a sequence of lever pulls.
The critical tradeoff the gambler faces at each trial is between “exploitation” of the machine that has the highest expected payoff and “exploration” to gather more information about the expected payoffs of the other machines. In practice, multi-armed bandits have been used to model problems such as managing research projects in a large organization or optimizing marketing strategies.
In the following sections, I will delve into the implementation of the Multi-Armed Bandit problem. This implementation will follow the order of problems and solutions presented in the book by Sutton and Barto. Throughout this process, I will make an effort to compare the performance of various algorithms with each other, allowing us to evaluate their effectiveness in addressing the exploration-exploitation tradeoff.
Let’s begin by importing the libraries and setting them up!
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
# For each multiarmed bandit experiment we'll have two plots, displayed horizontally
matplotlib.rcParams['figure.figsize'] = [20, 5]
matplotlib.rcParams['axes.prop_cycle'] = matplotlib.cycler(color=['g', 'b', 'r', "y"])
The simplest setting for the multi-armed bandit problem is when the reward for choosing an arm remains constant over time. In this scenario, each arm has a fixed probability distribution for its rewards, and the gambler’s objective remains the same: to maximize the cumulative reward over a series of trials.
class StationaryMultiArmedBandit:
"""
Represents a stationary multi-armed bandit problem.
Attributes:
k (int): Number of arms.
runs (int): Number of independent runs.
random_state (int, optional): Random seed for reproducibility.
"""
def __init__(
self,
k,
runs,
random_state=None,
):
self.k = k
self.runs = runs
self.random_state = random_state
self.setup()
def setup(self):
"""Set up the seed for reproducibility and reward distribution"""
self.nprandom = np.random.RandomState(self.random_state)
self.q_star = self.nprandom.normal(
loc=0.0,
scale=1.0,
size=(self.runs, self.k),
)
def get_reward(self, action):
"""Given the action, return the reward"""
reward = self.nprandom.normal(
loc=self.q_star[np.arange(self.runs), action],
scale=1.0,
)
return reward
def get_correct_action(self):
"""
Get the correct action for each run.
Correct action for each run is the one with highest mean reward
"""
return self.q_star.argmax(axis=1)
def plot_reward_distribution(self, run=0):
"""Plot the reward distribution for the given run."""
samples = self.nprandom.normal(
loc=self.q_star[run],
scale=1.0,
size=(10_000, self.k),
)
plt.violinplot(samples, showmeans=True)
plt.xlabel('Action')
plt.ylabel('Reward Distribution')
plt.show()
For all experiments, we’ll aggregate the performance across 2000 independent 10-armed bandit runs. Let’s create variables to store these numbers.
runs = 2000
k = 10
Now let’s create an instance of Stationary Multi-Armed Bandit problem, and see its reward distribution. The agent will be unknown to this reward distribution. The ideal agent should be able to learn the correct action that gives the highest reward.
st_bandit = StationaryMultiArmedBandit(k=k, runs=runs)
st_bandit.plot_reward_distribution(run=0)
print(f"Correct action for run 0: {st_bandit.get_correct_action()[0] + 1}")

Correct action for run 0: 5
An agent in reinforcement learning is an entity that learns the actions that yield the highest reward in the long run. The agent has two primary roles:
One of the challenges faced by the agent is the exploration-exploitation dilemma. In this dilemma, the agent must decide whether to explore new actions to gain more knowledge about the environment or exploit its current knowledge to maximize immediate rewards. Striking a balance between exploration and exploitation is critical for the agent’s success in the long run, as excessive exploration may lead to suboptimal rewards, while excessive exploitation may prevent the agent from discovering better actions. Various types of agents can be developed based on how they handle this exploration-exploitation trade-off, and in our implementation, we will compare different methods to understand their effectiveness in addressing this challenge.
The epsilon-greedy sample average method addresses the exploration-exploitation dilemma by choosing between exploration and exploitation randomly. With a probability of epsilon, the agent selects a random action for exploration, while with a probability of 1-epsilon, the agent exploits the current best action based on its estimated action values.
In comparison to the pure greedy algorithm, which always selects the action with the highest estimated value, the epsilon-greedy algorithm introduces a level of exploration, allowing the agent to discover potentially better actions and improve its long-term performance.
The sample average estimate is a technique used to update the estimated action values. For each action, the agent maintains a running average of the rewards it has received when selecting that action. When a new reward is received, the agent updates its estimate by incorporating the new reward into the running average. This way, the agent continuously refines its estimates based on its experience, which in turn helps it make better decisions in the exploration-exploitation trade-off.
class EpsilonGreedySampleAverageAgent:
"""
An epsilon-greedy agent using sample-average method for action value estimation.
Attributes
----------
k : int
Number of actions.
runs : int
Number of independent runs.
epsilon : float, optional
Probability of choosing a random action (exploration), default is 0.1.
random_state : int, optional
The random number generator seed to be used, default is None.
"""
def __init__(
self,
k,
runs,
epsilon=0.1,
random_state=None,
):
self.k = k
self.runs = runs
self.epsilon = epsilon
self.random_state = random_state
self.setup()
def setup(self):
"""Initialize the Q and N arrays for action value estimation and action counts."""
self.nprandom = np.random.RandomState(self.random_state)
self.Q = np.zeros((self.runs, self.k))
self.N = np.zeros((self.runs, self.k))
def get_action(self):
"""Choose an action based on epsilon-greedy policy."""
greedy_action = np.argmax(
self.nprandom.random(self.Q.shape) * (self.Q==self.Q.max(axis=1, keepdims=True)), # breaking ties randomly
axis=1
)
random_action = self.nprandom.randint(0, self.k, size=(self.runs, ))
action = np.where(
self.nprandom.random((self.runs, )) < self.epsilon,
random_action,
greedy_action,
)
return action
def get_step_size(self, action):
"""Calculate the step size for updating action value estimates.
For sample average method we return 1/number of times the action is choosen until current step"""
return 1/self.N[np.arange(self.runs), action]
def update(self, action, reward):
"""Update the action value estimates based on the chosen action and received reward."""
self.N[np.arange(self.runs), action] += 1
step_size = self.get_step_size(action)
self.Q[np.arange(self.runs), action] += (reward - self.Q[np.arange(self.runs), action])*step_size
Now, we’ll create a testbed which will a run episode of parley between agent and k-arm bandit environment for the provided number of steps. We’ll also include a function that plots the average reward agent is receiving and percentage optimal action the agent is taking in each steps.
class MultiArmedBanditTestBed:
"""A test bed for running experiments with multi-armed bandits and agents.
Attributes:
bandit (object): A multi-armed bandit object.
agent (object): An agent object.
steps (int): The number of steps for the experiment.
"""
def __init__(
self,
bandit,
agent,
steps,
):
self.bandit = bandit
self.agent = agent
self.steps = steps
def run_experiment(self):
"""Runs the experiment for the given number of steps and returns the average rewards and optimal actions.
Returns:
tuple: A tuple containing two lists: average rewards and average optimal actions for each step.
"""
avg_reward = []
avg_optimal_action = []
for _ in range(self.steps):
action = self.agent.get_action()
reward = self.bandit.get_reward(action)
self.agent.update(action, reward)
correct = action == self.bandit.get_correct_action()
avg_reward.append(reward.mean())
avg_optimal_action.append(correct.mean())
return avg_reward, avg_optimal_action
@classmethod
def run_and_plot_experiments(cls, steps, exp_bandit_agent_dict):
"""Runs multiple experiments and plots the results.
Args:
steps (int): The number of steps for the experiments.
exp_bandit_agent_dict (dict): A dictionary with labels as keys and (bandit, agent) tuples as values.
"""
fig, (ax_reward, ax_optimal_action) = plt.subplots(nrows=1, ncols=2)
for label, (bandit, agent) in exp_bandit_agent_dict.items():
test_bed = cls(bandit, agent, steps)
avg_reward, avg_optimal_action = test_bed.run_experiment()
ax_reward.plot(avg_reward, label=label)
ax_optimal_action.plot(avg_optimal_action, label=label)
ax_reward.set_ylabel("Average reward")
ax_reward.set_xlabel("Steps")
ax_optimal_action.set_ylabel("% Optimal Action")
ax_optimal_action.set_xlabel("Steps")
ax_reward.legend()
ax_optimal_action.legend()
plt.show()
Let’s run three agents: greedy, \(\epsilon=0.1\)-greedy, and \(\epsilon=0.01\)-greedy agent on the stationary 10-arm bandit environment.
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=10_000,
exp_bandit_agent_dict={
"greedy": (
StationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0)
),
"epsilon=0.1": (
StationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.1)
),
"epsilon=0.01": (
StationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.01)
),
},
)

In the plots above, we can clearly see that epsilon greedy agent outperforms the pure greedy agent because greedy agent did no exploration and got stuck on the suboptimal action. On comparing the two epsilon greedy agent, we can see that agent with higher epsilon explored more and got better performance at initial stage. But, the agent with low epsilon value outperforms the agent with high epsilon agent in the long run. This clearly show the challenges of finding the balance between exploration and exploitation.
Let’s shift the multi-armed bandit problem a bot towards a realistic full reinforcement learning paradigm. The bandit implemented earlier was stationary as its reward distribution never changed during its lifetime. Let’s implement a bandit where the reward distribution changes with steps the agent takes.
class NonStationaryMultiArmedBandit(StationaryMultiArmedBandit):
def setup(self):
"""For stationary bandit, we start with same average reward for all actions. Let's say zero."""
self.nprandom = np.random.RandomState(self.random_state)
self.q_star = np.zeros((self.runs, self.k))
def get_reward(self, action):
"""Before getting reward for the action taken in the current step,
we shift the reward distribution with drift sampe from normal distribution with mean 0 and std 0.01"""
self.q_star += self.nprandom.normal(loc=0.0, scale=0.01, size=(runs, k))
return super(NonStationaryMultiArmedBandit, self).get_reward(action)
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=10_000,
exp_bandit_agent_dict={
"stationary": (
StationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.01)
),
"nonstationary": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.01)
),
},
)

It is clearly seen that for non stationary bandit problem sample average method falls significantly behind.
The sample average method gives equal weightage to reward obtain irrespective of the steps in which they were obtained. However, in case of nonstationary setting it makes sense to give more weight to recent rewards than to long-past rewards. One of the most popular ways of doing this is to use a constant step-size parameter.
class EpsilonGreedyAgent(EpsilonGreedySampleAverageAgent):
"""Epsilon greedy agent with constant step size."""
def __init__(self, k, runs, alpha=0.1, epsilon=0.1, random_state=None):
super(EpsilonGreedyAgent, self).__init__(k, runs, epsilon, random_state)
self.alpha = alpha
def get_step_size(self, action):
"""Instead of returning number of times the action is choosen,
it returns the constant step size `alpha` provided to agent during its instantiation."""
return self.alpha
Let’s run experiment with combinations of stationary/non-stationary bandit and sampleaverage/constant step size action value estimation methods.
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=10_000,
exp_bandit_agent_dict={
"stationarysampleaverage": (
StationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.01)
),
"stationaryconstantstepsize": (
StationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
),
"nonstationarysampleaverage": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.01)
),
"nonstationaryconstantstepsize": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
),
},
)

From the figure, we can see that both constant step size and sample average action value estimation methods shows comparable performance on stationary setting. However, in non-stationary setting, though lower than in stationary setting, the constant step size method performs significantly better than sample average method.
Optimistic initial values is a technique used in the multi-armed bandit problem to encourage exploration in the early stages of learning. Instead of starting with initial action values set to zero or a neutral value, this approach sets the initial action values to a high, optimistic value, sometimes even higher than the maximum possible reward.
The optimistic initial values encourage exploration because the agent is initially biased to believe that all actions have high rewards. As the agent selects actions and receives actual rewards, it updates its estimates, and the optimistic values gradually fall towards their true values. This process continues until the value estimates of suboptimal actions fall below the estimates of the optimal action, and the agent starts exploiting the optimal action more frequently.
While optimistic initial values can be effective in stationary problems, there are some limitations to this approach. For instance, it is not well-suited for nonstationary problems, as its drive for exploration is temporary and may not adapt well to changes in the environment. Additionally, the technique relies heavily on the initial conditions, and finding the best initial values may require some trial-and-error or domain knowledge blog.
class OptimisticEpsilonGreedyAgent(EpsilonGreedyAgent):
def __init__(self, k, runs, alpha=0.1, init_q=5, epsilon=0.1, random_state=None):
"""This behaves similar to epsilon greedy agent,
but it starts with high optimistic action value."""
super(OptimisticEpsilonGreedyAgent, self).__init__(k, runs, alpha, epsilon, random_state)
self.init_q = init_q
self.Q += self.init_q
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=1000, # To make spike clear
exp_bandit_agent_dict={
"optimisticgreedy": (
StationaryMultiArmedBandit(k=k, runs=runs),
OptimisticEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=5, epsilon=0)
),
"realisticepsilongreedy": (
StationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
),
},
)

We can see that even with \(\epsilon=0\)(pure greedy) the agent with optimistic initial values outperforms \(\epsilon\)-greedy agent in the long run. Initially, it underperforms as it was forced to do exploration because of the optimistic values. Note that there is spike at about the 10th steps. The optimistic greedy policy promotes exploration in the initial steps, as all value estimates are set higher than their true values. This can lead to a scenario where the agent randomly selects the optimal action and then quickly abandons it in favor of other actions that have not been explored yet. This behavior results in a noticeable spike in performance around timestep 10, as the agent is still in the early stages of exploring different actions.
The experiment above for optimistic initial action values is done in stationary setting. Let’s run it in non-stationary setting see what happens.
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=10_000,
exp_bandit_agent_dict={
"optimisticgreedy": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
OptimisticEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=5, epsilon=0)
),
"realisticepsilongreedy": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
),
},
)

The experimenation clearly shows the limitation of the trick of using optimistic intial values to force exploration. It is not well suited to nonstationary problems because its drive for exploration is inherently temporary and non-stationary task at every steps creates need for exploration.
Let’s run an experiment comparing different initial action value estimate.
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=5000,
exp_bandit_agent_dict={
"optimistic_egreedy_initq=500": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
OptimisticEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=500, epsilon=0.01)
),
"optimistic_egreedy_initq=5": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
OptimisticEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=5, epsilon=0.01)
),
},
)

It is clear from the figure that the initial action value we choose has effects on beginning steps. With steps, the effect of initial values is lessened. I would like to quote the statements from the book here: “Indeed, any method that focuses on the initial conditions in any special way is unlikely to help with the general nonstationary case. The beginning of time occurs only once, and thus we should not focus on it too much.” So, for long one shouldn’t worry about the effect of initial action value choosen. But, there are tricks to avoid it and let’s explore one.
This trick given in the Exercise 2.7 of the Sutton’s book deals to avoid the effect of initial action values. The trick is to use step size: \(\beta \doteq \alpha /\overline{\omicron }_{n}\) where \(\overline{\omicron }_{n} \doteq \overline{\omicron }_{n-1} +\alpha ( 1-\overline{\omicron }_{n-1})\) for \(n \geq 0\), with \(\overline{\omicron}_{0}\doteq0\).
class UnbiasedEpsilonGreedyAgent(OptimisticEpsilonGreedyAgent):
"""Unbiased Constant Step-Size Agent.
Inheritated from `OptimisticEpsilonGreedyAgent` to allow to set initial action values"""
def __init__(self, k, runs, alpha=0.1, init_q=5, epsilon=0.1, random_state=None):
super(UnbiasedEpsilonGreedyAgent, self).__init__(k, runs, alpha, init_q, epsilon, random_state)
self.step_trace = np.zeros((self.runs, self.k))
def get_step_size(self, action):
"""Calculate the step size for the given action using trace."""
self.step_trace[np.arange(self.runs), action] += self.alpha*(1 - self.step_trace[np.arange(self.runs), action])
return self.alpha / self.step_trace[np.arange(self.runs), action]
Let’s run ubiased constant step-size agent with different inital values.
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=10000,
exp_bandit_agent_dict={
"unbiased_initq=500": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
UnbiasedEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=500, epsilon=0.01)
),
"unbiased_initq=5": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
UnbiasedEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=5, epsilon=0.01)
),
},
)

With this trick, we can see the effect of initial action values is gone. It is because with the unbiased constant step size the step size parameter for the first update will be \(1\) which means the agent will ignore the current action value estimate and set the estimate to the current reward it get.
In above sections, we’re more focused in action value estimation methods. We explored sample average, constant step size and unbiased constant step size methods. Remember that a agent also has to make decision on what action to choose in each step. The greedy method always exploit the action with highest action value estimate it has till the current step, i.e., no exploration. The epsilon greedy method in each steps select the random action with some probability as way of doing exploration. While doing these random selection, equal preference is given not taking care of actions that are better than others.
Upper Confidence Bound (UCB) action selection methods aims to balance exploration and exploitation based on the confidence boundaries assigned to each action. The UCB algorithm is rooted in the principle of optimism in the face of uncertainty, meaning that the more uncertain we are about an action, the more important it becomes to explore that action. The UCB algorithm effectively solves the exploration-exploitation dilemma by selecting actions that maximize both the estimated reward and the exploration term. By doing so, it ensures that the agent explores uncertain actions while still exploiting actions with high estimated rewards, allowing it to learn the optimal action over time. Let’s implement it as it descibed in the section 2.7 of the book.
class UCBActionAgent(EpsilonGreedyAgent):
def __init__(self, k, runs, alpha=0.1, confidence=1, random_state=None):
self.k = k
self.runs = runs
self.alpha = alpha
self.confidence = confidence
self.random_state = random_state
self.setup()
def get_action(self):
current_step = self.N.sum(axis=1, keepdims=True) + 1
return np.argmax(
self.Q + self.confidence * np.sqrt(
np.log(current_step)/(self.N + 1e-5) # To avoid divide by zero error
),
axis=1
)
Let’s run an experiment comparing epsilon greedy with UCB with various confidence level.
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=1000, # To make spike clear
exp_bandit_agent_dict={
"UCB C=1": (
StationaryMultiArmedBandit(k=k, runs=runs),
UCBActionAgent(k=k, runs=runs, alpha=0.1, confidence=1)
),
"UCB C=2": (
StationaryMultiArmedBandit(k=k, runs=runs),
UCBActionAgent(k=k, runs=runs, alpha=0.1, confidence=2)
),
"e-greedy, e=0.1": (
StationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.1)
),
},
)

UCB with \(C=1\) is performing better than epsilon greedy but we didn’t see the imporovement with \(C=2\). This is because \(C\) controls the degree of exploration, higher the confidence level, higher the degree of exploration.
Let’s run the above experiment in non-stationary setting.
# e-greedy is better than UCB in nonstationary setting
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=10_000,
exp_bandit_agent_dict={
"UCB C=1": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
UCBActionAgent(k=k, runs=runs, alpha=0.1, confidence=1)
),
"e-greedy, e=0.1": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.1)
),
},
)

This experiments shows the limitation of UCB in non-stationary setting.
Instead of indiscriminately choosing actions or using some uncertainities values for exploring actions, a more sophisticated way is to learn the preference of each action. Agent can do so by using gradient bandit algorithms. Unlike other bandit algorithms that maintain an action-value estimate for each action, gradient bandit algorithms maintain a preference value for each action and use a soft-max distribution to derive the probabilities of selecting each action.
The core idea of gradient bandit algorithms is to update the preferences based on the received rewards and a baseline reward value, which can be the average of all observed rewards so far. The update rule is designed to increase the preference for actions that yield higher rewards than the baseline and decrease the preference for actions with lower rewards.
In each iteration, the agent selects an action according to the soft-max distribution derived from the action preferences and updates the preferences based on the received reward and the baseline. This process continues until the agent converges towards the optimal action or a stopping criterion is met. Gradient bandit algorithms can adapt to changing environments and provide a good balance between exploration and exploitation.
class GradientAgent(EpsilonGreedyAgent):
"""Gradient Bandit Algorithm"""
def __init__(self, k, runs, alpha=0.1, random_state=None):
self.k = k
self.runs = runs
self.alpha = alpha
self.random_state = random_state
self.setup()
def setup(self):
"""Set up the initial preference and average reward for the GradientAgent"""
self.nprandom = np.random.RandomState(self.random_state)
# initial preference is same for all actions
self.H = np.zeros((self.runs, self.k))
self.avg_R = np.zeros((self.runs,))
def action_proba(self):
"""Calculate the probability of each action."""
exp = np.exp(self.H)
prob = exp/exp.sum(axis=1, keepdims=True)
# Persist for update method
self.current_action_prob = prob
return prob
def get_action(self):
"""Get the action to take based on the action probabilities."""
prob = self.action_proba()
# TODO find better way to vectorize the following.
return np.apply_along_axis(
lambda row: self.nprandom.choice(np.arange(self.k), p=row),
arr=prob,
axis=1,
)
def update(self, action, reward):
"""Update the preferences and average reward based on the given action and reward."""
step_size = self.get_step_size(action)
# Get already calculated action probability if available
prob = self.current_action_prob if hasattr(self, "current_action_prob") else self.action_proba()
prob = -prob
prob[np.arange(self.runs), action] = 1 + prob[np.arange(self.runs), action]
deltaR = reward - self.avg_R
self.H += step_size*prob*deltaR[:, np.newaxis]
self.avg_R += step_size*deltaR
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=1000,
exp_bandit_agent_dict={
"gradient": (
StationaryMultiArmedBandit(k=k, runs=runs),
GradientAgent(k=k, runs=runs, alpha=0.1)
),
"e-greedy, e=0.01": (
StationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
),
},
)

We can see almost 100% improvement in the optimal action selection because of gradient bandit algorithm.
MultiArmedBanditTestBed.run_and_plot_experiments(
steps=1000,
exp_bandit_agent_dict={
"gradient": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
GradientAgent(k=k, runs=runs, alpha=0.1)
),
"e-greedy, e=0.01": (
NonStationaryMultiArmedBandit(k=k, runs=runs),
EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
),
},
)

The gradient bandit algorithm also struggles in non-stationary setting.
This concludes the blog on multi-armed bandit problems. We implemented and compared various action value estimation and selection algorithms in sationary and non-stationary settings. All experiments showed the inherent challenge of exploitation vs exploration delimma in reinforcement learning.
The paper “Why do tree-based models still outperform deep learning on typical tabular data?” by Grinsztajn et al. [2023], establishes a new standard set of datasets with clear characteristics of tabular data and benchmarks various tree-based and deep learning models on these datasets. The results debunk common myths about the performance of neural networks on tabular data, highlighting the importance of understanding inductive biases and the impact of uninformative features and irregular functions on model performance.
By discussing the strengths and weaknesses of the paper, I aim to provide a comprehensive understanding of why tree-based models continue to outshine deep learning models on typical tabular data. This analysis will be particularly useful for industrial machine learning practitioners who are often puzzled by the seemingly inferior performance of deep learning models on tabular data. Also, some of the key strengths and weaknesses with this paper, the approach taken by the authors, and the clarity and ease of understanding of the writing are discussed herewith.
In the PDF attached below in the blog post, I have included annotations for the critical analysis . These annotations were made while I was reading the paper.
Overall, the significance of the contributions and strengths of the paper beats its weaknesses. That is why I think it was accepted in NeurIPS 2022. Also, this paper provides explanations for practitioners perplexed by the inferior performance of deep learning models on tabular data.
This blog/tutorial on finetune XLS-R on OpenSLR’s Nepali ASR dataset is adopted from the Huggingface’s blog on “Fine-tuning XLS-R for Multi-Lingual ASR with 🤗 Transformers.” The existing XLS-R on Nepali was actually finetuned on OpenSLR’s Nepali Text to Speech, which contains voices from only one speaker and that too of high quality. Therefore, it is doubtful that this model would work when the utterances are made in the wild, as we would normally do. To avoid this problem, speech samples taken in real life with natural characteristics like noises, pauses should be used, and Large Nepali ASR training data set is the one dataset we have for Nepali language. However, computational resources unavailability is keeping me from completing the finetuning and hyperparameter optimization of this remarkable model on my mother tongue; so far I get 21% word error rate on the test split I created. Anyone who likes to contribute on finshing this train will be heartily welcomed. Let’s begin the implementation journey though.
Wav2Vec2 is a pretrained model for Automatic Speech Recognition (ASR) and was released in September 2020 by Alexei Baevski, Michael Auli, and Alex Conneau. Soon after the superior performance of Wav2Vec2 was demonstrated on one of the most popular English datasets for ASR, called LibriSpeech, Facebook AI presented a multi-lingual version of Wav2Vec2, called XLSR. XLSR stands for cross-lingual speech representations and refers to model’s ability to learn speech representations that are useful across multiple languages.
XLSR’s successor, simply called XLS-R (refering to the '’XLM-R for Speech’‘), was released in November 2021 by Arun Babu, Changhan Wang, Andros Tjandra, et al. XLS-R used almost half a million hours of audio data in 128 languages for self-supervised pre-training and comes in sizes ranging from 300 milion up to two billion parameters. You can find the pretrained checkpoints on the 🤗 Hub:
Similar to BERT’s masked language modeling objective, XLS-R learns contextualized speech representations by randomly masking feature vectors before passing them to a transformer network during self-supervised pre-training (i.e. diagram on the left below).
For fine-tuning, a single linear layer is added on top of the pre-trained network to train the model on labeled data of audio downstream tasks such as speech recognition, speech translation and audio classification (i.e. diagram on the right below).

XLS-R shows impressive improvements over previous state-of-the-art results on both speech recognition, speech translation and speaker/language identification, cf. with Table 3-6, Table 7-10, and Table 11-12 respectively of the official paper.
In this notebook, we will give an in-detail explanation of how XLS-R - more specifically the pre-trained checkpoint Wav2Vec2-XLS-R-300M - can be fine-tuned for ASR.
XLS-R is fine-tuned using Connectionist Temporal Classification (CTC), which is an algorithm that is used to train neural networks for sequence-to-sequence problems, such as ASR and handwriting recognition.
I highly recommend reading the well-written blog post Sequence Modeling with CTC (2017) by Awni Hannun.
First, let’s try to get a good GPU in our colab! With Google Colab’s free version it’s sadly becoming much harder to get access to a good GPU. With Google Colab Pro, however, one should easily get either a V100 or P100 GPU.
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
print('Not connected to a GPU')
else:
print(gpu_info)
Sun Oct 23 02:39:42 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 38C P8 11W / 70W | 0MiB / 15109MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Before we start, let’s install datasets, transformers, and pytorch mutually compatible with each other. Also, we need the torchaudio to load audio files and jiwer to evaluate our fine-tuned model using the word error rate (WER) metric \({}^1\).
!pip --no-cache-dir install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
!pip --no-cache-dir install transformers==4.23.1
!pip --no-cache-dir install datasets==2.6.0
!pip --no-cache-dir install evaluate==0.3.0
!pip --no-cache-dir install jiwer
If you are training in environments like colab where storage space are limited or environments are temporary, it is strongly recommended to save your model checkpoints somewhere else. If we’re working with Huggingface’s transformer, why not the Huggingface Hub. The 🤗 Hub has integrated version control so you can be sure that no model checkpoint is getting lost during training.
To do so you have to store your authentication token from the Hugging Face website (sign up here if you haven’t already!)
from huggingface_hub import notebook_login
notebook_login()
Login successful
Your token has been saved to /root/.huggingface/token
Huggingface hub requires git lfs as it reuires to upload large models.
!apt install git-lfs
\({}^1\) In the paper, the model was evaluated using the phoneme error rate (PER), but by far the most common metric in ASR is the word error rate (WER). To keep this notebook as general as possible we decided to evaluate the model using WER.
ASR models transcribe speech to text, which means that we both need a feature extractor that processes the speech signal to the model’s input format, e.g. a feature vector, and a tokenizer that processes the model’s output format to text.
In 🤗 Transformers, the XLS-R model is thus accompanied by both a tokenizer, called Wav2Vec2CTCTokenizer, and a feature extractor, called Wav2Vec2FeatureExtractor.
Let’s start by creating the tokenizer to decode the predicted output classes to the output transcription.
Wav2Vec2CTCTokenizerA pre-trained XLS-R model maps the speech signal to a sequence of context representations as illustrated in the figure above. However, for speech recognition the model has to to map this sequence of context representations to its corresponding transcription which means that a linear layer has to be added on top of the transformer block (shown in yellow in the diagram above). This linear layer is used to classifies each context representation to a token class analogous how, e.g., after pretraining a linear layer is added on top of BERT’s embeddings for further classification - cf. with ‘BERT’ section of this blog post.
The output size of this layer corresponds to the number of tokens in the vocabulary, which does not depend onXLS-R’s pretraining task, but only on the labeled dataset used for fine-tuning. So in the first step, we will take a look at the chosen dataset of OpenSLR Nepali ASR and define a vocabulary based on the transcriptions.
First, let’s go to OpenSLR official website for Nepali ASR This data set contains transcribed audio data for Nepali. The data set consists of zips containing flac(a file format for audio) files, and a TSV file. The file utt_spk_text.tsv contains a FileID, anonymized UserID and the transcription of audio in the file. The data set has been manually quality checked, but there might still be errors.
Since downloading, extracting and preprocessing takes a lot of works and time, I’ve uploaded the preprocessed dataset along with the original dataset in Huggingface Hub so that we can interact with the Huggingface’ Dataset API. In summary, following are the steps I’ve taken to prepare the dataset:
import torchaudio
# The pretrained Wav2Vec2 model was trained on speeches with sample rate 16KHz
SAMPLING_RATE = 16000
def process_audio_file(orig_path, new_path):
"""Read and process file in `orig_path` and save it to `new_path`"""
waveform, sampling_rate = torchaudio.load(orig_path)
if sampling_rate != SAMPLING_RATE:
# Resample to 16KHz if the audio originally has different sampling rate.
waveform = torchaudio.functional.resample(waveform, sampling_rate, SAMPLING_RATE)
# Though the ASR models should be resilient to silences at the ends of audio,
# the leading and trailing silences are removed using Voice Activity Detection(VAD)
# implemented in torchaudio with default parameters to reduce the demands
# for computational resources
waveform = torchaudio.functional.vad(waveform, sample_rate=SAMPLING_RATE)
# save the processed audio files to new location
torchaudio.save(new_path, waveform, sample_rate=SAMPLING_RATE)
Now, we load the dataset with datasets api. Since the dataset is not split into train/val/test split, whole dataset will be downloaded and split into train and val set later. When I finetuned the Wav2Vec2 model on preprocessed, cleaned dataset, the model was not learning(WER was always 1 on validation/test set). Anyone who likes to debug the preprocessed dataset and fintuning on it is heartily welcomed.
from datasets import load_dataset
DATASET_TYPE = 'original' # change to `original` or `cleaned` for downloading original or cleaned version of openslr dataset
dataset = load_dataset("spktsagar/openslr-nepali-asr-cleaned", name=DATASET_TYPE, split='train')
dataset
Dataset({
features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
num_rows: 157905
})
You can see in the info that the dataset contains following fields.
For the description of them, please read the dataset card here.
Although the transcription are fairly clean, some transcription contain characters other than Nepali. We will remove those data from our dataset.
import string
def check_english_chars(text):
"""Returns if this text contains any english characters"""
return any([c in text for c in string.ascii_letters])
# Use dataset filter to remove examples with above function
dataset = dataset.filter(
lambda ex: not check_english_chars(ex),
input_columns=['transcription',],
with_indices=False, batched=False, batch_size=0,
)
dataset
Dataset({
features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
num_rows: 157904
})
Let’s see the list of all the characters we have now in the dataset.
''.join(sorted(set([c for s in dataset['transcription'] for c in s])))
' !%.;?\\\xa0ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलवशषसह़ािीुूृेैॉॊोौ्ॐ॑ॠ।०१२३४५६७८९॰\u200c\u200d\u200e\u200f“'
You can see there are some characters and symbols that we don’t use in Nepali. We will remove those from the transcription.
remove_chars = ['!', '%', '.', ';', '?', '\\', '।', '\xa0', '\u200c', '\u200d', '\u200e', '\u200f', '“']
def remove_special_characters(row):
row['transcription'] = ''.join(
[c for c in row['transcription'] if c not in remove_chars]
).strip()
return row
dataset = dataset.map(remove_special_characters)
''.join(sorted(set([c for s in dataset['transcription'] for c in s])))
' ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलवशषसह़ािीुूृेैॉॊोौ्ॐ॑ॠ०१२३४५६७८९॰'
In CTC, it is common to classify speech chunks into letters, so we will do the same here. Let’s extract all distinct letters of the training and test data and build our vocabulary from this set of letters.
We write a mapping function that concatenates all transcriptions into one long transcription and then transforms the string into a set of chars.
It is important to pass the argument batched=True to the map(...) function so that the mapping function has access to all transcriptions at once.
def extract_all_chars(batch):
all_text = " ".join(batch["transcription"])
vocab = list(set(all_text))
return {"vocab": [vocab]}
vocab_all = dataset.map(extract_all_chars, batched=True,
batch_size=-1, keep_in_memory=True,
remove_columns=dataset.column_names)
vocab_list = sorted(list(set(vocab_all["vocab"][0])))
Finally, we also add a padding token that corresponds to CTC’s “blank token”. The “blank token” is a core component of the CTC algorithm. For more information, please take a look at the “Alignment” section here.
UNK_TOKEN = '__UNK__'
PAD_TOKEN = '__PAD__'
vocab_list = [PAD_TOKEN, UNK_TOKEN, *vocab_list]
Now, we create an enumerated dictionary so that we have token to id mapping.
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
# for printing vocab in single line
', '.join([f"{k}: {v}" for k, v in (vocab_dict.items())])
'__PAD__: 0, __UNK__: 1, : 2, ँ: 3, ं: 4, ः: 5, अ: 6, आ: 7, इ: 8, ई: 9, उ: 10, ऊ: 11, ऋ: 12, ए: 13, ऐ: 14, ओ: 15, औ: 16, क: 17, ख: 18, ग: 19, घ: 20, ङ: 21, च: 22, छ: 23, ज: 24, झ: 25, ञ: 26, ट: 27, ठ: 28, ड: 29, ढ: 30, ण: 31, त: 32, थ: 33, द: 34, ध: 35, न: 36, प: 37, फ: 38, ब: 39, भ: 40, म: 41, य: 42, र: 43, ऱ: 44, ल: 45, व: 46, श: 47, ष: 48, स: 49, ह: 50, ़: 51, ा: 52, ि: 53, ी: 54, ु: 55, ू: 56, ृ: 57, े: 58, ै: 59, ॉ: 60, ॊ: 61, ो: 62, ौ: 63, ्: 64, ॐ: 65, ॑: 66, ॠ: 67, ०: 68, १: 69, २: 70, ३: 71, ४: 72, ५: 73, ६: 74, ७: 75, ८: 76, ९: 77, ॰: 78'
To make it clearer that " " has its own token class, we give it a more visible character |. In addition, we also add an “unknown” token so that the model can later deal with characters not encountered in Common Voice’s training set.
WORD_DELIMITER = '|'
vocab_dict[WORD_DELIMITER] = vocab_dict[" "]
del vocab_dict[" "]
len(vocab_dict)
79
Cool, now our vocabulary is complete and consists of 79 tokens, which means that the linear layer that we will add on top of the pretrained XLS-R checkpoint will have an output dimension of 79.
Let’s now save the vocabulary as a json file.
import json
with open('vocab.json', 'w') as vocab_file:
json.dump(vocab_dict, vocab_file)
In a final step, we use the json file to load the vocabulary into an instance of the Wav2Vec2CTCTokenizer class.
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token=UNK_TOKEN, pad_token=PAD_TOKEN, word_delimiter_token=WORD_DELIMITER)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Wav2Vec2FeatureExtractorSpeech is a continuous signal and to be treated by computers, it first has to be discretized, which is usually called sampling. The sampling rate hereby plays an important role in that it defines how many data points of the speech signal are measured per second. Therefore, sampling with a higher sampling rate results in a better approximation of the real speech signal but also necessitates more values per second.
A pretrained checkpoint expects its input data to have been sampled more or less from the same distribution as the data it was trained on. The same speech signals sampled at two different rates have a very different distribution, e.g., doubling the sampling rate results in data points being twice as long. Thus, before fine-tuning a pretrained checkpoint of an ASR model, it is crucial to verify that the sampling rate of the data that was used to pretrain the model matches the sampling rate of the dataset used to fine-tune the model.
XLS-R was pretrained on audio data of Babel, Multilingual LibriSpeech (MLS), Common Voice, VoxPopuli, and VoxLingua107 at a sampling rate of 16kHz. As stated earlier, the OpenSLR Nepali ASR dataset is already has a sampling rate of 16kHz.
# Define a global variable to store our sampling rate
SPEECH_SAMPLING_RATE = 16000
Long input sequences require a lot of memory. XLS-R is based on self-attention the memory requirement scales quadratically with the input length for long input sequences (cf. with this reddit post). In case this demo crashes with an “Out-of-memory” error for you, you might want to use the following code to filter all sequences that are longer than 5 seconds for training.
MAX_FRAMES = SPEECH_SAMPLING_RATE*5 # 5 sec
dataset = dataset.filter(
lambda ex: ex < MAX_FRAMES,
input_columns=['num_frames',],
with_indices=False, batched=False, batch_size=0,
)
dataset
Dataset({
features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
num_rows: 143974
})
This seemed to have worked! Let’s listen to a couple of audio files to better understand the dataset and verify that the audio was correctly loaded.
Note: You can click the following cell a couple of times to listen to different speech samples.
import random
import IPython.display as ipd
sample_idx = random.randint(0, len(dataset))
print(dataset[sample_idx]['transcription'])
ipd.Audio(dataset[sample_idx]['utterance']["array"], autoplay=True, rate=SPEECH_SAMPLING_RATE)
सुवेदी पार्टीको जिल्ला
A Wav2Vec2FeatureExtractor object requires the following parameters to be instantiated:
feature_size: Speech models take a sequence of feature vectors as an input. While the length of this sequence obviously varies, the feature size should not. In the case of Wav2Vec2, the feature size is 1 because the model was trained on the raw speech signal \({}^2\).sampling_rate: The sampling rate at which the model is trained on.padding_value: For batched inference, shorter inputs need to be padded with a specific valuedo_normalize: Whether the input should be zero-mean-unit-variance normalized or not. Usually, speech models perform better when normalizing the inputreturn_attention_mask: Whether the model should make use of an attention_mask for batched inference. In general, XLS-R models checkpoints should always use the attention_mask.from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=SPEECH_SAMPLING_RATE,
padding_value=0.0, do_normalize=True,
return_attention_mask=True)
Great, XLS-R’s feature extraction pipeline is thereby fully defined!
For improved user-friendliness, the feature extractor and tokenizer are wrapped into a single Wav2Vec2Processor class so that one only needs a model and processor object.
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(
feature_extractor=feature_extractor,
tokenizer=tokenizer
)
If one wants to re-use the just created processor including tokenizer and feature extractor with the fine-tuned model of this notebook, it is strongly advised to upload the processor to the 🤗 Hub. Let’s call the repo to which we will upload the files
"wav2vec2-large-xls-r-300m-nepali-openslr":
repo_name = "wav2vec2-large-xls-r-300m-nepali-openslr"
and upload the tokenizer to the 🤗 Hub.
processor.push_to_hub(repo_name)
Great, you can see the just created repository under https://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-nepali-openslr
So far, we have not looked at the actual values of the speech signal but just the transcription. In addition to transcription, our datasets include more column names utterance_id, speaker_id, utterance, and num_frames. In utterance there are two fields: array and path. path states the absolute path of the audio file, and array is the numpy array of the same audio file. Let’s take a look.
dataset[45]
{'utterance_id': 'a176fcb0d8',
'speaker_id': '6a6d1',
'utterance': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/6bc81d0b078b9cc240efb7e2885d7c845ea238b71125a16d73b19c06621b39f2/asr_nepali/data/a1/a176fcb0d8.flac',
'array': array([ 0.00030518, 0.00039673, 0.00039673, ..., -0.00057983,
-0.00036621, -0.00027466], dtype=float32),
'sampling_rate': 16000},
'transcription': '० पनि एक',
'num_frames': 36800}
We will convert Huggingface’s Dataset to PyTorch dataset, so that audio files are loaded lazily as we are restricted by space availablity and memory size.
import torch
class NepaliASRProcessedDataset(torch.utils.data.Dataset):
"""Takes HF dataset and processor, and process the audio files
and transcription with the processor only when items are requested
"""
def __init__(
self,
dataset,
processor,
):
self.dataset = dataset
self.processor = processor
def __len__(self):
"""Length of dataset"""
return len(self.dataset)
def __getitem__(self, idx):
"""Return processed data at `idx` index."""
example = self.dataset[idx]
# Return dict
return_dict = {}
# first, process the audio with Wav2Vec2 feature extractor
return_dict['input_values'] = self.processor(
audio=example['utterance']['array'],
sampling_rate=example['utterance']['sampling_rate'],
return_attention_mask=False, # will be calculated during batching
)['input_values'][0]
# add the length of extracted features of audio
return_dict['input_length'] = len(return_dict['input_values'])
# second, process the transcription with Wav2Vec2 tokenizer
return_dict['labels'] = self.processor(
text=example['transcription'],
return_attention_mask=False, # will be calculated during batching
)['input_ids']
return return_dict
Since our dataset has no separate split for training and evaluation, we will create one manually. We will split the dataset into 15% test and 85% train set
test_size = 0.15
dataset = dataset.sort('utterance_id')
split_dict = dataset.train_test_split(test_size=test_size, seed=42)
train_dataset, test_dataset = split_dict['train'], split_dict['test']
train_dataset, test_dataset
(Dataset({
features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
num_rows: 122377
}), Dataset({
features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
num_rows: 21597
}))
Convert the Huggingface’s train/test dataset to Pytorch train/test data
train_dataset = NepaliASRProcessedDataset(train_dataset, processor)
test_dataset = NepaliASRProcessedDataset(test_dataset, processor)
The data is processed so that we are ready to start setting up the training pipeline. We will make use of 🤗’s Trainer for which we essentially need to do the following:
Define a data collator. In contrast to most NLP models, XLS-R has a much larger input length than output length. E.g., a sample of input length 50000 has an output length of no more than 100. Given the large input sizes, it is much more efficient to pad the training batches dynamically meaning that all training samples should only be padded to the longest sample in their batch and not the overall longest sample. Therefore, fine-tuning XLS-R requires a special padding data collator, which we will define below
Evaluation metric. During training, the model should be evaluated on the word error rate. We should define a compute_metrics function accordingly
Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.
Define the training configuration.
After having fine-tuned the model, we will correctly evaluate it on the test data and verify that it has indeed learned to correctly transcribe speech.
Let’s start by defining the data collator. The code for the data collator was copied from this example.
Without going into too many details, in contrast to the common data collators, this data collator treats the input_values and labels differently and thus applies to separate padding functions on them. This is necessary because in speech input and output are of different modalities meaning that they should not be treated by the same padding function.
Analogous to the common data collators, the padding tokens in the labels with -100 so that those tokens are not taken into account when computing the loss.
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
LARGE_NEG = -100
@dataclass
class DataCollatorCTCWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
"""
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths,
# and need different padding methods
batch = {}
input_features = [{"input_values": feature["input_values"]} for feature in features if 'input_values' in feature]
label_features = [{"input_ids": feature["labels"]} for feature in features if 'labels' in feature]
if input_features:
batch.update(self.processor.pad(
input_features,
padding=self.padding,
return_tensors="pt",
))
if label_features:
labels_batch = self.processor.tokenizer.pad(
label_features,
padding=self.padding,
return_tensors="pt",
)
# replace padding with large negative number to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), LARGE_NEG)
batch["labels"] = labels
return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
Next, the evaluation metric is defined. As mentioned earlier, the predominant metric in ASR is the word error rate (WER), hence we will use it in this notebook as well.
import evaluate
import numpy as np
wer_metric = evaluate.load("wer")
The model will return a sequence of logit vectors: \(\mathbf{y}_1, \ldots, \mathbf{y}_m\) with \(\mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0]\) and \(n >> m\).
A logit vector \(\mathbf{y}_1\) contains the log-odds for each word in the vocabulary we defined earlier, thus \(\text{len}(\mathbf{y}_i) =\) config.vocab_size. We are interested in the most likely prediction of the model and thus take the argmax(...) of the logits. Also, we transform the encoded labels back to the original string by replacing LARGE_NEG with the pad_token_id and decoding the ids while making sure that consecutive tokens are not grouped to the same token in CTC style \({}^1\).
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == LARGE_NEG] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
Now, we can load the pretrained checkpoint of Wav2Vec2-XLS-R-300M. The tokenizer’s pad_token_id must be to define the model’s pad_token_id or in the case of Wav2Vec2ForCTC also CTC’s blank token \({}^2\). To save GPU memory, we enable PyTorch’s gradient checkpointing and also set the loss reduction to “mean”.
Because the dataset is quite large (~100h of data) and because ASR dataset is quite noisy, fine-tuning Facebook’s wav2vec2-xls-r-300m checkpoint seems to require some hyper-parameter tuning. Therefore, one had to play around a bit with different values for dropout, SpecAugment’s masking dropout rate, layer dropout, and the learning rate until training seemed to be stable enough.
Note: Since I was not able to run the hyperparamter optimization on colab, I’m not sure if the current set of hyperparameters are the best set of parameters. Feel free to adapt those parameters and let me know. I’ve used the default ones in wav2vec2 models.
from transformers import Wav2Vec2ForCTC
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-xls-r-300m",
attention_dropout=0.1,
hidden_dropout=0.1,
feat_proj_dropout=0.0,
mask_time_prob=0.075,
layerdrop=0.1,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)
The first component of XLS-R consists of a stack of CNN layers that are used to extract acoustically meaningful - but contextually independent - features from the raw speech signal. This part of the model has already been sufficiently trained during pretraining and as stated in the paper does not need to be fine-tuned anymore.
Thus, we can set the requires_grad to False for all parameters of the feature extraction part.
model.freeze_feature_encoder()
In a final step, we define all parameters related to training. To give more explanation on some of the parameters:
group_by_length makes training more efficient by grouping training samples of similar input length into one batch. This can significantly speed up training time by heavily reducing the overall number of useless padding tokens that are passed through the modellearning_rate and weight_decay were heuristically tuned until fine-tuning has become stable. Note that those parameters strongly depend on the dataset used and might be suboptimal for other speech datasets.For more explanations on other parameters, one can take a look at the docs.
During training, a checkpoint will be uploaded asynchronously to the hub every 400 training steps. It allows you to also play around with the demo widget even while your model is still training.
Note: If one does not want to upload the model checkpoints to the hub, simply set push_to_hub=False.
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=repo_name,
group_by_length=True,
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
evaluation_strategy="steps",
num_train_epochs=10,
gradient_checkpointing=True,
fp16=True,
save_steps=800,
eval_steps=800,
logging_steps=800,
learning_rate=3e-4,
warmup_steps=500,
save_total_limit=2,
push_to_hub=True,
hub_strategy='checkpoint',
resume_from_checkpoint='last-checkpoint',
)
Now, all instances can be passed to Trainer and we are ready to start training!
from transformers import Trainer
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=processor.feature_extractor,
)
\({}^1\) To allow models to become independent of the speaker rate, in CTC, consecutive tokens that are identical are simply grouped as a single token. However, the encoded labels should not be grouped when decoding since they don’t correspond to the predicted tokens of the model, which is why the group_tokens=False parameter has to be passed. If we wouldn’t pass this parameter a word like "hello" would incorrectly be encoded, and decoded as "helo".
\({}^2\) The blank token allows the model to predict a word, such as "hello" by forcing it to insert the blank token between the two l’s. A CTC-conform prediction of "hello" of our model would be [PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD].
Training will take multiple hours depending on the GPU allocated to this notebook.
In case you want to use this google colab to fine-tune your model, you should make sure that your training doesn’t stop due to inactivity. A simple hack to prevent this is to paste the following code into the console of this tab (right mouse click -> inspect -> Console tab and insert code).
function ConnectButton(){
console.log("Connect pushed");
document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click()
}
setInterval(ConnectButton,60000);
Depending on what GPU was allocated to your google colab it might be possible that you are seeing an "out-of-memory" error here. In this case, it’s probably best to reduce per_device_train_batch_size to 8 or even less and increase gradient_accumulation.
trainer.train(
resume_from_checkpoint=True, # Set to false if you want to start from the beginning
)
If the training loss and validation WER go down nicely, You can now upload the result of the training to the 🤗 Hub, just execute this
trainer.push_to_hub()
You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier “your-username/the-name-you-picked” so for instance:
from transformers import AutoModelForCTC, Wav2Vec2Processor
model = AutoModelForCTC.from_pretrained("spktsagar/wav2vec2-large-xls-r-300m-nepali-openslr")
processor = Wav2Vec2Processor.from_pretrained("spktsagar/wav2vec2-large-xls-r-300m-nepali-openslr")
For more examples of how XLS-R can be fine-tuned, please take a look at the official speech recognition examples.
As a final check, let’s load the model and verify that it indeed has learned to transcribe Nepali speech.
Let’s first load the pretrained checkpoint.
model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)
Now, we will just take all examples of the test set, run it through the model and take the argmax(...) of the logits to retrieve the predicted token ids. Those token ids will be decoded to retrieve transcriptions.
# only take 5 examples from
pred = trainer.predict(
torch.utils.data.Subset(
test_dataset,
random.sample(list(range(len(test_dataset))), 5)
)
)
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == LARGE_NEG] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
Now, let’s see first few reference transcriptions and predicted transcriptions.
list(zip(label_str, pred_str))[:5]
[('उपस्थित गराए', 'उपस्थित गराए'),
('प्रतिशतले वृद्धि', 'प्रतिशदले वृद्धि'),
('उनीहरू जहाँ जुन', 'उनीहरू जाँ जुन'),
('टेम्प्लेटहरू पनि', 'टेम्प्लेटहरू पनि'),
('रूपमा खाइन्छ', 'रूपमा खाइन्छ')]
Alright! The transcription can definitely be recognized from our prediction, but it is not perfect yet. Training the model a bit longer, spending more time on the data preprocessing, and especially using a language model for decoding would certainly improve the model’s overall performance.
You can play with and use the model I trained from here. Thank you!!
]]>Humans can recognize objects from a few examples. Having seen a lot of animal images before, given very few images of novel animals, we can recognize them easily. But for deep learning models, we have to train them from scratch to learn a new task. Transfer Learning and fine-tuning are some of the techniques to adapt trained models to learn a new task. The problem with them is that such models are trained only on a single task; adapting to a completely new task needs manual verification of similarity between these tasks. One of the recent approaches to this is the concept of meta-learning. The purpose of meta-learning schemes is to share information between the models being trained on similar tasks by using adaptation strategies to extract patterns that are useful for more than one task.
Learning from a small number of samples presents another difficulty for machine learning. The few-shot learning and zero-shot learning frameworks teach models to generalize to new datasets using relatively few samples. A K-shot classification problem, for instance, requires the model to generalize using just K examples: in the extreme case, the model generalizes using zero examples in zero-shot learning.
The model must adapt to new tasks with few instances and training iterations when the context of few-shot learning is added in meta-learning schemes. For example, with the model trained on various languages’ handwritten digit recognition tasks, with only a few handwritten examples per alphabet in a completely new language and very few training iterations, the model needs to generalize to that new language. To do this, a series of tasks are used to train the model or learner (character recognition model) during the meta-learning phase (e.g., different language character recognition). Instead of forcing the model or learner to focus on a specific task, we allow them to acquire intrinsic features that are generally applicable to all tasks in the task distribution \(P(\mathrm{T})\). Our goal is to identify model parameters (meta-learning phase) that are responsive to task changes such that minor changes in the parameters (adaptation phase) result in significant gains in the loss function for each task taken from \(P(\mathrm{T})\).
![]() |
|---|
| Fig 1: Diagram showing meta-learning and adaptation phase. Source: MAML Paper |
MetaGAN is a simple and general framework for few-shot learning problems. Given a K-shot(number of samples per class in a training task) and an N-way(number of classes in a task) classifier, a conditional task generator generates samples that are not distinguishable from true data samples drawn from the task used to condition it. We now need to train the classifier(discriminator in GAN but with N output units for N-way classification) and generator in an adversarial setup.
What is the gain of using GAN in few-shot meta-learning? In a few-shot classification problem, the model tries to find a decision boundary for each task with just a few samples in each class. With very few samples, so many decision boundaries can be made, but most of them will not generalize well. Meta-learning tries to mitigate this problem by trying to learn a shared strategy across different tasks to form a decision boundary from a few samples in the hope that the strategy of making decision boundaries generalizes well to new tasks. Although this is plausible, there might be some problems. For example, some objects look more similar than others. It may be easier to form a decision boundary between a Chinese alphabet and an English alphabet than between a Chinese alphabet and a Korean alphabet. If the training data does not contain tasks that try to separate the Chinese alphabet from the Korean alphabet, the learner may find it difficult to extract the correct features to separate these two classes of objects. However, on the other hand, the expectation to have all kinds of class combinations during training leads to the combinatorial explosion problem. This is where MetaGAN helps. The generator in MetaGAN generates fake data. This forces the classifier(discriminator) to learn a sharper decision boundary. Instead of a classifier learning to separate Chinese and Korean alphabets, MetaGAN also forces it to learn to distinguish between real and fake Chinese and Korean alphabets, as shown in the figure below. Moreover, we don’t need the generator to generate data that are exactly similar to true data. It is better if the generator learns a bit off about data manifold.
![]() |
|---|
| Fig 2: Decision Boundary with MetaGAN(left) and Decision Boundary without MetaGAN(right). Colors represent different classes: gray means fake classes, and green and bluish can be real characters’ images from different languages. + and - means real and fake samples. Source: MetaGAN Paper |
The objective of this blog is to show you the implementation of MetaGAN. While some basics of GAN are expected from you before you delve deeper into this implementation, you will learn about meta-learning and semi-supervised classification. After reading this blog, you’ll realize that GANs can be used for purposes other than as generative models. For the reason that our only purpose is to generate samples that are plausible with real data, we ignore the discriminator when the vanilla GAN training is finished. When the discriminator is extended to output class labels, we can use it to perform supervised and semi-supervised classification, which helps in utilizing unlabeled data with very few labeled data to increase performance. Implementing meta-learning on top of that, you will see that with very few training iterations, our classifier will be able to achieve very significant performance.
Start with basic library imports and environment setup.
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
tf.__version__
'2.8.2'
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
We will use Omniglot Dataset which contains 1623 different handwritten characters from 50 different languages. Each character has 20 samples, so in total 32460 samples exist in the dataset. This dataset is particularly used for few-shot learning problems and is available in Tensorflow Dataset resources. Let’s import it.
import tensorflow_datasets as tfds
omniglot, info = tfds.load('omniglot', with_info=True)
info # see the info about omniglot dataset here
tfds.core.DatasetInfo(
name='omniglot',
full_name='omniglot/3.0.0',
description="""
Omniglot data set for one-shot learning. This dataset contains 1623 different
handwritten characters from 50 different alphabets.
""",
homepage='https://github.com/brendenlake/omniglot/',
data_path='~/tensorflow_datasets/omniglot/3.0.0',
file_format=tfrecord,
download_size=17.95 MiB,
dataset_size=12.29 MiB,
features=FeaturesDict({
'alphabet': ClassLabel(shape=(), dtype=tf.int64, num_classes=50),
'alphabet_char_id': tf.int64,
'image': Image(shape=(105, 105, 3), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=1623),
}),
supervised_keys=('image', 'label'),
disable_shuffling=False,
splits={
'small1': <SplitInfo num_examples=2720, num_shards=1>,
'small2': <SplitInfo num_examples=3120, num_shards=1>,
'test': <SplitInfo num_examples=13180, num_shards=1>,
'train': <SplitInfo num_examples=19280, num_shards=1>,
},
citation="""@article{lake2015human,
title={Human-level concept learning through probabilistic program induction},
author={Lake, Brenden M and Salakhutdinov, Ruslan and Tenenbaum, Joshua B},
journal={Science},
volume={350},
number={6266},
pages={1332--1338},
year={2015},
publisher={American Association for the Advancement of Science}
}""",
)
Being aware of following few shot learning problem statements(extracted from MetaGAN paper) that helps in dataset preparation for MetaGAN.
Given a distribution of tasks \(P(T)\), a sample task \(T\) from \(P(T)\) is given by a joint distribution \(P^T_{X \times Y}(x, y)\), where the task is to predict \(y\) given \(x\). We have a set of training sample tasks \(\{T_i\}^N_{i=1}\). Each training sample task \(T\) is a tuple \(T = (S_T, Q_T)\), where the support set is denoted as \(S_T = S^s_T \cup S^u_T\), and the query set is denoted as \(Q_T = Q^s_T \cup Q^u_T\). The supervised support set \(S^s_T = \{(x_1, y_1), (x_2, y_2), \cdots (x_{N×K}, y_{N×K})\}\) contains \(K\) labeled samples from each of the \(N\) classes (this is usually known as \(K\)-shot \(N\)-way classification). The optional unlabeled support set \(S^u_T = \{(x^u_1 , x^u_2 , \cdots x^u_M\}\) contains unlabeled samples from the same set of \(N\) classes, which can also be empty in purely supervised cases. \(Q^s_T = \{(x_1, y_1), (x_2, y_2), \cdots (x_T, y_T)\}\) is the supervised query dataset. \(Q^u_T = \{x_1, x_2, \cdots x_P\}\) is the optional unlabeled query dataset. The objective of the model is to minimize the loss of its predictions on a query set, given the support set as input.
Simply put, for each alphabet in the omniglot dataset, we will keep K-shot(K-support in code) samples in the support set and K-query samples in the query set. Since the omniglot dataset doesn’t have any unlabeled samples, we don’t need to prepare an unlabeled support set and a query set. Whenever we require an unlabeled set, we will replace it with the corresponding labeled support set and labeled query set. A task is prepared by selecting N classes, each having K-shot support samples and K-query query samples. The support set will be used to fine tune the learner(adaptation), and the query set will be used to evaluate the adapted learner. Accumulated evaluation loss from a query set of a number of tasks will be used to update the learner to the best position from where, with very few gradient updates, the learner can be adapted to a new task.
We will be doing \(5\)-way, \(5\)-shot, \(15\)-query meta learning with 32 tasks for single meta update. Omniglot dataset contains images with size \(105 \times 105\) and \(3\) channels. We will resize it to \(28 \times 28\). All channels in images in Omniglot are same, so we can achieve our objective with single channel only. We will create a task by randomly selecting \(5\) labels from the train set, regardless of their alphabets. That means a single task can contain alphabets from different languages.
n_way = 5
k_support = 5 # alias of K in K-shot
k_query = 15
task_batch = 32 # number of task for single meta update
image_size = [28, 28]
num_of_channels = 1
noise_dim = 100 # number of dimension in latent space from where noise is sampled
(k_support + k_query) samples. Filter out group with samples less than (k_support + k_query). Before that shuffle dataset so that samples filtered out in one iteration gets chance to be involved in next iteration.n_way labels and batch them to form a single task. Ignore task with number of classes less than n_way.n_way different classes in a taskn_way task into support set and query set.task_batch tasks and batch into one. One task batch will be used for one metalearning step.def get_images_and_labels(sample):
"""Returns image and corresponding labels from omniglot samples.
A Omniglot samples is a dictionary with following structure.
`{alphabet: (), alphabet_char_id: (), image: (105, 105, 3), label: ()}`
Parameters
----------
sample : `dict` of Omniglot sample
Returns
----------
image : `Tensor` of dtype `tf.float32`
Image tensor shaped [105, 105, 3] in `sample` dictionary
label : `Tensor` of dtype `tf.int64`
Scalar Label tensor in `sample` dictionary
"""
image = tf.cast(sample['image'], tf.float32)
label = tf.cast(sample['label'], tf.int64)
return image, label
def get_label_group_func(dataset):
"""Returns a dataset where grouping of omniglot sample by its label
and reduction of them to batch of size `k_support + k_query` is done
Returns
----------
dataset :
A `tf.data.Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`
"""
dataset = dataset.group_by_window(key_func=lambda x, y: y,
reduce_func=lambda _, els: els.batch(k_support + k_query),
window_size=k_support + k_query)
return dataset
def label_group_filter(images, labels):
"""A predicate to check if labeled group Omniglot images
has exactly `k_support + k_query` samples. Otherwise we
cannot make support set and query set from this label group.
Ignore them.
Parameters
----------
images : `Tensor` of dtype `tf.float32`
Shape `[k_support + k_query, h, w, c]`
Images from Omniglot with same labels
labels : `Tensor` of dtype `tf.int64`
Shape `[k_support + k_query,]`
Corresponding labels of input images. They must be same here
Returns
----------
right_size_label_group : `boolean`
`True` if images have `k_support + k_query` samples else `False`
"""
right_size_label_group = tf.shape(images)[0] == (k_support + k_query)
return right_size_label_group
def resize_and_normalize(images, labels):
"""Resize image and normalize them in between `[-1, 1]`.
Parameters
----------
images : `Tensor` of dtype `tf.float32`
Shape `[k_support + k_query, height, width, channels]`
labels : `Tensor` of dtype `tf.int64`
Shape `[k_support + k_query, height, width, channels,]`
Corresponding labels of input `images`.
Returns
----------
images : `Tensor` of dtype `tf.float32`
Shape ``
Resized and normalized `images` with shape `image_size` and values in between `[-1, 1]`
Returns images with only first `num_of_channels` channels.
All channels in Omniglot dataset are same.
labels : `Tensor` of dtype `tf.int64`
Same as input labels.
"""
images = tf.image.resize((images[:, :, :, :num_of_channels]-127.5)/127.5, image_size)
return images, labels
def data_augment(images, labels):
"""Data augmentation by randomly rotating each image by a
multiple of 90 degrees to form new classes.
Parameters
----------
images : `Tensor` of dtype `tf.float32`
Shape `[k_support + k_query, height, width, channels]`
`k_support + k_query` samples in `images` should be from same class.
labels : `Tensor` of dtype `tf.int64`
Shape `[k_support + k_query,]`
Corresponding labels of input `images`. Here all labels should be same.
Returns
----------
images: `Tensor` of dtype `tf.float32`
Shape same as input `images`
labels : `Tensor` of dtype `tf.int64`
Same as input labels. We can consider same label for new class
as in old class since all images are rotated by same angle.
"""
rotation = tf.random.uniform([], maxval=4, dtype=tf.int32)
images = tf.image.rot90(images, k=rotation)
return images, labels
def relabel(images, labels):
"""Relabel images for `n_way` classification.
Parameters
----------
images : `Tensor` of dtype `tf.float32`
Shape `[n_way, k_query + k_support, height, width, channels]`
Images for a single task managed in 5D tensor.
--> 1st dimension: for `n_way` classes
--> 2nd dimension: for `k_query + k_support` images samples for a class
--> Other dimensions: Images height, width and channels
labels : `Tensor` of dtype `tf.int64`
Shape `[n_way, k_query + k_support]`
Omniglot labels for images in same structure as input `images`.
Labels value for images from each class can be any value betwenn 0 to 1622(total alphabets-1)
Returns
----------
images : `Tensor` of dtype `tf.float32`
Returns same input `images`. No changes required here.
new_labels : `Tensor` of dtype `tf.int64`
Shape `[n_way, k_query + k_support]`
New labels value must be between `[0, n_way-1]`
"""
old_labels_shape = tf.shape(labels)
new_classes = tf.expand_dims(tf.range(old_labels_shape[0]), -1)
new_labels = tf.tile(new_classes, [1, old_labels_shape[-1]])
new_labels = tf.cast(new_labels, tf.int64)
return images, new_labels
def get_support_query_split_func(shuffle=True):
"""Returns a function that will split a task into support set and query set
Parameters
----------
shuffle : `boolean`
Flags whether to shuffle before splitting `n_way` task into support set and query set.
If not take first `k_support` into support set and remaining into query set.
Returns
----------
support_query_split : `function`
A function that will split a task into support set and query set.
"""
def support_query_split(nway_images, nway_labels):
"""Split `n_way` task into `k_support`s support set and `k_query` query set.
Parameters
----------
nway_images : `Tensor` of dtype `tf.float32`
Shape `[n_way, k_query + k_support, height, width, channels]`
`k_query + k_support` images for each class in `n_way` task.
nway_labels : `Tensor` of dtype `tf.int64`
Shape `[n_way, k_query + k_support]`
N-way labels for images in same structure as input `nway_images`.
Returns
----------
support_images : `Tensor` of dtype `tf.float32`
Shape `[n_way, k_support, height, width, channels]`
Used for adaptation(K-shot learning).
support_labels : `Tensor` of dtype `tf.int64`
Shape `[n_way, k_support]`
Used for adaptation(K-shot learning).
query_images : `Tensor` of dtype `tf.float32`
Shape `[n_way, k_query, height, width, channels]`
Used for metalearning step(Testing adapted learner).
query_labels : `Tensor` of dtype `tf.int64`
Shape `[n_way, k_query]`
Used for metalearning step(Testing adapted learner).
"""
images_shape = tf.shape(nway_images)
perm = tf.random.shuffle(tf.range(images_shape[1])) if shuffle \
else tf.range(images_shape[1])
support_images = tf.gather(nway_images, perm[:k_support], axis=1)
support_images = tf.reshape(support_images, (-1, images_shape[-3], images_shape[-2], images_shape[-1]))
support_labels = tf.gather(nway_labels, perm[:k_support], axis=1)
support_labels = tf.reshape(support_labels, [-1])
query_images = tf.gather(nway_images, perm[k_support:], axis=1)
query_images = tf.reshape(query_images, (-1, images_shape[-3], images_shape[-2], images_shape[-1]))
query_labels = tf.gather(nway_labels, perm[k_support:], axis=1)
query_labels = tf.reshape(query_labels, [-1])
return support_images, support_labels, query_images, query_labels
return support_query_split
As mentioned in dataset preparation steps, we use above defined function in tf.data.Dataset pipeline.
You can see in Omniglot info above, The dataset has TRAIN, TEST and other splits(not used). Remember, we do no shuffling in test dataset.
train_dataset_task_grouped = omniglot['train']\
.map(get_images_and_labels)\
.shuffle(19280, reshuffle_each_iteration=True)\
.group_by_window(key_func=lambda x, y: y, # Group by label
reduce_func=lambda _, els: els.batch(k_support + k_query), # Batch size is k_support + k_query
window_size=k_support + k_query)\
.filter(label_group_filter)\
.map(resize_and_normalize)\
.map(data_augment)\
.shuffle(964, reshuffle_each_iteration=True)\
.batch(batch_size=n_way, drop_remainder=True)\
.map(relabel)\
.map(get_support_query_split_func(shuffle=True))\
.batch(batch_size=task_batch)
test_dataset_nway_grouped = omniglot['test']\
.map(get_images_and_labels)\
.group_by_window(key_func=lambda x, y: y, # Group by label
reduce_func=lambda _, els: els.batch(k_support + k_query), # Batch size is k_support + k_query
window_size=k_support + k_query)\
.filter(label_group_filter)\
.map(resize_and_normalize)\
.batch(batch_size=n_way, drop_remainder=True)\
.map(relabel)\
.map(get_support_query_split_func(shuffle=False))
Now define a utility function to generate noise from normal distribution.
def generate_noise(shape):
"""Generate noise from normal distribution.
Parameters
----------
shape: `tuple` of `int` or `tf.int` or both.
Shape of noise tensor to generate.
Returns
----------
noise: `Tensor` of dtype `tf.float32`
Noise tensor shaped `shape` from normal distribution.
"""
noise = tf.random.normal(shape)
return noise
The generator should be able to generate fake data that is close to real data manifold in specific task \(T\). That means we need to condition generator on task basis. For that, we compress the information in the task’s support dataset with a dataset encoder \(E\) into vector \(h_T\), which contains sufficient statistics for the data distribution of task \(T\). The task representation vector \(h_T\) is than concatenated with noise input \(z\) to be input to the generator network. The task encoder contains two modules. The Instance-Encoder encodes each samples and feature aggregation module produce representation vector \(h_T\) for the whole training task set by some aggrefation scheme like averaging, max-pooling, etc.
class TaskEncoder(tf.keras.Model):
"""Takes a task $$T$$ support set and generates a representation vector $$h_T$$.
Parameters
----------
conv_filters : `list` of `int`
List of number of filters to use in each `Conv2D` layers.
The length of list is the number of `Conv2D` layers to use.
conv_kernels : `list` of `int` or `tuples` or both.
List of kernel size to use in each corresponding `Conv2D` layers.
conv_strides : `list` of `int` or `tuples` or both.
List of strides to use in each corresponding `Conv2D` layers.
output_units : `int`
Dimension of Task representation.
Input shape
----------
N-D tensor with shape: `(n_way*k_support, height, width, channels)`
Whole support set for a task must be given to this model
Output shape
----------
N-D tensor with shape: `(1, output_units)`
Single representation vector for the whole task set.
"""
def __init__(self,
conv_filters,
conv_kernels,
conv_strides,
output_units,
**kwargs):
super(TaskEncoder, self).__init__(**kwargs)
self.conv2d_layers = [tf.keras.layers.Conv2D(filters=f,
kernel_size=k,
strides=s,
padding='same')
for f, k, s in zip(conv_filters,
conv_kernels,
conv_strides)]
self.activation_layers = [tf.keras.layers.LeakyReLU()
for _ in conv_filters]
self.flatten_layer = tf.keras.layers.Flatten()
self.output_layer = tf.keras.layers.Dense(output_units)
self.output_activation = tf.keras.layers.LeakyReLU()
self.output_dropout = tf.keras.layers.Dropout(rate=0.2)
def call(self, inputs):
for conv, activation in zip(self.conv2d_layers,
self.activation_layers):
inputs = conv(inputs)
inputs = activation(inputs)
outputs = self.flatten_layer(inputs)
outputs = self.output_layer(outputs)
outputs = self.output_activation(outputs)
outputs = self.output_dropout(outputs)
task_repr = tf.reduce_mean(outputs, axis=0, keepdims=True)
return task_repr
conv_filters=[64, 64, 128, 128, 256, 256]
conv_kernels=[3, 3, 3, 3, 3, 3]
conv_strides=[1, 2, 1, 2, 1, 2]
output_units=256
task_encoder = TaskEncoder(conv_filters=conv_filters,
conv_kernels=conv_kernels,
conv_strides=conv_strides,
output_units=output_units)
Note: No batchnormalization is used here.
class ConditionalGenerator(tf.keras.Model):
"""Task conditioned generator
Parameters
----------
conv_start_shape : `tuple` of length 3
The conditioned noise will be projected to `np.prod(conv_start_shape)`
and reshaped to `conv_start_shape`. To this output we can perform
convolutional operation.
upsample_scales : `list` of `int`
Instead of `Conv2DTranspose`, we use `Upsample + Conv2D`. It is the
list of scale sizes for upsampling before corresponding convolutional layers.
conv_filters : `list` of `int`
List of number of filters to use in each `Conv2D` layers.
The length of list is the number of `Conv2D` layers to use.
conv_kernels : `list` of `int` or `tuples` or both.
List of kernel size to use in each `Conv2D` layers.
Input shape
----------
Tuple of N-D tensor of length 2 and with shapes:
`[(n_way*k_support, noise_dim), (1, task_repr_size)]` during adaptation
or `[(n_way*k_query, noise_dim), (1, task_repr_size)]` during metalearning
1st element in tuple is noise and 2nd is task representation vector to condition generator.
Output shape
----------
N-D tensor with shape:
`(n_way*k_support, height, width, channels)` during adaptation
`(n_way*k_query, height, width, channels)` during metalearning
"""
def __init__(self,
conv_start_shape,
upsample_scales,
conv_filters,
conv_kernels,
**kwargs):
super(ConditionalGenerator, self).__init__(**kwargs)
self.concatenation = tf.keras.layers.Concatenate()
self.noise_embedding_projection = tf.keras.layers.Dense(np.prod(conv_start_shape))
self.noise_embedding_reshape = tf.keras.layers.Reshape(conv_start_shape)
self.upsample_layers = [tf.keras.layers.UpSampling2D(size=scale)
for scale in upsample_scales]
self.conv_layers = [tf.keras.layers.Conv2D(filters=f,
kernel_size=k,
padding='same')
for f, k in zip(conv_filters,
conv_kernels)]
self.activation_layers = [tf.keras.layers.LeakyReLU()
for _ in range(len(conv_filters) - 1)]\
+ [tf.keras.layers.Activation('tanh'),]
def call(self, inputs):
noise, task_representation = inputs
task_encodings = tf.tile(task_representation, [tf.shape(noise)[0], 1])
contitioned_noise = self.concatenation([noise, task_encodings]) # noise is now task conditioned
# Now, Same as normal gan except upsample instead of convtranspose.
output_image = self.noise_embedding_reshape(self.noise_embedding_projection(contitioned_noise))
for upsample, conv, activation in zip(self.upsample_layers,
self.conv_layers,
self.activation_layers):
output_image = upsample(output_image)
output_image = conv(output_image)
output_image = activation(output_image)
return output_image
generator = ConditionalGenerator(conv_start_shape=(7, 7, 256),
upsample_scales=[1, 2, 2],
conv_filters=[128, 64, num_of_channels],
conv_kernels=[5, 5, 5])
MetaGAN discriminator can be any few shot classifiers. We will be using one from Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks Chelsea
class Discriminator(tf.keras.Model):
"""The Discriminator/Classifier network for MetaGAN.
Parameters
----------
conv_filters : `list` of `int`
List of number of filters to use in each `Conv2D` layers.
The length of list is the number of `Conv2D` layers to use.
conv_kernels : `list` of `int` or `tuples` or both.
List of kernel size to use in each `Conv2D` layers.
conv_strides : `list` of `int` or `tuples` or both.
List of strides to use in each `Conv2D` layers.
Input shape
----------
N-D tensor with shape: `(batch_size, height, width, channels)`.
Output shape:
----------
Tuple of 2 N-D tensor with shape: `[(batch_size, n_way), (batch_size, flattened_size)]`.
Tuple of length 2 with 1st element classifier logit output and
2nd element flattened last convolutional layer output.
"""
def __init__(self,
conv_filters,
conv_kernels,
conv_strides,
n_way,
**kwargs):
super(Discriminator, self).__init__(**kwargs)
self.conv2d_layers = [tf.keras.layers.Conv2D(filters=f,
kernel_size=k,
strides=s,
padding='same')
for f, k, s in zip(conv_filters,
conv_kernels,
conv_strides)]
self.batchnorm_layers = [tf.keras.layers.BatchNormalization()
for _ in conv_filters]
self.activation_layers = [tf.keras.layers.ReLU()
for _ in conv_filters]
self.flatten_layer = tf.keras.layers.Flatten()
self.output_layer = tf.keras.layers.Dense(n_way)
def call(self, inputs):
for conv, batchnorm, activation in zip(self.conv2d_layers,
self.batchnorm_layers,
self.activation_layers):
inputs = conv(inputs)
inputs = batchnorm(inputs, training=True)
inputs = activation(inputs)
flattened_features = self.flatten_layer(inputs)
class_logits = self.output_layer(flattened_features)
return class_logits, flattened_features
discriminator = Discriminator(conv_filters=[64, 64, 64, 64],
conv_kernels=[3, 3, 3, 3],
conv_strides=[2, 2, 2, 2],
n_way=n_way)
# We require two discriminators to restore previous weights after trial of meta learning step
duplicate_discriminator = Discriminator(conv_filters=[64, 64, 64, 64],
conv_kernels=[3, 3, 3, 3],
conv_strides=[2, 2, 2, 2],
n_way=n_way)
We’ll use hyperparameters that are choosen according to MetaGAN paper and Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks Chelsea
meta_learning_rate=1e-3
meta_beta_1 = 0.5
meta_beta_2 = 0.9
adaption_learning_rate=0.1
adaptation_number_of_steps = 5
EPOCHS = 100
For adaptation steps, use SGD, and for meta learning, use adam optimizer with above defined parameters.
adaptation_optimizer = tf.keras.optimizers.SGD(learning_rate=adaption_learning_rate) # used for inner gradient update
meta_discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=meta_learning_rate,
beta_1=meta_beta_1,
beta_2=meta_beta_2)
meta_generator_optimizer = tf.keras.optimizers.Adam(learning_rate=meta_learning_rate,
beta_1=meta_beta_1,
beta_2=meta_beta_2)
As mentioned in the paper, discriminator loss can be divided into two parts. One that represents the GAN problem, the unsupervised loss and the other that represents the individual n_way class probabilities, the supervised loss. To get the classification probabilities, we feed the logits through the softmax function. However, We still need a way to represent the probability of an input image being real rather than fake. That is, we still need to account for the binary classification problem of regular a GAN. We know that the logits are in terms of softmax probability values. Yet, we need a way to represent them as sigmoid logits as well. We know that the probability of an input being real corresponds to the sum over all real class logits. With that mind, we can feed these values to a LogSumExp function that will model the binary classification value. After that, we feed the result from the LogSumExp to a sigmoid function as binary classification logits. Using the Tensorflow’s LogSumExp built-in function helps to avoid numerical problems. This routine prevents over/under flow issues that may occur when LogSumExp encounters very extreme, either positive or negative values.
As usual in vanilla GAN, for images coming from the training set, we maximize their probabilities of being real by assigning labels of 1s. For fabricated images coming from the generator, we maximize their probabilities to be fake by giving them labels of 0s.
def labeled_loss(target, class_logits):
"""Loss function to calculate supervised loss.
Parameters
----------
target : `Tensor` of dtype `tf.int64`
Shape `(batch_size,)`
Categorical labels for `n_way` classification
class_logits : `Tensor` of dtype `tf.float32`
Shape `(batch_size, n_way)`
It's a logit tensor.
Returns
----------
loss : Scalar `Tensor` of dtype `tf.float32`
Supervised cross entropy loss for labeled input
"""
losses = tf.keras.losses.sparse_categorical_crossentropy(target, class_logits, from_logits=True)
loss = tf.reduce_mean(losses)
return loss
For both real and generated images
def unlabeled_loss(class_logits, real=True):
"""Loss function to calculate supervised loss.
Parameters
----------
class_logits : `Tensor` of dtype `tf.float32`
Shape `(batch_size, n_way)`
It's a class logit tensor predicted by model for unlabeled input.
real : `boolean`
Flags whether `class_logits` is for real unlabeled data or for unlabeled generated data.
Returns
----------
loss : Scalar `Tensor` of dtype `tf.float32`
Unsupervised loss for unlabeled input.
"""
gan_logits = tf.reduce_logsumexp(class_logits, axis=1)
labels = tf.ones_like(gan_logits) if real else tf.zeros_like(gan_logits)
losses = tf.keras.losses.binary_crossentropy(labels, gan_logits, from_logits=True)
loss = tf.reduce_mean(losses)
return loss
def discriminator_loss(label_class_logits,
label_target,
unlabel_class_logits,
fake_class_logits):
"""Function that estimates how well the discriminator is able to distinguish real images from fakes.
discriminator_loss = supervised loss + real unsupervised loss + fake unsupervised loss.
Parameters
----------
label_class_logits : `Tensor` of dtype `tf.float32`
Shape `(batch_size, n_way)`
Classfier predicted class logits for labeled real data.
label_target : `Tensor` of dtype `tf.int32`
Shape `(batch_size,)`
True categorical labels of `label_class_logits` for `n_way` classification.
unlabel_class_logits : `Tensor` of dtype `tf.float32`
Shape `(batch_size, n_way)`
Classfier predicted class logits for unlabeled real data.
fake_class_logits : `Tensor` of dtype `tf.float32`
Shape `(batch_size, n_way)`
Classfier predicted class logits for unlabeled generated data.
Returns
----------
loss : Scalar `Tensor` of dtype `tf.float32`
"""
supervised_loss = labeled_loss(label_target, label_class_logits)
real_unsupervised_loss = unlabeled_loss(unlabel_class_logits, real=True)
fake_unsupervised_loss = unlabeled_loss(fake_class_logits, real=False)
disc_loss = supervised_loss + real_unsupervised_loss + fake_unsupervised_loss
return disc_loss
As described in the Improved Techniques for Training GANs paper, we use feature matching for the generator loss. Have a look at the author’s quote about feature matching: “Feature matching is the concept of penalizing the mean absolute error between the average value of some set of features on the training data and the average values of that set of features on the generated samples.” So, we take the average of the features of samples extracted from the discriminator when a real training minibatch is being processed and in same way take average of features for generated data samples. The generator loss is the mean squared difference between them.
def generator_loss(real_features, fake_features):
"""The generator's loss quantifies how well it was able to trick the discriminator.
Parameters
----------
real_features : `Tensor` of dtype `tf.float32`
Shape : `(batch_size, flattened_size)`
Flattened last convolutional layer output for real input data batch
real_features : `Tensor` of dtype `tf.float32`
Shape : `(batch_size, flattened_size)`
Flattened last convolutional layer output for generated input data batch
Returns
----------
loss : Scalar `Tensor` of dtype `tf.float32`
"""
real_mean_feature = tf.reduce_mean(real_features, axis=0)
fake_mean_feature = tf.reduce_mean(fake_features, axis=0)
gen_loss = tf.reduce_mean((real_mean_feature - fake_mean_feature)**2)
return gen_loss
def accuracy(class_logits, labels):
"""Accuracy measure for given class logits and target labels.
Parameters
----------
class_logits : `Tensor` of dtype `tf.float32`
Shape `(batch_size, n_way)`
It's a logit tensor.
labels : `Tensor` of dtype `tf.int64`
Shape `(batch_size,)`
True categorical labels of `class_logits` for `n_way` classification.
Returns
----------
accuracy : `Tensor` of dtype `tf.float32`
Accuracy measure for given class logits and target labels.
"""
label_predictions = tf.argmax(class_logits, axis=-1)
equality = tf.equal(labels, label_predictions)
return tf.reduce_mean(tf.cast(equality, tf.float32))
Define an adaptaion function.
def adaptation(learner, real_support, real_support_label, fake_support):
"""Given a learner model, perform `adaptation_number_of_steps`
gradient descent using supervised support set, generated fake support set
and same supervised support set as unsupervised support set as we don't
have unlabeled omniglot images.
Parameters
----------
learner : `tensorflow.keras.Model`
A few shot classifier model which needs to be adapted to given task support set.
real_support : `Tensor` of dtype `tf.float32`
Shape `(n_way*k_support, height, width, channels)`
Target task support images
real_support_label : `Tensor` of dtype `tf.int64`
Shape `(n_way*k_support,)`
`n_way` labels of given real support set
fake_support : `Tensor` of dtype `tf.float32`
Shape `(n_way*k_support, height, width, channels)`
Generated images conditioned on same task as input real support images task.
Returns
----------
support_loss : Scalar Tensor of dtype `tf.float32`
Support loss during last step of adaptation of given `learner`.
"""
support_disc_losses = []
for _ in range(adaptation_number_of_steps):
with tf.GradientTape() as adaptation_tape:
support_real_class_logits, _ = learner(real_support)
support_fake_class_logits, _ = learner(fake_support)
disc_loss = discriminator_loss(support_real_class_logits,
real_support_label,
support_real_class_logits,
support_fake_class_logits)
adaptation_grads = adaptation_tape.gradient(disc_loss, learner.trainable_variables)
adaptation_optimizer.apply_gradients(zip(adaptation_grads, learner.trainable_variables))
support_disc_losses.append(disc_loss)
return tf.reduce_mean(support_disc_losses)
@tf.function
def meta_learn_step(support_taskbatch,
support_taskbatch_labels,
query_taskbatch,
query_taskbatch_labels):
"""Perform one step of metalearning given a batch of tasks.
Parameters
----------
support_taskbatch : `Tensor` of dtype `tf.float32`
Shape `(task_batch, n_way*k_support, height, width, channels)`
Support set images for task batch.
support_taskbatch_labels : `Tensor` of dtype `tf.int64`
Shape `(task_batch, n_way*k_support,)`
Support set images for task batch set labels for the task batch.
query_taskbatch : `Tensor` of dtype `tf.float32`
Shape `(task_batch, n_way*k_query, height, width, channels)`
Query set images for task batch.
query_taskbatch_labels : `Tensor` of dtype `tf.int64`
Shape `(task_batch, n_way*k_query,)`
Query set images for task batch set labels for the task batch.
Returns
----------
taskbatch_query_discriminator_loss : Scalar Tensor of dtype `tf.float32`
Average discriminator loss over task batch on query set
task_batch_query_generator_loss : Scalar Tensor of dtype `tf.float32`
Average generator loss over task batch on query set
taskbatch_query_accuracy : Scalar Tensor of dtype `tf.float32`
Average accuracy over task batch on query set
"""
number_of_tasks = support_taskbatch.shape[0]
# Step 1. Store discriminator weights in another model,
# Both model should be built before executing this step
for dup_wts, wts in zip(duplicate_discriminator.trainable_variables,
discriminator.trainable_variables):
dup_wts.assign(wts)
# Step 2. Initialize tensor to find total losses and accuracies on various tasks
taskbatch_query_discriminator_loss = tf.constant(0.0)
task_batch_query_generator_loss = tf.constant(0.0)
taskbatch_query_accuracy = tf.constant(0.0)
with tf.GradientTape() as meta_discriminator_tape, tf.GradientTape() as meta_generator_tape:
## Step 3. Repeat Step 4-12 for all tasks in current task batch.
for task_no in range(number_of_tasks):
# Step 4. For each task, find its representation vector using support set and TaskEncoder model.
task_representation = task_encoder(support_taskbatch[task_no])
# Step 5. Adapt discriminator model to the current task, call `adaptation` function passing discriminator
# and required support inputs
with meta_discriminator_tape.stop_recording(), meta_generator_tape.stop_recording():
# No need to recording operatin of adaptation for meta updates
# Generate fake support set with same number of samples as in real support set in current task
support_noise = generate_noise((tf.shape(support_taskbatch[task_no])[0], noise_dim))
support_fake = generator([support_noise, task_representation])
support_loss = adaptation(discriminator,
support_taskbatch[task_no],
support_taskbatch_labels[task_no],
support_fake)
# Step 6. Generate fake query set
query_noise = generate_noise((tf.shape(query_taskbatch[task_no])[0], noise_dim))
query_fake = generator([query_noise, task_representation])
# Step 7. Find discriminator feature and class logits for real and generated query set of current task
query_real_class_logits, query_real_features = discriminator(query_taskbatch[task_no])
query_fake_class_logits, query_fake_features = discriminator(query_fake)
# Step 8. Find discriminator loss
disc_loss = discriminator_loss(query_real_class_logits,
query_taskbatch_labels[task_no],
query_real_class_logits,
query_fake_class_logits)
# Step 9. Find generator loss
gen_loss = generator_loss(query_real_features, query_fake_features)
# Step 10. Add query discriminator loss and generator loss to recording variable.
taskbatch_query_discriminator_loss += disc_loss
task_batch_query_generator_loss += gen_loss
# Step 11. Calculate query accuracy for current task and add to the sum variable
query_accuracy = accuracy(query_real_class_logits, query_taskbatch_labels[task_no])
taskbatch_query_accuracy += query_accuracy
# Step 12. Recover discriminator weights before adaptation for next task adaptation.
for dup_wts, wts in zip(duplicate_discriminator.trainable_variables,
discriminator.trainable_variables):
wts.assign(dup_wts)
# Step 13. Find discriminator and generator gradients; TaskEncoder is updated along with Generator,
# So from total generator loss, find gradients wrt to both generator variables and task encoder variables
meta_discriminator_grads = meta_discriminator_tape.gradient(taskbatch_query_discriminator_loss,
discriminator.trainable_variables)
meta_generator_grads = meta_generator_tape.gradient(task_batch_query_generator_loss,
task_encoder.trainable_variables + generator.trainable_variables)
# Step 14. Using respective meta optimizer updates discriminator, generator and task encoder weights
meta_discriminator_optimizer.apply_gradients(zip(meta_discriminator_grads,
discriminator.trainable_variables))
meta_generator_optimizer.apply_gradients(zip(meta_generator_grads,
task_encoder.trainable_variables + generator.trainable_variables))
# Find average metrices for task batch to return
avg_disc_loss = taskbatch_query_discriminator_loss/number_of_tasks
avg_gen_loss = task_batch_query_generator_loss/number_of_tasks
avg_accuracy = taskbatch_query_accuracy/number_of_tasks
return avg_disc_loss, avg_gen_loss, avg_accuracy
# for a single task
@tf.function
def evaluation(support_images, support_labels, query_images, query_labels):
"""Perform finetuning/adaptation using support set of a task and
returns evaluation metrices calculated using query set of the same task.
Parameters
----------
support_images : `Tensor` of dtype `tf.float32`
Shape `(n_way*k_support, height, width, channels)`
support_labels : `Tensor` of dtype `tf.int64`
Shape `(n_way*k_support,)
query_images : `Tensor` of dtype `tf.float32`
Shape `(n_way*k_query, height, width, channels)`
query_labels : `Tensor` of dtype `tf.int64`
Shape `(n_way*k_query,)
Returns
----------
query_discriminator_loss : Scalar Tensor of dtype `tf.float32`
Classifier loss on query set of given task after adaptation on same task support set.
query_accuracy : `Tensor` of dtype `tf.float32`
Classifier accuracy on query set of given task after adaptation on same task support set.
"""
# During evaluation since no metalearning step is done, we will use secondary discriminator for adaptation
# and evaluation metrices calculation on query set.
# Step 1. Copy primary discriminator weights to secondary discriminator.
for dup_wts, wts in zip(duplicate_discriminator.trainable_variables, discriminator.trainable_variables):
dup_wts.assign(wts)
# Step 2. Produce fake support set and adapt secondary discriminator to passed task.
task_representation = task_encoder(support_images)
support_noise = generate_noise((tf.shape(support_images)[0], noise_dim))
support_fake = generator([support_noise, task_representation])
support_loss = adaptation(duplicate_discriminator,
support_images,
support_labels,
support_fake)
# Step 3. Produce fake query set with same number of samples as in real query set
query_noise = generate_noise((tf.shape(query_images)[0], noise_dim))
query_fake = generator([query_noise, task_representation])
# Step 4. Find class logits for real and fake query set. Since generator loss is irrelevant here
# Discriminator loss is ignored here.
query_real_class_logits, _ = duplicate_discriminator(query_images)
query_fake_class_logits, _ = duplicate_discriminator(query_fake)
# Step 5. Find discriminator loss
query_discriminator_loss = discriminator_loss(query_real_class_logits,
query_labels,
query_real_class_logits,
query_fake_class_logits)
# Step 6. Find accuracy on query set
query_accuracy = accuracy(query_real_class_logits, query_labels)
return query_discriminator_loss, query_accuracy
import os
checkpoint_dir = './training_checkpoints'
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(adaptation_optimizer=adaptation_optimizer,
meta_discriminator_optimizer=meta_discriminator_optimizer,
meta_generator_optimizer=meta_generator_optimizer,
task_encoder=task_encoder,
generator=generator,
discriminator=discriminator)
# To share same weights between primary discriminator and secondary discriminator
# manually intialize weights by calling build function
discriminator.build((None, image_size[0], image_size[1], num_of_channels))
duplicate_discriminator.build((None, image_size[0], image_size[1], num_of_channels))
gen_losses = []
disc_losses = []
accuracies = []
test_losses = []
test_accuracies = []
for ep in range(EPOCHS):
gen_loss = []
disc_loss = []
disc_accuracy = []
########################################### Training ################################################
for task_batch_no, (support_images, support_labels, query_images, query_labels) in enumerate(train_dataset_task_grouped):
d_loss, g_loss, acc = meta_learn_step(support_images,
support_labels,
query_images,
query_labels)
disc_loss.append(d_loss)
gen_loss.append(g_loss)
disc_accuracy.append(acc)
disc_loss = tf.reduce_mean(disc_loss)
gen_loss = tf.reduce_mean(gen_loss)
disc_accuracy = tf.reduce_mean(disc_accuracy)
disc_losses.append(disc_loss)
gen_losses.append(gen_loss)
accuracies.append(accuracy)
#####################################################################################################
#################################### Evaluation and Logging ################################################
if ep%10 == 0: # Every 10 epochs
test_loss = []
test_accuracy = []
for task_no, (support_images, support_labels, query_images, query_labels) in enumerate(test_dataset_nway_grouped):
query_loss, query_accuracy = evaluation(support_images, support_labels, query_images, query_labels)
test_loss.append(query_loss)
test_accuracy.append(query_accuracy)
test_loss = tf.reduce_mean(test_loss)
test_accuracy = tf.reduce_mean(test_accuracy)
test_losses.append(test_loss)
test_accuracies.append(test_accuracy)
tf.print('Epoch: ', ep,
'Train gen loss: ', gen_loss,
'Train disc loss: ', disc_loss,
'Train accuracy: ', disc_accuracy,
'Test loss: ', test_loss,
'Test accuracy: ', test_accuracy)
else:
tf.print('Epoch: ', ep,
'Train gen loss: ', gen_loss,
'Train disc loss: ', disc_loss,
'Train accuracy: ', disc_accuracy)
#####################################################################################################
checkpoint.save(file_prefix = checkpoint_prefix)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
Evaluation on test dataset
test_loss = []
test_accuracy = []
for task_no, (support_images, support_labels, query_images, query_labels) in enumerate(test_dataset_nway_grouped):
query_loss, query_accuracy = evaluation(support_images, support_labels, query_images, query_labels)
test_loss.append(query_loss)
test_accuracy.append(query_accuracy)
test_loss = tf.reduce_mean(test_loss)
test_accuracy = tf.reduce_mean(test_accuracy)
tf.print('Test loss: ', test_loss,
'Test accuracy: ', test_accuracy)
Test loss: 1.80416536 Test accuracy: 0.628600478
As we said before, we don’t need generator to be perfect. Let’s see what generator has learned to generate
test_iterator = tfds.as_numpy(test_dataset_nway_grouped)
Run following cells multiple times to check generated images conditioned on different tasks.
test_task = next(test_iterator)
support_images, support_labels, query_images, query_labels = test_task
generated = generator([generate_noise((tf.shape(support_images)[0], noise_dim)), task_encoder(support_images)])
Real task images.
fig = plt.figure(figsize=(5,5))
for i in range(5*5):
plt.subplot(5, 5, i+1)
plt.imshow(support_images[i, :, :, -1], cmap='gray')
plt.axis('off')
plt.show()

Generated images from generator conditioned on same task.
fig = plt.figure(figsize=(5,5))
for i in range(5*5):
plt.subplot(5, 5, i+1)
plt.imshow(generated[i, :, :, -1], cmap='gray')
plt.axis('off')
plt.show()
