Jekyll2026-03-17T20:21:40-07:00https://spktsagar.com/feed.xmlSagar SapkotaGraduate student in Computer Science at the University of Central FloridaSagar Sapkota[email protected]https://spktsagar.comPEFT Cheat Sheet: Succinct Explanations to the Numerous PEFT Methods for LLM2023-10-18T00:00:00-07:002023-10-18T00:00:00-07:00https://spktsagar.com/posts/2023/10/peft-methods-summaryParameter-efficient fine-tuning (PEFT) of large language models (LLMs) is a critical area of focus in today’s machine learning research, driven by the need to optimize computational resources without compromising performance. Fine-tuning a pre-trained model for a specific task can often require substantial computational power, making it a bottleneck for many real-world applications. PEFT methods aim to mitigate this challenge by efficiently leveraging the existing parameters of a pre-trained model while adding, modifying, or reconfiguring a minimal number of parameters for the target task. This blog post will delve into various PEFT methods for LLMs, categorizing them into additive, selective, reparametrization-based, and hybrid methods. This classification of PEFT methods is taken from Lialin et al., 2023 Each method will be succinctly described, highlighting its unique approach and potential benefits.

Traditional Fine-Tuning

Before going into the list of PEFT methods, let’s explain how we traditionally adapted our pretrained models for a downstream task. The process involves taking a pre-trained model, which has already learned useful features from a large-scale dataset, and further training it on a specific task using a smaller, task-specific dataset.

The process of traditional fine-tuning is as follows:

  • Pre-training: A model is first trained on a large-scale dataset. This is often an unsupervised task, such as language modelling, where the model learns to predict the next word in a sentence. During this pre-training phase, the model learns a general understanding of the language, including its syntax, semantics, and some level of world knowledge.
  • Fine-tuning: After pre-training, the model is then fine-tuned on a specific task using a smaller, task-specific dataset. This could be any supervised NLP task like sentiment analysis, question answering, or named entity recognition. During this phase, all the model parameters are updated to optimize for the specific task.

Parameter-Efficient Fine-Tuning

In recent years, the use of large language models has revolutionized the field of natural language processing (NLP). These models, such as GPT-3 and BLOOM, are capable of generating human-like text and understanding the context of a language. However, fine-tuning these models traditionally can be computationally expensive and require a lot of memory due to the large number of parameters involved. To tackle this problem, there has been a upsurge in research community to find the efficient way to fine-tune pretrained models for downstream task. These founds and to be found methods are collectively called “Parameter-Efficient Fine-Tuning” methods, PEFT in short.

If one tries to list all the PEFT methods available online, it might take forever. Here, I try to list and provide quick explanation to some popular one. Most of them are extracted from survey paper by Lialin et al., 2023. However, I’ve included some that the paper not listed or some that are published after the survey paper. Basically, the paper categorized PEFT methods based on the conceptual framework underlying the approach.

[Click Methods to Expand their Explanations ]

1. Additive Methods

Additive methods for fine-tuning language models involve expanding the pre-existing pre-trained model with supplementary parameters or layers, and then training only those newly added parameters. Despite the potential increase in complexity, adding parameters can enhance training time and memory efficiency by shrinking the size of gradients and the optimizer states. Consequently, this approach can enhance the fine-tuning of larger networks or the use of larger micro-batch sizes, thus enhancing GPU training throughput and reducing communication volume in distributed setups. Based on the way parameters are added, it is divided into Adapters, Soft Prompting, and Others.

1.1 Adapters

Adapters are a method that introduces small, fully-connected networks after Transformer sub-layers.
1.1.1 Adapters
Adapters add fully-connected networks with a small hidden dimension after attention and feed-forward network (FFN) layers in a Transformer. Although this approach reduces the parameters updated during training, it creates inference overhead due to the added layers.
1.1.2 AdaMix
AdaMix uses multiple adapters in a mixture-of-experts (MoE) fashion. Unlike a regular MoE, which selects and weights multiple experts using a routing network, AdaMix randomly selects a single expert for each forward pass. This strategy minimizes computational costs and barely degrades the performance.

1.2 Soft Prompts

Soft prompts involve fine-tuning a portion of the model’s input embeddings via gradient descent. This approach transforms the problem of finding prompts in a discrete space(textual prompts) into a continuous optimization problem.
1.2.1 Prompt Tuning
Prompt tuning introduces a trainable tensor, commonly referred to as a "soft prompt", which is prepended to the model's input embeddings. This tensor is directly optimized through gradient descent. This method requires storing a small task-specific soft prompt and enables mixed-task inference using the original pre-trained model.
1.2.2 Prefix Tuning
Prefix tuning is a method used to address the instability of prompt tuning. Instead of only adding a soft prompt to the model input, trainable parameters are prepended to the hidden states of all layers. The same prefix is prepended to all of the hidden states.
1.2.3 P-Tuning
P-Tuning is another form of soft prompting, which employs a prompt encoder (a bidirectional long-short term memory network or LSTM) to optimize the prompt parameters. The prompt tokens can be inserted anywhere in the input sequence, and are not restricted to only the beginning.
1.2.4 Intrinsic Prompt Tuning (IPT)
Intrinsic Prompt Tuning (IPT) hypothesizes that the space used to define soft prompt parameters contains an "intrinsic task subspace" that can differentiate between various tasks. It introduces an autoencoder to (de)compress the soft prompt. Despite reducing the number of parameters for the soft prompt, the requirement to train the autoencoder makes it practically infeasible.

1.3 Other Additive Approaches

Beyond adapters and soft prompts, there are other methods of adding parameters that do not strictly follow the concepts of adapters or soft prompts.
1.3.1 Knowledge Distillation
Knowledge distillation is a technique that transfers knowledge from a larger, high-performing model (the teacher model) to a smaller model (the student model). The teacher model's output probabilities serve as soft targets for training the student model, enabling the student model to benefit from the teacher model's knowledge and generalize better.
1.3.2 Ladder-Side Tuning (LST)
Ladder-Side Tuning (LST) trains a small transformer network on the side of the pre-trained network. This side network combines the hidden states of the pre-trained backbone network with its own hidden states, using the pre-trained model as a feature extractor. Backpropagation is only computed through the side network, saving on both memory and compute during training.
1.3.3 IA3
(IA)3 is a method that learns new parameters (vectors) lv, lk, lff which rescale key, value, and hidden FFN activations in each transformer layer. This method produces very low overhead during parameter updates in fine-tuning.

2. Selective Methods

Selective methods for parameter-efficient fine-tuning involve optimizing a subset of a model's existing parameters. The selection can be based on layer depth, layer type, or even individual parameters. Here are some popular selective methods:

2.1 Quantization

Quantization is a method that reduces the precision of model parameters to lower memory and computational requirements. In traditional deep learning models, parameters are usually stored as 32-bit floating-point numbers. Quantization, however, allows these parameters to be represented with lower bit precision, such as 8-bit integers. This reduction in precision significantly lowers the memory footprint and speeds up computations.

2.2 BitFit

BitFit is a method that fine-tunes only the biases of the network. For every linear or convolutional layer, the weight matrix is kept constant, and only the bias vector is optimized. This approach is particularly efficient as it reduces the number of parameters that need to be updated during training.

2.3 Pruning

Pruning is a technique that involves removing unnecessary weights or connections from a pre-trained model. By identifying and eliminating redundant or less important parameters, the model’s size and computational requirements can be significantly reduced. Pruning can be performed based on different criteria, such as magnitude-based pruning or structured pruning. Magnitude-based pruning removes weights with small magnitudes, while structured pruning removes entire neurons or filters based on their importance.

2.4 DiffPruning

DiffPruning aims to achieve parameter efficiency by learning a sparse update of a neural network’s weights. The method introduces a learnable binary mask on the weights, denoted by δ = z ◦ ∆W, where ◦ represents the Hadamard product. This parameter mask is learned during model fine-tuning as part of the regularization objective, which is a differentiable approximation to the L0 norm of the update vector δ. This method requires more memory than traditional fine-tuning, as it involves optimizing all parameters during training in addition to the learnable binary mask.

2.5 Freeze and Reconfigure (FAR)

The Freeze and Reconfigure (FAR) method selects columns of parameter matrices to prune and reconfigures linear layers into trainable and frozen. In the first stage, the most important rows of parameter matrices are identified for updating. This process is similar to structured pruning and can use any pruning method. In the second stage, the network is reconfigured by splitting each parameter tensor into trainable and frozen components. After training, the parameters can be reconfigured back, removing any inference overhead.

2.6 FishMask

FishMask is a sparse fine-tuning method that selects the top-p parameters of the model based on their Fisher information. Fisher information measures the amount of information that an observable random variable carries about an unknown parameter of a distribution that models the variable.

2.7 ULMFit

ULMFit achieves fine-tuning using gradual unfreezing. Instead of fine-tuning all layers at once, which risks catastrophic forgetting, ULMFit gradually unfreezes the model starting from the last layer. The last layer is unfrozen first and all unfrozen layers are fine-tuned for one epoch. Then the next group of frozen layers is unfrozen and fine-tuned and the process is repeated until all layers are fine-tuned until convergence at the last iteration.

3. Reparametrization-based Methods

Reparametrization-based methods aim to find the low-rank representation (essentially smaller dimensions) of the updates that will be incorporated into the parameters of a pretrained model for a downstream task. The principle behind this is that neural networks possess equivalent low-dimensional representations.

3.1 Intrinsic SAID

Intrinsic SAID uses the Fastfood transform to reparametrize the update to the model weights. The model weights, which will be added to the pretrained model weight, are learned through the matrices H (Hadamard matrix), G (random diagonal matrix with independent standard normal entries), B (random diagonal matrix with equal probability ±1 entries), and Π (random permutation matrix). After training, the matrix M=HGΠHB is added to the pretrained model weights. This method essentially transforms the model update operation into a more manageable format.

3.2 LoRA

LoRA, or Low Rank Adaptation of LLM, takes inspiration from IntrinsicSAID and proposes a simpler way to perform low-rank fine-tuning. The update for a weight matrix in LoRA is decomposed into a product of just two low-rank matrices, unlike the Fastfood Transform used in IntrinsicSAID. This simplification reduces the complexity of the update operation and makes it more efficient.

3.3 KronA

KronA replaces the matrix factorization in LoRA with a matrix factorization through a Kronecker product δW = WA ⊗ WB. This yields a better rank per parameters tradeoff because the Kronecker product maintains the rank of the original matrices being multiplied. In other words, rank(A ⊗ B) = rank A·rank B. This method allows for more efficient use of parameters and maintains the rank properties of the original matrices. #### AdaLoRA

3.4 AdaLoRA

AdaLoRA proposes an SVD (Singular Value Decomposition) inspired decomposition of the adapter matrices and develops various importance scores to assess which triplets in the SVD decomposition can be removed. This allows adaptively tuning the ranks of the adapter matrices across layers. This method provides a dynamic way to adjust the rank of the adapter matrices, allowing for more flexibility and efficiency in the fine-tuning process.

4. Hybrid Methods

Hybrid methods in parameter-efficient fine-tuning (PEFT) for large language models (LLMs) amalgamate ideas from different PEFT categories to optimize performance while minimizing computational expenses associated with fine-tuning extensive neural networks. They are essentially a harmonious blend of multiple strategies, each contributing its strengths and counteracting the weaknesses of others, thereby leading to enhanced performance and efficiency.

4.1 Quantized LoRA (QLoRA)

Quantized LoRA (QLoRA) is a hybrid method that begins with quantizing the pretrained LLM, followed by standard LoRA training. QLoRA introduces a series of innovative features to conserve memory without compromising performance. These include 4-bit NormalFloat (NF4), a novel data type that is ideally suited for normally distributed weights, Double quantization, a technique to reduce the average memory footprint by quantizing the quantization constants, Paged optimizers, a tool to manage memory spikes.

4.2 SparseAdapter

The SparseAdapter method employs a large hidden dimension for the added module and prunes about 40% of the values at initialization. While it consistently outperforms its non-sparse counterpart with the same trainable parameter count, it's important to note that the training and inference costs can be higher due to hardware support requirements for sparse tensors and operations. Additionally, calculating the pruning mask for this method may necessitate obtaining gradients for all newly added parameters.

4.3 MAM Adapters

MAM Adapters is a hybrid approach that combines the concepts of adapters and soft prompting. It capitalizes on the fact that scaled parallel adapters perform better than sequentially-placed adapters, and an adapter placed in parallel to the Feed Forward Network (FFN) outperforms multi-head attention-parallel adapters. Moreover, it utilizes the efficiency of soft prompts in modifying attentions by altering just 0.1% of the parameters.

4.4 UniPELT

UniPELT is a hybrid method that incorporates LoRA, Prefix-tuning, and Adapters. Specifically, it uses LoRA reparametrization for WQ and WV attention matrices, applies prefix-tuning to keys and values of each layer, and adds adapters after the feed-forward layer of the transformer block. For each of these modules, gating is implemented as a linear layer that projects the module input into a dimension of size one, applies a sigmoid activation, and averages the resulting vector over the sequence length.

4.5 Compacter

The Compacter method, as proposed by Karimi Mahabadi et al., 2021, leverages the Kronecker product, low-rank matrices, and parameter sharing across layers to generate adapter weights.

4.6 S4

The S4 method carries out a thorough exploration of diverse combinations of parameter-efficient fine-tuning techniques. Its search space includes dividing consecutive layers into four uneven groups, allocating variable amounts of trainable parameters to each layer, deciding which groups to fine-tune, and determining the PEFT methods to apply to each group.

Conclusion

In conclusion, the pursuit of parameter efficiency in fine-tuning LLMs is a critical aspect of contemporary machine learning research. By successfully leveraging the existing parameters of a pre-trained model and minimizing the addition or modification of new parameters, PEFT methods offer a promising solution to the computational and memory challenges associated with fine-tuning large models. As we continue to push the boundaries of what’s possible with machine learning and artificial intelligence, these methods will undoubtedly play a pivotal role in shaping the future of the field.

]]>
Sagar Sapkota[email protected]https://spktsagar.com
Misnomer Alert: Dot Product and Inner Product are not the Same2023-09-19T00:00:00-07:002023-09-19T00:00:00-07:00https://spktsagar.com/posts/2023/09/inner-product-vs-dot-productLet’s jump in by calculating the inner and dot product of two example vectors using NumPy:

import numpy as np

x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
np.inner(x, y)
32
np.dot(x, y)
32

From the both np.inner and np.dot function, we got the same result. When popular libraries evaluates both function to same value, it is tempting to think of them as same. Not only that, many books and experts, particularly in machine learning, use the term “inner product” and “dot product” interchangeably, assuming they refer to the same mathematical operation. I also used them in similar fashion until I read the book “Mathematics for Machine Learning” by M. P. Deisenroth, A. A. Faisal, and C. S. Ong. In this blog post, I aim to clear up this misconception and shed light on the differences between these fundamental product in linear algebra.

Dot Product

Why not start with the definition of dot product, so called “inner product”, most of us know. The dot product of vectors \(\mathbf{x}, \mathbf{y} \epsilon \mathbb{R}^{n}\) is defined as:

\[\mathbf{x}\cdotp\mathbf{y} = \mathbf{x}^{\top}\mathbf{y}=\sum_{i=1}^{n}x_iy_i\]

where \(x_i\) and \(y_i\) are the \(i\)th elements of vectors \(\mathbf{x}\) and \(\mathbf{y}\) respectively. This is the dot product we use to calculate length of a vector and angle or similarity between two vectors. The NumPy functions, np.dot and np.inner did the same operation in above example. However, inner product is much more than this definition of dot product. Spoiler alert, the dot product is a specific instance of inner product. We’ll see how shortly.

Inner Product

Please bear with me while I give you the formal(mathematical) definition of the inner product using the fundamental concepts of linear algebra. Let’s start with the linear mapping:

Consider two vector spaces \(V\) and \(W\), a mapping \(Φ:V\rightarrow W\) is called linear mapping if

\[\forall \mathbf{x}, \mathbf{y} \epsilon V \forall\lambda,\varphi \epsilon \mathbb{R}: \Phi(\lambda\mathbf{x} + \varphi\mathbf{y})=\lambda\Phi(\mathbf{x})+\varphi\Phi(\mathbf{y})\]

In simple and intuitive terms, a tranformation function between two vector spaces that preserves the origin, collinearity and parallelism is called linear mapping.

Let’s expand the linear mapping to two arguments function called bilinear mapping:

Consider two vector spaces \(V\) and \(W\), a mapping with two arguments \(Ω: V\times V\rightarrow W\) is called bilinear mapping if \(\forall \mathbf{x}, \mathbf{y}, \mathbf{z} \epsilon V \forall\lambda,\varphi \epsilon \mathbb{R}:\)

\[\Omega(λ\mathbf{x}+\varphi\mathbf{y}, \mathbf{z})=\lambda\Omega(\mathbf{x}, \mathbf{z})+\varphi\Omega(\mathbf{y}, \mathbf{z})\] \[\Omega(\mathbf{x}, λ\mathbf{y}+\varphi\mathbf{z})=\lambda\Omega(\mathbf{x}, \mathbf{y})+\varphi\Omega(\mathbf{x}, \mathbf{z})\]

With these two definitions in hand, we can define the inner product formally as:

Consider a vector space \(V\), a bilinear mapping \(\Omega:V\times V\rightarrow \mathbb{R}\) is called inner product on \(V\) if it satisfies the following two conditions:

  • \(\Omega\) is symmetric i.e., \(\forall\mathbf{x},\mathbf{y}\epsilon V:\Omega(\mathbf{x},\mathbf{y})=\Omega(\mathbf{y},\mathbf{x})\)
  • \(\Omega\) is positive definite i.e., \(\forall\mathbf{x}\epsilon V-\{\mathbf{0}\}: \Omega(\mathbf{x}, \mathbf{x})>0\) and \(\Omega(\mathbf{x}, \mathbf{x})=0\) if and only if \(\mathbf{x}=0\)

From the definition, we can simply say that inner product can be any functions that takes two vectors as arguments, outputs a real number, is symmetric meaning we can interchange the arguments and evaluates to positive real number when both arguments are same.

Relationship Between Dot Product and Inner Product

To find the relationship between dot product and inner product, let’s expand the definition of inner product in terms of basis of the vector space and coordinate representation of vectors.

Consider \(B=(\mathbf{b}_1, \dots, \mathbf{b}_n)\) be the basis of the \(n\)-dimensional vector space \(V\). Also, let’s consider two vectors \(\mathbf{x}, \mathbf{y}\epsilon V\) in terms coordinate vectors \(\hat{\mathbf{x}}, \hat{\mathbf{y}}\) respectively with respect to the basis \(B\).

\[\mathbf{x}=\sum_{i=1}^{n}\varphi_i\mathbf{b}_i, \space \mathbf{y}=\sum_{j=1}^{n}\lambda_j\mathbf{b}_j\] \[\hat{\mathbf{x}}=\begin{bmatrix}\varphi_1\\\vdots\\\varphi_n\end{bmatrix}, \space \hat{\mathbf{y}}=\begin{bmatrix}\lambda_1\\\vdots\\\lambda_n\end{bmatrix}\]

where \(\varphi_i\) and \(\lambda_j\) is the \(i\)th and \(j\)th elements of the coordinate representation/coordinate vectors \(\hat{\mathbf{x}}\) and \(\hat{\mathbf{y}}\) respectively.

Now, we can write the inner product \(\Omega\), typically represented as \(\left \langle \cdot, \cdot \right \rangle\), as the following:

\[\left \langle \mathbf{x}, \mathbf{y} \right \rangle=\left \langle \sum_{i=1}^{n}\varphi_i\mathbf{b}_i, \sum_{j=1}^{n}\lambda_j\mathbf{b}_j \right \rangle\]

Using the property of bilinearity defined above:

\[\left \langle \mathbf{x}, \mathbf{y} \right \rangle=\sum_{i=1}^{n}\sum_{j=1}^{n}\varphi_i\left \langle \mathbf{b}_i, \mathbf{b}_j \right \rangle\lambda_j=\hat{\mathbf{x}}^\top\mathbf{A}\hat{\mathbf{y}}\]

where \(A_{ij}=\left \langle \mathbf{b}_i, \mathbf{b}_j \right \rangle\) i.e., the inner product of the basis \(\mathbf{b}_i\) and \(\mathbf{b}_j\).

From here, we can see that we can define inner product using a square matrix \(\mathbf{A}\epsilon\mathbb{R}^{n\times n}\) that is symmetric and positive definite.

Now, we can finally relate dot product with the inner product. When the symmetric and positive definite matrix that governed the inner product is an identity matrix, the inner product can be called as dot product. Formally, the dot product on \(n\)-dimensional vector space \(V\) is defined as a bilinear mapping \(\Omega: V\times V\rightarrow \mathbb{R}\) where:

\[\forall \mathbf{x},\mathbf{y}\epsilon V: \Omega(\mathbf{x},\mathbf{y})=\mathbf{x}^\top\mathbf{I}_n\mathbf{y}=\mathbf{x}^\top\mathbf{y}\]

where \(\mathbf{I}_n\) is an \(n\times n\) identity matrix.

Programmatic Implementation

Let’s implement a full-fledged inner product using NumPy:

def inner_product(x, y, A):
    """Returns inner product of x and y governed by matrix A
    Arguments:
    x = 1-D array of shape (n,)
    y = 1-D array of shape (n,)
    A = 2-D array of shape (n, n)
    """
    x_vec, y_vec = x[:, np.newaxis], y[:, np.newaxis]
    xTA = np.matmul(x_vec.T, A)
    in_prod = np.matmul(xTA, y_vec)
    return in_prod

Define example vectors and matrices:

x = np.array([1, 2])
y = np.array([3, 4])

# An example symmetric positive definite matrix
A = np.array(
    [[9, 6],
     [6, 5]]
)

# An indentity matrix
I = np.array(
    [[1, 0],
     [0, 1]]
)

Calculate inner product governed by A

inner_product(x, y, A)
array([[127]])

Calculate inner product governed by I. We’ll get the dot product.

inner_product(x, y, I)
array([[11]])

Let’s verify with the numpy implementation.

np.dot(x, y)
11

You’ve come to the end of the blog post. This post has illuminated the distinction and relationship between the inner product and the dot product in linear algebra. While most of us use them interchangeably, it’s crucial to understand that the inner product is a broader concept encompassing a bilinear mapping governed by a symmetric, positive definite matrix. On the other hand, the dot product is a specific instance of the inner product when the governing matrix is the identity matrix.

]]>
Sagar Sapkota[email protected]https://spktsagar.com
Defeating The Size: Working with Large Tabular Data on AWS S3 using Snowpark2023-09-10T00:00:00-07:002023-09-10T00:00:00-07:00https://spktsagar.com/posts/2023/09/S3-to-SnowparkWelcome to this blog post where we’ll dive into a powerful combination of tools for working with large tabular data: AWS S3, Snowpark and Snowflake. These tools, when used in tandem, enable seamless processing and analysis of large datasets stored in AWS S3. In this blog, we’ll walk through a sample code that leverages these technologies to work with the classic Iris dataset. This code can be easily adapted to handle much larger datasets, making it a valuable addition to your data engineering and machine learning toolkit.

Setting the Stage

Before we get started, let’s set up our environment. We’ll be using Python along with some essential libraries.

import numpy as np
import pandas as pd
from sklearn.datasets import load_iris

The first block of code imports the necessary libraries. numpy and pandas are widely used for numerical computations and data manipulation, while load_iris is a convenient function for loading the Iris dataset.

Loading and Preparing the Data

Next, for demo purpose, we load the Iris dataset and convert it into a Pandas DataFrame.

iris = load_iris()
df = pd.DataFrame(iris["data"], columns=iris["feature_names"])
df["target"] = iris["target"]
df.head()
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0

Here, we’ve converted the Iris data into a structured DataFrame, making it easier to work with. This DataFrame includes features like sepal length, sepal width, petal length, petal width, and a target variable.

Uploading Data to AWS S3

Now, let’s upload our DataFrame to AWS S3.

import boto3

# You can use your custom bucket and prefixes
bucket, key = "sagemaker-lal-stage", "iris.csv"
s3_client = boto3.client("s3")
s3_client.put_object(
    Bucket=bucket,
    Key=key,
    Body=df.to_csv(index=False),
    ContentType="text/csv",
)

By executing this code, we’re placing the data in an S3 bucket, making it accessible for further processing. You can upload your data of any size to S3 any method. Just note the bucket name and prefixes of the data files.

Establishing a Connection with Snowflake

The next step involves connecting to Snowflake using Snowpark, a powerful tool for data processing in Snowflake.

from snowflake.snowpark import Session

connection_parameters = {
    "account": "####SF_ACCOUNT_NAME######",
    "user": "#####USER#######",
    "password": "#####PASSWORD#######",
    "warehouse": "#####WAREHOUSE########",
    "database": "#####DATABASE########",
    "schema": "#####SCHEMA(PUBLIC)########",
}

session = Session.builder.configs(connection_parameters).create()

We establish a connection using session instantiated with the specified parameters. This allows us to interact with the Snowflake database.

Creating a Temporary Table

Now, let’s create a temporary table in Snowflake where we’ll load our data from S3.

table_name = "iris_dataset"
session.sql(f"""create temporary table {table_name} (
    SEPAL_LENGTH float,
    SEPAL_WIDTH float,
    PETAL_LENGTH float,
    PETAL_WIDTH float,
    TARGET integer
)"""
).collect()

This code creates a temporary table in Snowflake with the same structure as our DataFrame. This is where we’ll be loading our data. Temporary table are destroyed when our session is terminated. In my experience, temporary table are really fast in comparison to standard table or transient table available in snowflake. So, it is preferable to work on temporary table if we don’t need data persisting on Snowflake.

Copying Data from S3 to Snowflake

With our temporary table in place, let’s copy the data from S3 into Snowflake.

session.sql(f"""copy into {table_name}
from 's3://{bucket}/{key}'
credentials=( AWS_KEY_ID='#######AWS_KEY_ID#######' AWS_SECRET_KEY='#######AWS_SECRET_KEY#######')
file_format=(TYPE=CSV COMPRESSION=NONE SKIP_HEADER=1)
"""
).collect()

This command transfers the data from S3 to our Snowflake table using the specified credentials(which have permission to get, put and delete object in AWS S3). All files with prefix {key} in bucket {bucket} are processed by the above command. Note that while uploading data to s3, we uploaded with header. So, we tell Snowflake to ignore header with SKIP_HEADER=1. Also, no compression was done in the uploaded data. You can mention compression type if your data is compressed in any way.

Analyzing Data with Snowpark

Now that our data is in Snowflake, we can perform any operation we like. We’ll be using Snowpark’s DataFrame capabilities for this.

sdf = session.table(table_name)

This line creates a Snowpark DataFrame from our Snowflake table.

We can display the table by changing the Snowpark DataFrame to Pandas Dataframe.

sdf.to_pandas()
SEPAL_LENGTH SEPAL_WIDTH PETAL_LENGTH PETAL_WIDTH TARGET
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0
... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2
146 6.3 2.5 5.0 1.9 2
147 6.5 3.0 5.2 2.0 2
148 6.2 3.4 5.4 2.3 2
149 5.9 3.0 5.1 1.8 2

150 rows × 5 columns

Computing Pairwise Correlations

For demo, let’s start by computing pairwise correlations between sepal length and sepal width. For reference, the formula of Pearson Correlation is: \(r=\frac{\sum ( x_{i} -\overline{x})( y_{i} -\overline{y})}{\sqrt{\sum ( x_{i} -\overline{x})^{2}\sum ( y_{i} -\overline{y})^{2}}}\)

from snowflake.snowpark import DataFrame, Column
from snowflake.snowpark import functions as spf

def pair_correlation(df: DataFrame, x: str, y: str) -> DataFrame:
    # broadcast mean using `over`
    x_diff = df[x] - spf.mean(df[x]).over()
    y_diff = df[y] - spf.mean(df[y]).over()
    
    # Store results in columns
    df = df.with_columns(
        ["x_diff", "y_diff"],
        [x_diff, y_diff],
    )
    
    numerator = spf.sum(df["x_diff"]*df["y_diff"])
    denominator = spf.sqrt(spf.sum(spf.pow(df["x_diff"], 2))*spf.sum(spf.pow(df["y_diff"], 2)))
    
    # prepare dataframe with pair values in first two columns and correlation value in last column
    return df.select(
        spf.lit(x).alias("FEAT1"),
        spf.lit(y).alias("FEAT2"),
        (numerator/denominator).alias("VALUE"),
    )

Here, we define a function pair_correlation that accepts the Snowpark Dataframe and two columns name whose correlation is to be determined, and it will return a new dataframe with results. This function leverages Snowpark’s powerful functions for data manipulation.

pair_correlation(sdf, "SEPAL_LENGTH", "SEPAL_WIDTH").to_pandas()
FEAT1 FEAT2 VALUE
0 SEPAL_LENGTH SEPAL_WIDTH -0.11757

You can continue with other analysis as you require.

Wrapping Up

With the analysis complete, we close our Snowflake session.

session.close()

Additionally, we clean up our S3 bucket by deleting the uploaded data.

s3_client.delete_object(
    Bucket=bucket,
    Key=key,
)

And there you have it! We’ve walked through the steps of loading data into Snowflake from AWS S3, conducting various correlation analyses, and finally, cleaning up our environment.

This powerful combination of Snowpark and Snowflake opens up a world of possibilities for handling large tabular datasets. Whether you’re a data scientist or a machine learning engineer, having these tools in your arsenal can significantly enhance your data processing capabilities. Experiment with your own datasets and unlock valuable insights!

Happy coding!

]]>
Sagar Sapkota[email protected]https://spktsagar.com
Terraform a Scalable Comprehensive Sagemaker MultiModel Pipeline2023-07-19T00:00:00-07:002023-07-19T00:00:00-07:00https://spktsagar.com/posts/2023/07/terraform-sagemaker-multimodel-pipelineTable of Contents

In the ever-growing realm of machine learning, managing complex workflows efficiently is a crucial aspect. One such project that I recently did is a comprehensive, scalable SageMaker pipeline designed for training, deploying, and monitoring multiple models on varying datasets to a single endpoint. Built on the robust Amazon SageMaker platform, this project offers an end-to-end solution for defining and managing machine learning workflows. It covers everything from data preprocessing, model training, tuning, to the final deployment phase.

The beauty of this pipeline lies in its scalability and cost-effectiveness. It employs multi-model endpoints, which use a shared serving container and the same fleet of resources to host all models. This approach significantly cuts down hosting costs by improving endpoint utilization compared to using single-model endpoints. The pipeline is meticulously designed to tune, train, and deploy multiple models on individual datasets, providing users with the flexibility to work with the data and infrastructures they prefer.

Another aspect of this project is that all the resources involved in implementing the pipeline are orchestrated using a Terraform script. In the upcoming sections of this blog post, we will delve deeper into the implementation details of this project exploring its architecture, codes and resources. Buckle up to gain a deeper understanding of this fascinating SageMaker pipeline project!

Architecture

Now that we have a high-level understanding of the project, let’s dive deeper into the specific architecture of the pipeline. The detailed architecture diagram, presented next, provides a visual representation of the entire system, showcasing the relationships between different components and steps involved in the pipeline. This diagram will help us comprehend the structure of the system and the sequential order of events in the pipeline.

The pipeline of this project is divided into two sections: the Modeling Pipeline and the Deployment Pipeline. Let’s delve deeper into the steps involved in each of these sections:

Modeling Pipeline

The Modeling Pipeline is the initial phase where the model is trained and prepared for deployment. Here’s what happens in each step:

  • Preprocessing: The raw data is cleaned and transformed into a suitable format for model training. This step is crucial as quality data is a prerequisite for training an effective model. It involves handling missing values, removing outliers, encoding categorical variables, and so on.

  • Hyperparameter Tuning: This step involves searching for the best hyperparameters for the model. Hyperparameters are the configuration variables that govern the training process of a model. For instance, the learning rate in a neural network is a hyperparameter. A hyperparameter tuning algorithm, like grid search or random search, is used to explore different combinations of hyperparameters to find the optimal set that minimizes the loss function.

  • Refit Best Model: After the best hyperparameters are found, the model is trained again using these hyperparameters. This step ensures that the model is the best version of itself before it is evaluated and potentially deployed.

  • Evaluate Best Model: The performance of the best model is evaluated in this step. This is done using a holdout validation set that the model has never seen before. Evaluation metrics like accuracy, precision, recall, or AUC-ROC (for classification tasks), or MSE, MAE, R2 score (for regression tasks) are computed.

  • Registration Metric Check: The model’s performance metrics are checked against a predefined threshold or previous models’ performance to decide whether to register the model in the registry. This step ensures that only models that meet the quality standards are registered for deployment.

  • Model Package Registration Step: If the model passes the registration metric check, it is registered to the SageMaker Model Registry. This registry serves as a repository where trained models are stored before they are deployed.

Deployment Pipeline

The Deployment Pipeline is the second phase where the registered models are deployed for serving predictions.

  • The pipeline listens to approval/rejection events in the SageMaker Model Registry via AWS EventBridge. An approval event triggers the deployment of the approved model.

  • The approved models are deployed to an endpoint’s multi-model artifact location in S3 using a Lambda function. AWS Lambda is a serverless compute service that lets you run your code without provisioning or managing servers.

  • Another AWS Lambda function with a Function URL is used to interact with the SageMaker endpoint. This function can be used to send data to the endpoint for inference and receive predictions.

  • The scalability of the SageMaker endpoint is managed by AWS Application Auto Scaling. This service can automatically adjust the capacity of the endpoint to maintain steady, predictable performance at the lowest possible cost.

Overall, these steps ensure a streamlined process from data preprocessing to model deployment, providing an efficient and scalable solution for machine learning workflows.

Implementation

Now that we’ve understood the architecture of the pipeline, it’s time to delve into the implementation details. In the following section, we will walk through the process of implementing the pipeline to train separate models for two different datasets: the Breast Cancer dataset and the Bank Note Authentication dataset. The pipeline will do a hyperparameter search across multiple SKLearn, XGBoost and LightGBM classifiers and deploys the best-found model for each dataset in a single scalable multimodel Sagemaker endpoint. You’ll find that you can adjust the scripts to include more datasets and custom preprocessing and training steps as you go through the blog. Our implementation will ensure that all steps from beginning to the end in the pipeline are scalable.

In this blog, I’ll describe the main code snippets and components of the implementation. The complete implementation can be found in the following repo. This post is just a complementary to the Github Repo. Please read the README and USAGE_GUIDELINE.

Extending Prebuilt Sagemaker SKLearn Image

Docker images are the real backbone of the Sagemaker; whatever processing we do in each step of pipeline, we do it inside the container. In this section, we will discuss how to extend a prebuilt Sagemaker SKLearn Image. The main purpose of this is to add additional libraries that are not included in the prebuilt image, in this case, LightGBM and XGBoost. You can create custom Docker image from scratch if you want.

The Dockerfile for extending the SageMaker SKLearn Image can be written as follows:

From this Dockerfile, we then create docker image and push it to ecr using the following script. The script requires aws account id, aws region, and image name passed to it while running. Don’t worry about running this script for now. We run this script using Terraform.

ML Modeling Pipeline

Now that we have Docker container ready on which we can run our modeling jobs, let’s create scripts for those jobs. Let’s begin with creating a python script where we’ll define the pipeline with Sagemaker Python SDK.

Pipeline Definition:

We’ll create a sagemaker pipeline definition python called pipeline.py. This file will define all the steps in Modeling Pipeline. For now, let’s import necessary dependencies, create arguments and sagemaker sessions necessary for the pipeline. Further down, we’ll discuss the implementation of each step.

Preprocessing Step:

We’ll prepare a python script for downloading, splitting, and saving two specific datasets: the Breast Cancer and the Banknote Authentication datasets.The main block of the script uses argparse to parse command line arguments for selecting the dataset, specifying the test size, random state, stratification, target column name and file name. It then calls the appropriate function to get the requested dataset. Finally, the script creates directories for saving the training and testing datasets and saves these datasets as CSV files. You can modify the script to include other datasets and preprocessing steps. You just need to take care of dataset argument passed to the script and process that dataset only.

We’ll call this script from preprocessing step in the pipeline. In the pipeline.py, we define Preprocessing step using Sagemaker ScriptProcessor which runs the above script. We’ll also define parameters that the user passes during the pipeline execution. Also, note that we have specified where the output artifacts like train/test split of the preprocessing steps should be dumped so that sagemaker will upload them to s3.

Hyperparameter Tuning With CrossValidation Step:

Hyperparameter tuning is basically running multiple training with different combination of model hyperparameters and selecting the model hyperparameter which gives the best result. When we do tuning with k fold cross-validation, we select models that has highest average performance across the folds. The tuning with cross-validation and training steps differs from each other in such a way that we look for peformance metric in tuning but train the model in whole set in training. We’ll prepare script which can do both based on the cross-validation argument flag passed to it. Also, note that we’re not looking for best hyperparameter of a single model in this pipeline. Models including RandomForest, Naive Bayes, NeuralNet, LogisticRegression, XGBoost, and LightGBM competes with each other during tuning. So for each models we create script that does cross-validation and training on whole dataset. For example, Following is the script for LightGBM:

You can find script for other estimators in my github repo.

Now the scripts are ready and later we’ll upload them to s3 using terraform, we can work on implementing tuning step in the pipeline. First we’ll create a dictionary of estimators where key is estimator name that match with the script name for training and value is sagemaker estimator. Along with the paramters required for the tuning steps, we define the hyperparameter search space for each estimators. Please have a look at the following code snippet.

With those estimator and hyperparameters definition, we can create tuning step of the pipeline in the following way.

Refit Best Model:

After tuning job is completed, we need to retrain the model with the best found parameters on the whole training dataset. For that we’ll create a script called refit.py. The Python script is essentially a command-line interface for launching AWS SageMaker training jobs. It extracts hyperparameters from specified best training job, disables cross-validation, and runs a Python script associated with a chosen algorithm using those hyperparameters. Have a look at the script.

We can use this script to create best estimator in sagemaker and add Training Step in the pipeline which will act as Refit Best Model step in the pipeline.

Evaluation Step:

Now that we’ve trained the model in best found parameters, we can measure it’s performance in test split we created in the preprocessing step. Let’s create a evaluation script.The provided Python script is designed to evaluate the performance of a binary classification machine learning model. It begins by parsing command-line arguments for the test data file, the metric to register, the features, and the target feature. The script then extracts a model from a tarball file, loads the test data from a CSV file, and prepares the test data for evaluation. The model is used to make predictions on the test data and calculate several evaluation metrics, including precision, recall, accuracy, f1 score, Roc Auc score, and confusion matrix. Finally, a report including all the calculated metrics is generated and saved to a JSON file.

Next step is to add evaluation step in the pipeline. Following is the code snippet for that. The snippet initiates by defining parameters for model registration in the registry, including the evaluation metric, the minimum threshold for the metric, and the approval status. It then sets up an evaluation step to assess the best model using a ScriptProcessor object, which runs a Python script with specified arguments, inputs, and outputs. The evaluation step is encapsulated in a ProcessingStep` object, which signifies this step in the SageMaker pipeline.

Model Registry Step:

After the evaluation is complete, if the model performance exceeds some predefined threshold, we can register the model in model registry. The follwing provided Python code snippet leverages the Amazon SageMaker Python SDK to create and register a machine learning model. Initially, it defines the model metrics and creates an instance of the model using the SKLearnModel class. The model is then registered using its register method, where important parameters like the content types for the request and response, instances for inference and transformation, and the model metrics are specified. A condition check is implemented to verify if the model’s evaluation metric value meets the required threshold. Depending on the condition’s outcome, the model registration step or a failure step is executed. This code provides a thorough procedure for setting up a machine learning model, registering it, and evaluating its performance.

This concludes the modeling pipeline. Now, we’ll delve into the deployment pipeline.

Deployment Pipeline

The Deployment Pipeline is the phase where approved machine learning models, registered in the SageMaker Model Registry, are deployed for serving predictions. The pipeline uses AWS Lambda functions to deploy the models to an S3 location and to interact with the SageMaker endpoint for sending data and receiving predictions. AWS Application Auto Scaling manages the scalability of the endpoint, ensuring steady performance at the lowest cost. This process provides an efficient and scalable solution for machine learning workflows. Let’s begin the pipeline defintion with the lambda function that handles the events in Sagemaker Model Registry like model approval or rejection and updates the model for each dataset in serving.

Deployer Lambda Function:

The Python script below is designed to manage multiple machine learning models in the Amazon Web Services (AWS) environment, specifically using Amazon SageMaker and Amazon S3. It uses the Boto3 library, the Python SDK for AWS, to interact with these services. The script fetches the latest approved model package from a model group, stores a map of group names(dataset) to latest models in an S3 bucket, updates the model in endpoint model artifact S3 bucket, and handles EventBridge events that are triggered when the models are approved or rejected in the sagemaker model registry.

Invoker Lambda Function:

The following Python script is an AWS Lambda function designed to interact with a SageMaker endpoint. It imports necessary libraries, sets up AWS clients, and defines two functions. The function get_groupname2model_map retrieves a JSON object from an S3 bucket mapping group names(dataset) to latest model names. The main function, lambda_handler, retrieves the group-to-model map, and checks if the specified model name exists in the map. If it does, it invokes the SageMaker endpoint with that model as target model, passing in the data from the event body, and returns the result with a 200 status code. If the model name doesn’t exist, it returns a 404 status code with an error message. Essentially, this script serves as an interface between an HTTP request and a multiple machine learning model hosted on AWS SageMaker.

Multimodel Endpoint:

We’ll need a sagemaker endpoint along with model and endpoint configuration in multi-model mode. These resources will be created from Terraform which is discussed in the following section.

Terraform The Pipeline

We’ll create all the resources including AWS Infrastructure, Docker Images, Pipeline Definition, etc. using the following terraform file. Here is the description for each resources used in this Terraform.

  • First, we define the required providers and variables for the pipeline.
  • Next, we create an AWS IAM Role with access to necessary resources.
  • A ECR repository is then created to host the custom images used in the pipeline.
  • Since, terraform doesn’t have inbuilt capacity to build an push docker images, we run the build_and_push.sh script we created at the beginning using Terraform’s null_resource.
  • A S3 bucket for storing artifacts for and from the pipeline is created.
  • tarball file of the source scripts is created, and uploaded to the bucket created above along with the separate preprocessing and evaluation script.
  • The pipeline.py file is ran again with the Terraform’s null_resource.
  • The produced pipeline definition file is uploaded to the S3 bucket.
  • A sagemaker pipeline is then created with the content of the pipeline definition file
  • Next, a sagemaker endpoint model in MultiModel mode is created. It will points to a location is the S3 bucket where all the models for endpoint is stored. Our deployer lambda function will update the model in this location.
  • A sagemaker endpoint configuration with attributes before scaling is created.
  • The endpoint is than deployed with the configuration just created above.
  • AWS App Auto Scaling is deployed with the endpoint as the target. The policy used is based on the CPU utilization of the instances in the sagemaker endpoint.
  • The deployment lambda function is created.
  • The EventBridge rule with the deployment lambda function as a target is created to handle the event the model registry.
  • The invocation lambda function is created and function url pointing to the same lambda function is created.

We’ve come to the conclusion of the blog. Here, we created a scalable end-to-end ML Pipeline. For usage description, please refer to my github repo.

]]>
Sagar Sapkota[email protected]https://spktsagar.com
Implementation and Empirical Analysis of Multi-Armed Bandit Problem2023-06-17T00:00:00-07:002023-06-17T00:00:00-07:00https://spktsagar.com/posts/2023/06/multi-armedbandit

Welcome to my latest blog post! Today, I am excited to share my recent exploration into the fascinating world of reinforcement learning, specifically focusing on the multi-armed bandit problem and its various solutions. As a foundation for my implementation, I closely followed the insightful book, Reinforcement Learning: An Introduction (second edition) by Richard S. Sutton and Andrew G. Barto. In this post, I will walk you through my journey, discussing key concepts and algorithms presented in Chapter 2 of the book, while also providing you with code examples and explanations to help you grasp these intriguing topics. So, let’s dive into the world of multi-armed bandits and reinforcement learning together!

Multi-Armed Bandit Problem

The Multi-Armed Bandit problem is a classic reinforcement learning challenge that exemplifies the exploration-exploitation tradeoff dilemma. Imagine a gambler(an agent in terms of RL terminology) in front of a row of slot machines, also known as “one-armed bandits.” The gambler needs to decide which machines to play, how many times to play each machine, in which order to play them, and whether to continue with the current machine or try a different one. Each machine provides a random reward from a probability distribution specific to that machine, which is unknown to the gambler. The objective is to maximize the sum of rewards earned through a sequence of lever pulls.

The critical tradeoff the gambler faces at each trial is between “exploitation” of the machine that has the highest expected payoff and “exploration” to gather more information about the expected payoffs of the other machines. In practice, multi-armed bandits have been used to model problems such as managing research projects in a large organization or optimizing marketing strategies.

In the following sections, I will delve into the implementation of the Multi-Armed Bandit problem. This implementation will follow the order of problems and solutions presented in the book by Sutton and Barto. Throughout this process, I will make an effort to compare the performance of various algorithms with each other, allowing us to evaluate their effectiveness in addressing the exploration-exploitation tradeoff.

Let’s begin by importing the libraries and setting them up!

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

# For each multiarmed bandit experiment we'll have two plots, displayed horizontally
matplotlib.rcParams['figure.figsize'] = [20, 5]
matplotlib.rcParams['axes.prop_cycle'] = matplotlib.cycler(color=['g', 'b', 'r', "y"])

Stationary Multi-Armed Bandit

The simplest setting for the multi-armed bandit problem is when the reward for choosing an arm remains constant over time. In this scenario, each arm has a fixed probability distribution for its rewards, and the gambler’s objective remains the same: to maximize the cumulative reward over a series of trials.

class StationaryMultiArmedBandit:
    """
    Represents a stationary multi-armed bandit problem.

    Attributes:
        k (int): Number of arms.
        runs (int): Number of independent runs.
        random_state (int, optional): Random seed for reproducibility.
    """
    def __init__(
            self,
            k,
            runs,
            random_state=None,
    ):
        self.k = k
        self.runs = runs
        self.random_state = random_state

        self.setup()

    def setup(self):
        """Set up the seed for reproducibility and reward distribution"""
        self.nprandom = np.random.RandomState(self.random_state)
        self.q_star = self.nprandom.normal(
            loc=0.0,
            scale=1.0,
            size=(self.runs, self.k),
        )

    def get_reward(self, action):
        """Given the action, return the reward"""
        reward = self.nprandom.normal(
            loc=self.q_star[np.arange(self.runs), action],
            scale=1.0,
        )
        return reward

    def get_correct_action(self):
        """
        Get the correct action for each run.
        Correct action for each run is the one with highest mean reward
        """
        return self.q_star.argmax(axis=1)

    def plot_reward_distribution(self, run=0):
        """Plot the reward distribution for the given run."""
        samples = self.nprandom.normal(
            loc=self.q_star[run],
            scale=1.0,
            size=(10_000, self.k),
        )
        plt.violinplot(samples, showmeans=True)
        plt.xlabel('Action')
        plt.ylabel('Reward Distribution')
        plt.show()

For all experiments, we’ll aggregate the performance across 2000 independent 10-armed bandit runs. Let’s create variables to store these numbers.

runs = 2000
k = 10

Now let’s create an instance of Stationary Multi-Armed Bandit problem, and see its reward distribution. The agent will be unknown to this reward distribution. The ideal agent should be able to learn the correct action that gives the highest reward.

st_bandit = StationaryMultiArmedBandit(k=k, runs=runs)
st_bandit.plot_reward_distribution(run=0)
print(f"Correct action for run 0: {st_bandit.get_correct_action()[0] + 1}")

png

Correct action for run 0: 5

Agent

An agent in reinforcement learning is an entity that learns the actions that yield the highest reward in the long run. The agent has two primary roles:

  1. acting on the environment based on the current estimate of action values and
  2. updating the action values estimate based on the rewards it receives.

One of the challenges faced by the agent is the exploration-exploitation dilemma. In this dilemma, the agent must decide whether to explore new actions to gain more knowledge about the environment or exploit its current knowledge to maximize immediate rewards. Striking a balance between exploration and exploitation is critical for the agent’s success in the long run, as excessive exploration may lead to suboptimal rewards, while excessive exploitation may prevent the agent from discovering better actions. Various types of agents can be developed based on how they handle this exploration-exploitation trade-off, and in our implementation, we will compare different methods to understand their effectiveness in addressing this challenge.

Epsilon Greedy Sample Average Agent

The epsilon-greedy sample average method addresses the exploration-exploitation dilemma by choosing between exploration and exploitation randomly. With a probability of epsilon, the agent selects a random action for exploration, while with a probability of 1-epsilon, the agent exploits the current best action based on its estimated action values.

In comparison to the pure greedy algorithm, which always selects the action with the highest estimated value, the epsilon-greedy algorithm introduces a level of exploration, allowing the agent to discover potentially better actions and improve its long-term performance.

The sample average estimate is a technique used to update the estimated action values. For each action, the agent maintains a running average of the rewards it has received when selecting that action. When a new reward is received, the agent updates its estimate by incorporating the new reward into the running average. This way, the agent continuously refines its estimates based on its experience, which in turn helps it make better decisions in the exploration-exploitation trade-off.

class EpsilonGreedySampleAverageAgent:
    """
    An epsilon-greedy agent using sample-average method for action value estimation.

    Attributes
    ----------
    k : int
        Number of actions.
    runs : int
        Number of independent runs.
    epsilon : float, optional
        Probability of choosing a random action (exploration), default is 0.1.
    random_state : int, optional
        The random number generator seed to be used, default is None.
    """
    def __init__(
            self,
            k,
            runs,
            epsilon=0.1,
            random_state=None,
    ):
        self.k = k
        self.runs = runs
        self.epsilon = epsilon
        self.random_state = random_state

        self.setup()

    def setup(self):
        """Initialize the Q and N arrays for action value estimation and action counts."""
        self.nprandom = np.random.RandomState(self.random_state)
        self.Q = np.zeros((self.runs, self.k))
        self.N = np.zeros((self.runs, self.k))

    def get_action(self):
        """Choose an action based on epsilon-greedy policy."""
        greedy_action = np.argmax(
            self.nprandom.random(self.Q.shape) * (self.Q==self.Q.max(axis=1, keepdims=True)), # breaking ties randomly
            axis=1
        )
        random_action = self.nprandom.randint(0, self.k, size=(self.runs, ))

        action = np.where(
            self.nprandom.random((self.runs, )) < self.epsilon,
            random_action,
            greedy_action,
        )
        return action

    def get_step_size(self, action):
        """Calculate the step size for updating action value estimates.
        For sample average method we return 1/number of times the action is choosen until current step"""
        return 1/self.N[np.arange(self.runs), action]

    def update(self, action, reward):
        """Update the action value estimates based on the chosen action and received reward."""
        self.N[np.arange(self.runs), action] += 1
        step_size = self.get_step_size(action)
        self.Q[np.arange(self.runs), action] += (reward - self.Q[np.arange(self.runs), action])*step_size

Now, we’ll create a testbed which will a run episode of parley between agent and k-arm bandit environment for the provided number of steps. We’ll also include a function that plots the average reward agent is receiving and percentage optimal action the agent is taking in each steps.

class MultiArmedBanditTestBed:
    """A test bed for running experiments with multi-armed bandits and agents.

    Attributes:
        bandit (object): A multi-armed bandit object.
        agent (object): An agent object.
        steps (int): The number of steps for the experiment.
    """
    def __init__(
            self,
            bandit,
            agent,
            steps,
    ):
        self.bandit = bandit
        self.agent = agent
        self.steps = steps

    def run_experiment(self):
        """Runs the experiment for the given number of steps and returns the average rewards and optimal actions.

        Returns:
            tuple: A tuple containing two lists: average rewards and average optimal actions for each step.
        """
        avg_reward = []
        avg_optimal_action = []

        for _ in range(self.steps):
            action = self.agent.get_action()
            reward = self.bandit.get_reward(action)
            self.agent.update(action, reward)

            correct = action == self.bandit.get_correct_action()

            avg_reward.append(reward.mean())
            avg_optimal_action.append(correct.mean())

        return avg_reward, avg_optimal_action

    @classmethod
    def run_and_plot_experiments(cls, steps, exp_bandit_agent_dict):
        """Runs multiple experiments and plots the results.

        Args:
            steps (int): The number of steps for the experiments.
            exp_bandit_agent_dict (dict): A dictionary with labels as keys and (bandit, agent) tuples as values.
        """
        fig, (ax_reward, ax_optimal_action) = plt.subplots(nrows=1, ncols=2)

        for label, (bandit, agent) in exp_bandit_agent_dict.items():
            test_bed = cls(bandit, agent, steps)
            avg_reward, avg_optimal_action = test_bed.run_experiment()
            ax_reward.plot(avg_reward, label=label)
            ax_optimal_action.plot(avg_optimal_action, label=label)

        ax_reward.set_ylabel("Average reward")
        ax_reward.set_xlabel("Steps")

        ax_optimal_action.set_ylabel("% Optimal Action")
        ax_optimal_action.set_xlabel("Steps")

        ax_reward.legend()
        ax_optimal_action.legend()

        plt.show()

Experiment 1: Greedy Vs \(ϵ\)-Greedy

Let’s run three agents: greedy, \(\epsilon=0.1\)-greedy, and \(\epsilon=0.01\)-greedy agent on the stationary 10-arm bandit environment.

MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=10_000,
    exp_bandit_agent_dict={
        "greedy": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0)
        ),
        "epsilon=0.1": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.1)
        ),
        "epsilon=0.01": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.01)
        ),
    },
)

png

In the plots above, we can clearly see that epsilon greedy agent outperforms the pure greedy agent because greedy agent did no exploration and got stuck on the suboptimal action. On comparing the two epsilon greedy agent, we can see that agent with higher epsilon explored more and got better performance at initial stage. But, the agent with low epsilon value outperforms the agent with high epsilon agent in the long run. This clearly show the challenges of finding the balance between exploration and exploitation.

Non-Stationary Multi-Armed Bandit Problem

Let’s shift the multi-armed bandit problem a bot towards a realistic full reinforcement learning paradigm. The bandit implemented earlier was stationary as its reward distribution never changed during its lifetime. Let’s implement a bandit where the reward distribution changes with steps the agent takes.

class NonStationaryMultiArmedBandit(StationaryMultiArmedBandit):
    def setup(self):
        """For stationary bandit, we start with same average reward for all actions. Let's say zero."""
        self.nprandom = np.random.RandomState(self.random_state)
        self.q_star = np.zeros((self.runs, self.k))

    def get_reward(self, action):
        """Before getting reward for the action taken in the current step,
        we shift the reward distribution with drift sampe from normal distribution with mean 0 and std 0.01"""
        self.q_star += self.nprandom.normal(loc=0.0, scale=0.01, size=(runs, k))
        return super(NonStationaryMultiArmedBandit, self).get_reward(action)

Experiment 2: Sample Average Action Value Estimation Method on Stationary and Non-Stationary Problem

MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=10_000,
    exp_bandit_agent_dict={
        "stationary": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.01)
        ),
        "nonstationary": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.01)
        ),
    },
)

png

It is clearly seen that for non stationary bandit problem sample average method falls significantly behind.

Epsilon Greedy with Constant Step Size

The sample average method gives equal weightage to reward obtain irrespective of the steps in which they were obtained. However, in case of nonstationary setting it makes sense to give more weight to recent rewards than to long-past rewards. One of the most popular ways of doing this is to use a constant step-size parameter.

class EpsilonGreedyAgent(EpsilonGreedySampleAverageAgent):
    """Epsilon greedy agent with constant step size."""
    def __init__(self, k, runs, alpha=0.1, epsilon=0.1, random_state=None):
        super(EpsilonGreedyAgent, self).__init__(k, runs, epsilon, random_state)
        self.alpha = alpha

    def get_step_size(self, action):
        """Instead of returning number of times the action is choosen,
        it returns the constant step size `alpha` provided to agent during its instantiation."""
        return self.alpha

Experiment 3: Epsilon Greedy with Constant Step-Size in Non-Stationary Environment

Let’s run experiment with combinations of stationary/non-stationary bandit and sampleaverage/constant step size action value estimation methods.

MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=10_000,
    exp_bandit_agent_dict={
        "stationarysampleaverage": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.01)
        ),
        "stationaryconstantstepsize": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
        ),
        "nonstationarysampleaverage": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedySampleAverageAgent(k=k, runs=runs, epsilon=0.01)
        ),
        "nonstationaryconstantstepsize": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
        ),
    },
)

png

From the figure, we can see that both constant step size and sample average action value estimation methods shows comparable performance on stationary setting. However, in non-stationary setting, though lower than in stationary setting, the constant step size method performs significantly better than sample average method.

Optimistic Initial Values

Optimistic initial values is a technique used in the multi-armed bandit problem to encourage exploration in the early stages of learning. Instead of starting with initial action values set to zero or a neutral value, this approach sets the initial action values to a high, optimistic value, sometimes even higher than the maximum possible reward.

The optimistic initial values encourage exploration because the agent is initially biased to believe that all actions have high rewards. As the agent selects actions and receives actual rewards, it updates its estimates, and the optimistic values gradually fall towards their true values. This process continues until the value estimates of suboptimal actions fall below the estimates of the optimal action, and the agent starts exploiting the optimal action more frequently.

While optimistic initial values can be effective in stationary problems, there are some limitations to this approach. For instance, it is not well-suited for nonstationary problems, as its drive for exploration is temporary and may not adapt well to changes in the environment. Additionally, the technique relies heavily on the initial conditions, and finding the best initial values may require some trial-and-error or domain knowledge blog.

class OptimisticEpsilonGreedyAgent(EpsilonGreedyAgent):
    def __init__(self, k, runs, alpha=0.1, init_q=5, epsilon=0.1, random_state=None):
        """This behaves similar to epsilon greedy agent,
        but it starts with high optimistic action value."""
        super(OptimisticEpsilonGreedyAgent, self).__init__(k, runs, alpha, epsilon, random_state)
        self.init_q = init_q

        self.Q += self.init_q

Experiment 4: Optimistic Pure Greedy vs Epsilon Greedy

MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=1000,  # To make spike clear
    exp_bandit_agent_dict={
        "optimisticgreedy": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            OptimisticEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=5, epsilon=0)
        ),
        "realisticepsilongreedy": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
        ),
    },
)

png

We can see that even with \(\epsilon=0\)(pure greedy) the agent with optimistic initial values outperforms \(\epsilon\)-greedy agent in the long run. Initially, it underperforms as it was forced to do exploration because of the optimistic values. Note that there is spike at about the 10th steps. The optimistic greedy policy promotes exploration in the initial steps, as all value estimates are set higher than their true values. This can lead to a scenario where the agent randomly selects the optimal action and then quickly abandons it in favor of other actions that have not been explored yet. This behavior results in a noticeable spike in performance around timestep 10, as the agent is still in the early stages of exploring different actions.

Experiment 5: Optimistic Pure Greedy vs Epsilon Greedy on Non-Stationary Setting

The experiment above for optimistic initial action values is done in stationary setting. Let’s run it in non-stationary setting see what happens.

MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=10_000,
    exp_bandit_agent_dict={
        "optimisticgreedy": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            OptimisticEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=5, epsilon=0)
        ),
        "realisticepsilongreedy": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
        ),
    },
)

png

The experimenation clearly shows the limitation of the trick of using optimistic intial values to force exploration. It is not well suited to nonstationary problems because its drive for exploration is inherently temporary and non-stationary task at every steps creates need for exploration.

Experiment 6: Effects of Initial Action Values

Let’s run an experiment comparing different initial action value estimate.

MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=5000,
    exp_bandit_agent_dict={
        "optimistic_egreedy_initq=500": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            OptimisticEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=500, epsilon=0.01)
        ),
        "optimistic_egreedy_initq=5": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            OptimisticEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=5, epsilon=0.01)
        ),
    },
)

png

It is clear from the figure that the initial action value we choose has effects on beginning steps. With steps, the effect of initial values is lessened. I would like to quote the statements from the book here: “Indeed, any method that focuses on the initial conditions in any special way is unlikely to help with the general nonstationary case. The beginning of time occurs only once, and thus we should not focus on it too much.” So, for long one shouldn’t worry about the effect of initial action value choosen. But, there are tricks to avoid it and let’s explore one.

Unbiased Constant Step-Size

This trick given in the Exercise 2.7 of the Sutton’s book deals to avoid the effect of initial action values. The trick is to use step size: \(\beta \doteq \alpha /\overline{\omicron }_{n}\) where \(\overline{\omicron }_{n} \doteq \overline{\omicron }_{n-1} +\alpha ( 1-\overline{\omicron }_{n-1})\) for \(n \geq 0\), with \(\overline{\omicron}_{0}\doteq0\).

class UnbiasedEpsilonGreedyAgent(OptimisticEpsilonGreedyAgent):
    """Unbiased Constant Step-Size Agent.
    Inheritated from `OptimisticEpsilonGreedyAgent` to allow to set initial action values"""
    def __init__(self, k, runs, alpha=0.1, init_q=5, epsilon=0.1, random_state=None):
        super(UnbiasedEpsilonGreedyAgent, self).__init__(k, runs, alpha, init_q, epsilon, random_state)
        self.step_trace = np.zeros((self.runs, self.k))

    def get_step_size(self, action):
        """Calculate the step size for the given action using trace."""
        self.step_trace[np.arange(self.runs), action] += self.alpha*(1 - self.step_trace[np.arange(self.runs), action])
        return self.alpha / self.step_trace[np.arange(self.runs), action]

Experiment 7: Unbiased Constant Step-Size with Different Initial Action Values

Let’s run ubiased constant step-size agent with different inital values.

MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=10000,
    exp_bandit_agent_dict={
        "unbiased_initq=500": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            UnbiasedEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=500, epsilon=0.01)
        ),
        "unbiased_initq=5": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            UnbiasedEpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, init_q=5, epsilon=0.01)
        ),
    },
)

png

With this trick, we can see the effect of initial action values is gone. It is because with the unbiased constant step size the step size parameter for the first update will be \(1\) which means the agent will ignore the current action value estimate and set the estimate to the current reward it get.

Upper-Confidence Bound Action Selection

In above sections, we’re more focused in action value estimation methods. We explored sample average, constant step size and unbiased constant step size methods. Remember that a agent also has to make decision on what action to choose in each step. The greedy method always exploit the action with highest action value estimate it has till the current step, i.e., no exploration. The epsilon greedy method in each steps select the random action with some probability as way of doing exploration. While doing these random selection, equal preference is given not taking care of actions that are better than others.

Upper Confidence Bound (UCB) action selection methods aims to balance exploration and exploitation based on the confidence boundaries assigned to each action. The UCB algorithm is rooted in the principle of optimism in the face of uncertainty, meaning that the more uncertain we are about an action, the more important it becomes to explore that action. The UCB algorithm effectively solves the exploration-exploitation dilemma by selecting actions that maximize both the estimated reward and the exploration term. By doing so, it ensures that the agent explores uncertain actions while still exploiting actions with high estimated rewards, allowing it to learn the optimal action over time. Let’s implement it as it descibed in the section 2.7 of the book.

class UCBActionAgent(EpsilonGreedyAgent):
    def __init__(self, k, runs, alpha=0.1, confidence=1, random_state=None):
        self.k = k
        self.runs = runs
        self.alpha = alpha
        self.confidence = confidence
        self.random_state = random_state

        self.setup()

    def get_action(self):
        current_step = self.N.sum(axis=1, keepdims=True) + 1
        return np.argmax(
            self.Q + self.confidence * np.sqrt(
                np.log(current_step)/(self.N + 1e-5)  # To avoid divide by zero error
            ),
            axis=1
        )

Experiment 8: UCB vs \(\epsilon\)-Greedy on Stationary Setting

Let’s run an experiment comparing epsilon greedy with UCB with various confidence level.

MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=1000,  # To make spike clear
    exp_bandit_agent_dict={
        "UCB C=1": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            UCBActionAgent(k=k, runs=runs, alpha=0.1, confidence=1)
        ),
        "UCB C=2": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            UCBActionAgent(k=k, runs=runs, alpha=0.1, confidence=2)
        ),
        "e-greedy, e=0.1": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.1)
        ),
    },
)

png

UCB with \(C=1\) is performing better than epsilon greedy but we didn’t see the imporovement with \(C=2\). This is because \(C\) controls the degree of exploration, higher the confidence level, higher the degree of exploration.

Experiment 9: UCB vs \(\epsilon\)-Greedy on Non-Stationary Setting

Let’s run the above experiment in non-stationary setting.

# e-greedy is better than UCB in nonstationary setting
MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=10_000,
    exp_bandit_agent_dict={
        "UCB C=1": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            UCBActionAgent(k=k, runs=runs, alpha=0.1, confidence=1)
        ),
        "e-greedy, e=0.1": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.1)
        ),
    },
)

png

This experiments shows the limitation of UCB in non-stationary setting.

Gradient Bandit Algorithms

Instead of indiscriminately choosing actions or using some uncertainities values for exploring actions, a more sophisticated way is to learn the preference of each action. Agent can do so by using gradient bandit algorithms. Unlike other bandit algorithms that maintain an action-value estimate for each action, gradient bandit algorithms maintain a preference value for each action and use a soft-max distribution to derive the probabilities of selecting each action.

The core idea of gradient bandit algorithms is to update the preferences based on the received rewards and a baseline reward value, which can be the average of all observed rewards so far. The update rule is designed to increase the preference for actions that yield higher rewards than the baseline and decrease the preference for actions with lower rewards.

In each iteration, the agent selects an action according to the soft-max distribution derived from the action preferences and updates the preferences based on the received reward and the baseline. This process continues until the agent converges towards the optimal action or a stopping criterion is met. Gradient bandit algorithms can adapt to changing environments and provide a good balance between exploration and exploitation.

class GradientAgent(EpsilonGreedyAgent):
    """Gradient Bandit Algorithm"""
    def __init__(self, k, runs, alpha=0.1, random_state=None):
        self.k = k
        self.runs = runs
        self.alpha = alpha
        self.random_state = random_state

        self.setup()

    def setup(self):
        """Set up the initial preference and average reward for the GradientAgent"""
        self.nprandom = np.random.RandomState(self.random_state)
        # initial preference is same for all actions
        self.H = np.zeros((self.runs, self.k))
        self.avg_R = np.zeros((self.runs,))

    def action_proba(self):
        """Calculate the probability of each action."""
        exp = np.exp(self.H)
        prob = exp/exp.sum(axis=1, keepdims=True)

        # Persist for update method
        self.current_action_prob = prob
        return prob

    def get_action(self):
        """Get the action to take based on the action probabilities."""
        prob = self.action_proba()
        # TODO find better way to vectorize the following.
        return np.apply_along_axis(
            lambda row: self.nprandom.choice(np.arange(self.k), p=row),
            arr=prob,
            axis=1,
        )

    def update(self, action, reward):
        """Update the preferences and average reward based on the given action and reward."""
        step_size = self.get_step_size(action)

        # Get already calculated action probability if available
        prob  = self.current_action_prob if hasattr(self, "current_action_prob") else self.action_proba()
        prob = -prob
        prob[np.arange(self.runs), action] = 1 + prob[np.arange(self.runs), action]
        deltaR = reward - self.avg_R

        self.H += step_size*prob*deltaR[:, np.newaxis]
        self.avg_R += step_size*deltaR

Experiment 10: Gradient Bandit on Stationary Setting

MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=1000,
    exp_bandit_agent_dict={
        "gradient": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            GradientAgent(k=k, runs=runs, alpha=0.1)
        ),
        "e-greedy, e=0.01": (
            StationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
        ),
    },
)

png

We can see almost 100% improvement in the optimal action selection because of gradient bandit algorithm.

Experiment 11: Gradient Bandit on Non-Stationary Setting

MultiArmedBanditTestBed.run_and_plot_experiments(
    steps=1000,
    exp_bandit_agent_dict={
        "gradient": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            GradientAgent(k=k, runs=runs, alpha=0.1)
        ),
        "e-greedy, e=0.01": (
            NonStationaryMultiArmedBandit(k=k, runs=runs),
            EpsilonGreedyAgent(k=k, runs=runs, alpha=0.1, epsilon=0.01)
        ),
    },
)

png

The gradient bandit algorithm also struggles in non-stationary setting.

This concludes the blog on multi-armed bandit problems. We implemented and compared various action value estimation and selection algorithms in sationary and non-stationary settings. All experiments showed the inherent challenge of exploitation vs exploration delimma in reinforcement learning.

References:

  • Sutton, R. S., & Barto, A. G. (2018). Reinforcement Learning, second edition: An Introduction. MIT Press.
]]>
Sagar Sapkota[email protected]https://spktsagar.com
Critical Analysis: Why do tree-based models still outperform deep learning on typical tabular data?2023-06-11T00:00:00-07:002023-06-11T00:00:00-07:00https://spktsagar.com/posts/2023/06/annotated-critics-tree-better-than-deeplearningIn recent years, the machine learning community has witnessed significant advancements in deep learning models. However, a perplexing phenomenon encountered by industrial machine learning practitioners is that even the simplest tree-based models often outperform advanced deep learning models on real-world projects involving tabular data. In this blog post, I will delve into the critical analysis of a research paper that sheds light on this issue and provides valuable insights for practitioners working with tabular data.

The paper “Why do tree-based models still outperform deep learning on typical tabular data?” by Grinsztajn et al. [2023], establishes a new standard set of datasets with clear characteristics of tabular data and benchmarks various tree-based and deep learning models on these datasets. The results debunk common myths about the performance of neural networks on tabular data, highlighting the importance of understanding inductive biases and the impact of uninformative features and irregular functions on model performance.

By discussing the strengths and weaknesses of the paper, I aim to provide a comprehensive understanding of why tree-based models continue to outshine deep learning models on typical tabular data. This analysis will be particularly useful for industrial machine learning practitioners who are often puzzled by the seemingly inferior performance of deep learning models on tabular data. Also, some of the key strengths and weaknesses with this paper, the approach taken by the authors, and the clarity and ease of understanding of the writing are discussed herewith.

In the PDF attached below in the blog post, I have included annotations for the critical analysis . These annotations were made while I was reading the paper.

Key Contributions and Their Significances:

  • The paper establishes a new standard set of datasets with clear characteristics of tabular data. The authors also provided precise processing methods used to create them. They claim that having such homogeneous datasets allows researchers to investigate inductive biases purely suited for tabular data.
  • The authors benchmarked standard tree-based models(RF, GBTs, XGBoost) popular among practitioners and SOTA deep learning models for tabular data(MLP, ResNet, FT Transformer, SAINT) on those datasets. While doing so, the authors took into consideration different hyperparameter optimization budgets. The variance introduced due to it is addressed intelligently by shuffling the random search order multiple times. The result debunked two myths and pointed out that hyperparameter tuning doesn’t make neural nets state-of-the-art, and categorical variables are not the main weakness of neural networks.
  • While investigating why tree-based models outperform deep learning models by transforming data to alter their performance gap, the authors shed light on their different inductive biases. These findings are of significant importance as they guide future research to make tabular-specific neural networks robust to uninformative features, deal with irregular functions, and be rotationally non-invariant in a computationally cheaper way.

Strengths:

  • Though not for all experiment settings, the author provides sound justification for some of their choices. For instance, Bayesian optimization was not chosen over random search as it doesn’t allow reshuffling of the search order, and their ablation study also shows it doesn’t provide a significant improvement over random search. Also, their choices of data preparation steps cohere with the goal of making homogeneous datasets.
  • The conclusions drawn are backed by complementary empirical evidence. For instance, while gauging the effects of uninformative features, authors draw the same conclusion by both adding and removing uninformative features or by training on informative and uninformative features separately. Similarly, their empirical conclusion about the link between rotational invariance and uninformative features is validated by the theoretical link provided by Ng [2004].
  • The paper is well-sectioned with pertinent information. The introduction covers all the important aspects of the whole paper. It is an easy-to-follow paper. The codebase is also available publicly.

Weaknesses:

  • The authors claim they provide new comprehensive datasets for a standard benchmark. However, the criteria used while creating these datasets ignore many features of real-world datasets, questioning their usability in standard benchmarking.
  • Kadra et al. [2021a] uses a “cocktail” of regularization on MLPs and get competitive with XGBoost on a similar random search budget. Rather than speculating the performance was particularly due to the presence of “deterministic” datasets, the authors could’ve proven it empirically by measuring the performance of MLP with regularization on their newly created datasets.
  • Although I anticipate conclusions drawn in the paper hold for small and large datasets, experiments only with medium-sized datasets leave a place for doubt.
  • I think techniques to remove side issues contradict some criteria mentioned in 3.1, such as “Not too easy,” and “real-world data.”
  • No clear explanation is given why multi-class tasks are binarised, why only the top 5 features based on RF importance ranking were taken to study the impact of irregular functions, why the search order was shuffled within a single random search run instead of considering a new one, and why “ReduceOnPlateau” LR Scheduler was chosen for MLP.

Summary:

Overall, the significance of the contributions and strengths of the paper beats its weaknesses. That is why I think it was accepted in NeurIPS 2022. Also, this paper provides explanations for practitioners perplexed by the inferior performance of deep learning models on tabular data.

Annotated Paper:

References:

  • LéoGrinsztajn,EdouardOyallon,andGaëlVaroquaux.“Whydotree-basedmodelsstilloutperform deeplearningontabulardata?”(July2022).arXiv: 2207.08815[cs.LG]
  • Andrew Y. Ng. Feature selection, L 1 vs. L 2 regularization, and rotational invariance. In Twenty-First International Conference on Machine Learning - ICML ’04, page 78, Banff, Alberta, Canada, 2004. ACM Press. doi: 10.1145/1015330.1015435.
  • Arlind Kadra, Marius Lindauer, Frank Hutter, and Josif Grabocka. Well-tuned Simple Nets Excel on Tabular Datasets, November 2021a.
]]>
Sagar Sapkota[email protected]https://spktsagar.com
Finetuning XLS-R(Wav2Vec2) on OpenSLR Nepali ASR Dataset2022-08-23T00:00:00-07:002022-08-23T00:00:00-07:00https://spktsagar.com/posts/2022/08/finetune-XLS-R-Nepali-ASR

This blog/tutorial on finetune XLS-R on OpenSLR’s Nepali ASR dataset is adopted from the Huggingface’s blog on “Fine-tuning XLS-R for Multi-Lingual ASR with 🤗 Transformers.” The existing XLS-R on Nepali was actually finetuned on OpenSLR’s Nepali Text to Speech, which contains voices from only one speaker and that too of high quality. Therefore, it is doubtful that this model would work when the utterances are made in the wild, as we would normally do. To avoid this problem, speech samples taken in real life with natural characteristics like noises, pauses should be used, and Large Nepali ASR training data set is the one dataset we have for Nepali language. However, computational resources unavailability is keeping me from completing the finetuning and hyperparameter optimization of this remarkable model on my mother tongue; so far I get 21% word error rate on the test split I created. Anyone who likes to contribute on finshing this train will be heartily welcomed. Let’s begin the implementation journey though.

Introduction to Wav2Vec2

Wav2Vec2 is a pretrained model for Automatic Speech Recognition (ASR) and was released in September 2020 by Alexei Baevski, Michael Auli, and Alex Conneau. Soon after the superior performance of Wav2Vec2 was demonstrated on one of the most popular English datasets for ASR, called LibriSpeech, Facebook AI presented a multi-lingual version of Wav2Vec2, called XLSR. XLSR stands for cross-lingual speech representations and refers to model’s ability to learn speech representations that are useful across multiple languages.

XLSR’s successor, simply called XLS-R (refering to the '’XLM-R for Speech’‘), was released in November 2021 by Arun Babu, Changhan Wang, Andros Tjandra, et al. XLS-R used almost half a million hours of audio data in 128 languages for self-supervised pre-training and comes in sizes ranging from 300 milion up to two billion parameters. You can find the pretrained checkpoints on the 🤗 Hub:

Similar to BERT’s masked language modeling objective, XLS-R learns contextualized speech representations by randomly masking feature vectors before passing them to a transformer network during self-supervised pre-training (i.e. diagram on the left below).

For fine-tuning, a single linear layer is added on top of the pre-trained network to train the model on labeled data of audio downstream tasks such as speech recognition, speech translation and audio classification (i.e. diagram on the right below).

wav2vec2_structure

XLS-R shows impressive improvements over previous state-of-the-art results on both speech recognition, speech translation and speaker/language identification, cf. with Table 3-6, Table 7-10, and Table 11-12 respectively of the official paper.

Notebook Setup

In this notebook, we will give an in-detail explanation of how XLS-R - more specifically the pre-trained checkpoint Wav2Vec2-XLS-R-300M - can be fine-tuned for ASR.

XLS-R is fine-tuned using Connectionist Temporal Classification (CTC), which is an algorithm that is used to train neural networks for sequence-to-sequence problems, such as ASR and handwriting recognition.

I highly recommend reading the well-written blog post Sequence Modeling with CTC (2017) by Awni Hannun.

First, let’s try to get a good GPU in our colab! With Google Colab’s free version it’s sadly becoming much harder to get access to a good GPU. With Google Colab Pro, however, one should easily get either a V100 or P100 GPU.

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)
Sun Oct 23 02:39:42 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Before we start, let’s install datasets, transformers, and pytorch mutually compatible with each other. Also, we need the torchaudio to load audio files and jiwer to evaluate our fine-tuned model using the word error rate (WER) metric \({}^1\).

!pip --no-cache-dir install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
!pip --no-cache-dir install transformers==4.23.1
!pip --no-cache-dir install datasets==2.6.0
!pip --no-cache-dir install evaluate==0.3.0
!pip --no-cache-dir install jiwer

If you are training in environments like colab where storage space are limited or environments are temporary, it is strongly recommended to save your model checkpoints somewhere else. If we’re working with Huggingface’s transformer, why not the Huggingface Hub. The 🤗 Hub has integrated version control so you can be sure that no model checkpoint is getting lost during training.

To do so you have to store your authentication token from the Hugging Face website (sign up here if you haven’t already!)

from huggingface_hub import notebook_login

notebook_login()
Login successful
Your token has been saved to /root/.huggingface/token

Huggingface hub requires git lfs as it reuires to upload large models.

!apt install git-lfs

\({}^1\) In the paper, the model was evaluated using the phoneme error rate (PER), but by far the most common metric in ASR is the word error rate (WER). To keep this notebook as general as possible we decided to evaluate the model using WER.

Prepare Data, Tokenizer, Feature Extractor

ASR models transcribe speech to text, which means that we both need a feature extractor that processes the speech signal to the model’s input format, e.g. a feature vector, and a tokenizer that processes the model’s output format to text.

In 🤗 Transformers, the XLS-R model is thus accompanied by both a tokenizer, called Wav2Vec2CTCTokenizer, and a feature extractor, called Wav2Vec2FeatureExtractor.

Let’s start by creating the tokenizer to decode the predicted output classes to the output transcription.

Create Wav2Vec2CTCTokenizer

A pre-trained XLS-R model maps the speech signal to a sequence of context representations as illustrated in the figure above. However, for speech recognition the model has to to map this sequence of context representations to its corresponding transcription which means that a linear layer has to be added on top of the transformer block (shown in yellow in the diagram above). This linear layer is used to classifies each context representation to a token class analogous how, e.g., after pretraining a linear layer is added on top of BERT’s embeddings for further classification - cf. with ‘BERT’ section of this blog post.

The output size of this layer corresponds to the number of tokens in the vocabulary, which does not depend onXLS-R’s pretraining task, but only on the labeled dataset used for fine-tuning. So in the first step, we will take a look at the chosen dataset of OpenSLR Nepali ASR and define a vocabulary based on the transcriptions.

Dataset Preparation

First, let’s go to OpenSLR official website for Nepali ASR This data set contains transcribed audio data for Nepali. The data set consists of zips containing flac(a file format for audio) files, and a TSV file. The file utt_spk_text.tsv contains a FileID, anonymized UserID and the transcription of audio in the file. The data set has been manually quality checked, but there might still be errors.

Since downloading, extracting and preprocessing takes a lot of works and time, I’ve uploaded the preprocessed dataset along with the original dataset in Huggingface Hub so that we can interact with the Huggingface’ Dataset API. In summary, following are the steps I’ve taken to prepare the dataset:

  1. Download the dataset from https://www.openslr.org/54/
  2. Extract the zip files containing flac audio files
  3. Load the audio files and apply following preprocessing function to each of them
     import torchaudio
     # The pretrained Wav2Vec2 model was trained on speeches with sample rate 16KHz
     SAMPLING_RATE = 16000
     def process_audio_file(orig_path, new_path):
         """Read and process file in `orig_path` and save it to `new_path`"""
         waveform, sampling_rate = torchaudio.load(orig_path)
         if sampling_rate != SAMPLING_RATE:
             # Resample to 16KHz if the audio originally has different sampling rate.
             waveform = torchaudio.functional.resample(waveform, sampling_rate, SAMPLING_RATE)
         #  Though the ASR models should be resilient to silences at the ends of audio,
         # the leading and trailing silences are removed using Voice Activity Detection(VAD)
         # implemented in torchaudio with default parameters to reduce the demands 
         # for computational resources
         waveform = torchaudio.functional.vad(waveform, sample_rate=SAMPLING_RATE)
         # save the processed audio files to new location
         torchaudio.save(new_path, waveform, sample_rate=SAMPLING_RATE)
    
  4. The processed audio files are agained zipped in similar fashion as in original OpenSLR Nepali ASR dataset.
  5. The zip files and TSV file containing transcript and audio path mappings are uploaded to “spktsagar/openslr-nepali-asr-cleaned/data”.
  6. The dataset loading script, which can be found here is developed and pushed to the same repo.

Now, we load the dataset with datasets api. Since the dataset is not split into train/val/test split, whole dataset will be downloaded and split into train and val set later. When I finetuned the Wav2Vec2 model on preprocessed, cleaned dataset, the model was not learning(WER was always 1 on validation/test set). Anyone who likes to debug the preprocessed dataset and fintuning on it is heartily welcomed.

from datasets import load_dataset

DATASET_TYPE = 'original'  # change to `original` or `cleaned` for downloading original or cleaned version of openslr dataset

dataset = load_dataset("spktsagar/openslr-nepali-asr-cleaned", name=DATASET_TYPE, split='train')
dataset
Dataset({
    features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
    num_rows: 157905
})

You can see in the info that the dataset contains following fields.

  • utterance_id
  • speaker_id
  • utterance
  • transcription
  • num_frames

For the description of them, please read the dataset card here.

Text Preprocessing

Although the transcription are fairly clean, some transcription contain characters other than Nepali. We will remove those data from our dataset.

import string

def check_english_chars(text):
    """Returns if this text contains any english characters"""
    return any([c in text for c in string.ascii_letters])

# Use dataset filter to remove examples with above function
dataset = dataset.filter(
    lambda ex: not check_english_chars(ex),
    input_columns=['transcription',],
    with_indices=False, batched=False, batch_size=0,
)
dataset
Dataset({
    features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
    num_rows: 157904
})

Let’s see the list of all the characters we have now in the dataset.

''.join(sorted(set([c for s in dataset['transcription'] for c in s])))
' !%.;?\\\xa0ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलवशषसह़ािीुूृेैॉॊोौ्ॐ॑ॠ।०१२३४५६७८९॰\u200c\u200d\u200e\u200f“'

You can see there are some characters and symbols that we don’t use in Nepali. We will remove those from the transcription.

remove_chars = ['!', '%', '.', ';', '?', '\\', '।', '\xa0', '\u200c', '\u200d', '\u200e', '\u200f', '“']

def remove_special_characters(row):
    row['transcription'] = ''.join(
        [c for c in row['transcription'] if c not in remove_chars]
    ).strip()
    return row

dataset = dataset.map(remove_special_characters)
''.join(sorted(set([c for s in dataset['transcription'] for c in s])))
' ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलवशषसह़ािीुूृेैॉॊोौ्ॐ॑ॠ०१२३४५६७८९॰'

In CTC, it is common to classify speech chunks into letters, so we will do the same here. Let’s extract all distinct letters of the training and test data and build our vocabulary from this set of letters.

We write a mapping function that concatenates all transcriptions into one long transcription and then transforms the string into a set of chars. It is important to pass the argument batched=True to the map(...) function so that the mapping function has access to all transcriptions at once.

def extract_all_chars(batch):
    all_text = " ".join(batch["transcription"])
    vocab = list(set(all_text))
    return {"vocab": [vocab]}

vocab_all = dataset.map(extract_all_chars, batched=True,
                        batch_size=-1, keep_in_memory=True,
                        remove_columns=dataset.column_names)
vocab_list = sorted(list(set(vocab_all["vocab"][0])))

Finally, we also add a padding token that corresponds to CTC’s “blank token”. The “blank token” is a core component of the CTC algorithm. For more information, please take a look at the “Alignment” section here.

UNK_TOKEN = '__UNK__'
PAD_TOKEN = '__PAD__'

vocab_list = [PAD_TOKEN, UNK_TOKEN, *vocab_list]

Now, we create an enumerated dictionary so that we have token to id mapping.

vocab_dict = {v: k for k, v in enumerate(vocab_list)}

# for printing vocab in single line
', '.join([f"{k}: {v}" for k, v in (vocab_dict.items())])
'__PAD__: 0, __UNK__: 1,  : 2, ँ: 3, ं: 4, ः: 5, अ: 6, आ: 7, इ: 8, ई: 9, उ: 10, ऊ: 11, ऋ: 12, ए: 13, ऐ: 14, ओ: 15, औ: 16, क: 17, ख: 18, ग: 19, घ: 20, ङ: 21, च: 22, छ: 23, ज: 24, झ: 25, ञ: 26, ट: 27, ठ: 28, ड: 29, ढ: 30, ण: 31, त: 32, थ: 33, द: 34, ध: 35, न: 36, प: 37, फ: 38, ब: 39, भ: 40, म: 41, य: 42, र: 43, ऱ: 44, ल: 45, व: 46, श: 47, ष: 48, स: 49, ह: 50, ़: 51, ा: 52, ि: 53, ी: 54, ु: 55, ू: 56, ृ: 57, े: 58, ै: 59, ॉ: 60, ॊ: 61, ो: 62, ौ: 63, ्: 64, ॐ: 65, ॑: 66, ॠ: 67, ०: 68, १: 69, २: 70, ३: 71, ४: 72, ५: 73, ६: 74, ७: 75, ८: 76, ९: 77, ॰: 78'

To make it clearer that " " has its own token class, we give it a more visible character |. In addition, we also add an “unknown” token so that the model can later deal with characters not encountered in Common Voice’s training set.

WORD_DELIMITER = '|'

vocab_dict[WORD_DELIMITER] = vocab_dict[" "]
del vocab_dict[" "]
len(vocab_dict)
79

Cool, now our vocabulary is complete and consists of 79 tokens, which means that the linear layer that we will add on top of the pretrained XLS-R checkpoint will have an output dimension of 79.

Let’s now save the vocabulary as a json file.

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In a final step, we use the json file to load the vocabulary into an instance of the Wav2Vec2CTCTokenizer class.

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token=UNK_TOKEN, pad_token=PAD_TOKEN, word_delimiter_token=WORD_DELIMITER)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Create Wav2Vec2FeatureExtractor

Speech is a continuous signal and to be treated by computers, it first has to be discretized, which is usually called sampling. The sampling rate hereby plays an important role in that it defines how many data points of the speech signal are measured per second. Therefore, sampling with a higher sampling rate results in a better approximation of the real speech signal but also necessitates more values per second.

A pretrained checkpoint expects its input data to have been sampled more or less from the same distribution as the data it was trained on. The same speech signals sampled at two different rates have a very different distribution, e.g., doubling the sampling rate results in data points being twice as long. Thus, before fine-tuning a pretrained checkpoint of an ASR model, it is crucial to verify that the sampling rate of the data that was used to pretrain the model matches the sampling rate of the dataset used to fine-tune the model.

XLS-R was pretrained on audio data of Babel, Multilingual LibriSpeech (MLS), Common Voice, VoxPopuli, and VoxLingua107 at a sampling rate of 16kHz. As stated earlier, the OpenSLR Nepali ASR dataset is already has a sampling rate of 16kHz.

# Define a global variable to store our sampling rate
SPEECH_SAMPLING_RATE = 16000

Long input sequences require a lot of memory. XLS-R is based on self-attention the memory requirement scales quadratically with the input length for long input sequences (cf. with this reddit post). In case this demo crashes with an “Out-of-memory” error for you, you might want to use the following code to filter all sequences that are longer than 5 seconds for training.

MAX_FRAMES = SPEECH_SAMPLING_RATE*5  # 5 sec

dataset = dataset.filter(
    lambda ex: ex < MAX_FRAMES,
    input_columns=['num_frames',],
    with_indices=False, batched=False, batch_size=0,
)

dataset
Dataset({
    features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
    num_rows: 143974
})

This seemed to have worked! Let’s listen to a couple of audio files to better understand the dataset and verify that the audio was correctly loaded.

Note: You can click the following cell a couple of times to listen to different speech samples.

import random
import IPython.display as ipd

sample_idx = random.randint(0, len(dataset))

print(dataset[sample_idx]['transcription'])
ipd.Audio(dataset[sample_idx]['utterance']["array"], autoplay=True, rate=SPEECH_SAMPLING_RATE)
सुवेदी पार्टीको जिल्ला

A Wav2Vec2FeatureExtractor object requires the following parameters to be instantiated:

  • feature_size: Speech models take a sequence of feature vectors as an input. While the length of this sequence obviously varies, the feature size should not. In the case of Wav2Vec2, the feature size is 1 because the model was trained on the raw speech signal \({}^2\).
  • sampling_rate: The sampling rate at which the model is trained on.
  • padding_value: For batched inference, shorter inputs need to be padded with a specific value
  • do_normalize: Whether the input should be zero-mean-unit-variance normalized or not. Usually, speech models perform better when normalizing the input
  • return_attention_mask: Whether the model should make use of an attention_mask for batched inference. In general, XLS-R models checkpoints should always use the attention_mask.
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=SPEECH_SAMPLING_RATE,
                                             padding_value=0.0, do_normalize=True,
                                             return_attention_mask=True)

Great, XLS-R’s feature extraction pipeline is thereby fully defined!

For improved user-friendliness, the feature extractor and tokenizer are wrapped into a single Wav2Vec2Processor class so that one only needs a model and processor object.

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor,
    tokenizer=tokenizer
)

If one wants to re-use the just created processor including tokenizer and feature extractor with the fine-tuned model of this notebook, it is strongly advised to upload the processor to the 🤗 Hub. Let’s call the repo to which we will upload the files "wav2vec2-large-xls-r-300m-nepali-openslr":

repo_name = "wav2vec2-large-xls-r-300m-nepali-openslr"

and upload the tokenizer to the 🤗 Hub.

processor.push_to_hub(repo_name)

Great, you can see the just created repository under https://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-nepali-openslr

Preprocess Data

So far, we have not looked at the actual values of the speech signal but just the transcription. In addition to transcription, our datasets include more column names utterance_id, speaker_id, utterance, and num_frames. In utterance there are two fields: array and path. path states the absolute path of the audio file, and array is the numpy array of the same audio file. Let’s take a look.

dataset[45]
{'utterance_id': 'a176fcb0d8',
 'speaker_id': '6a6d1',
 'utterance': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/6bc81d0b078b9cc240efb7e2885d7c845ea238b71125a16d73b19c06621b39f2/asr_nepali/data/a1/a176fcb0d8.flac',
  'array': array([ 0.00030518,  0.00039673,  0.00039673, ..., -0.00057983,
         -0.00036621, -0.00027466], dtype=float32),
  'sampling_rate': 16000},
 'transcription': '० पनि एक',
 'num_frames': 36800}

We will convert Huggingface’s Dataset to PyTorch dataset, so that audio files are loaded lazily as we are restricted by space availablity and memory size.

import torch

class NepaliASRProcessedDataset(torch.utils.data.Dataset):
    """Takes HF dataset and processor, and process the audio files
    and transcription with the processor only when items are requested
    """
    def __init__(
        self,
        dataset,
        processor,
    ):
        self.dataset = dataset
        self.processor = processor
    
    def __len__(self):
        """Length of dataset"""
        return len(self.dataset)
    
    def __getitem__(self, idx):
        """Return processed data at `idx` index."""
        example = self.dataset[idx]
        
        # Return dict
        return_dict = {}

        # first, process the audio with Wav2Vec2 feature extractor
        return_dict['input_values'] = self.processor(
            audio=example['utterance']['array'],
            sampling_rate=example['utterance']['sampling_rate'],
            return_attention_mask=False,  # will be calculated during batching
        )['input_values'][0]
        # add the length of extracted features of audio
        return_dict['input_length'] = len(return_dict['input_values'])

        # second, process the transcription with Wav2Vec2 tokenizer
        return_dict['labels'] = self.processor(
            text=example['transcription'],
            return_attention_mask=False,  # will be calculated during batching
        )['input_ids']
        return return_dict

Train/Test Split

Since our dataset has no separate split for training and evaluation, we will create one manually. We will split the dataset into 15% test and 85% train set

test_size = 0.15
dataset = dataset.sort('utterance_id')
split_dict = dataset.train_test_split(test_size=test_size, seed=42)
train_dataset, test_dataset = split_dict['train'], split_dict['test']
train_dataset, test_dataset
(Dataset({
     features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
     num_rows: 122377
 }), Dataset({
     features: ['utterance_id', 'speaker_id', 'utterance', 'transcription', 'num_frames'],
     num_rows: 21597
 }))

Convert the Huggingface’s train/test dataset to Pytorch train/test data

train_dataset = NepaliASRProcessedDataset(train_dataset, processor)
test_dataset = NepaliASRProcessedDataset(test_dataset, processor)

Training

The data is processed so that we are ready to start setting up the training pipeline. We will make use of 🤗’s Trainer for which we essentially need to do the following:

  • Define a data collator. In contrast to most NLP models, XLS-R has a much larger input length than output length. E.g., a sample of input length 50000 has an output length of no more than 100. Given the large input sizes, it is much more efficient to pad the training batches dynamically meaning that all training samples should only be padded to the longest sample in their batch and not the overall longest sample. Therefore, fine-tuning XLS-R requires a special padding data collator, which we will define below

  • Evaluation metric. During training, the model should be evaluated on the word error rate. We should define a compute_metrics function accordingly

  • Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.

  • Define the training configuration.

After having fine-tuned the model, we will correctly evaluate it on the test data and verify that it has indeed learned to correctly transcribe speech.

Set-up Trainer

Let’s start by defining the data collator. The code for the data collator was copied from this example.

Without going into too many details, in contrast to the common data collators, this data collator treats the input_values and labels differently and thus applies to separate padding functions on them. This is necessary because in speech input and output are of different modalities meaning that they should not be treated by the same padding function. Analogous to the common data collators, the padding tokens in the labels with -100 so that those tokens are not taken into account when computing the loss.

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union


LARGE_NEG = -100

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths,
        # and need different padding methods
        batch = {}
        input_features = [{"input_values": feature["input_values"]} for feature in features if 'input_values' in feature]
        label_features = [{"input_ids": feature["labels"]} for feature in features if 'labels' in feature]

        if input_features:
            batch.update(self.processor.pad(
                input_features,
                padding=self.padding,
                return_tensors="pt",
            ))
        if label_features:
            labels_batch = self.processor.tokenizer.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

            # replace padding with large negative number to ignore loss correctly
            labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), LARGE_NEG)

            batch["labels"] = labels

        return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

Next, the evaluation metric is defined. As mentioned earlier, the predominant metric in ASR is the word error rate (WER), hence we will use it in this notebook as well.

import evaluate
import numpy as np

wer_metric = evaluate.load("wer")

The model will return a sequence of logit vectors: \(\mathbf{y}_1, \ldots, \mathbf{y}_m\) with \(\mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0]\) and \(n >> m\).

A logit vector \(\mathbf{y}_1\) contains the log-odds for each word in the vocabulary we defined earlier, thus \(\text{len}(\mathbf{y}_i) =\) config.vocab_size. We are interested in the most likely prediction of the model and thus take the argmax(...) of the logits. Also, we transform the encoded labels back to the original string by replacing LARGE_NEG with the pad_token_id and decoding the ids while making sure that consecutive tokens are not grouped to the same token in CTC style \({}^1\).

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == LARGE_NEG] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Now, we can load the pretrained checkpoint of Wav2Vec2-XLS-R-300M. The tokenizer’s pad_token_id must be to define the model’s pad_token_id or in the case of Wav2Vec2ForCTC also CTC’s blank token \({}^2\). To save GPU memory, we enable PyTorch’s gradient checkpointing and also set the loss reduction to “mean”.

Because the dataset is quite large (~100h of data) and because ASR dataset is quite noisy, fine-tuning Facebook’s wav2vec2-xls-r-300m checkpoint seems to require some hyper-parameter tuning. Therefore, one had to play around a bit with different values for dropout, SpecAugment’s masking dropout rate, layer dropout, and the learning rate until training seemed to be stable enough.

Note: Since I was not able to run the hyperparamter optimization on colab, I’m not sure if the current set of hyperparameters are the best set of parameters. Feel free to adapt those parameters and let me know. I’ve used the default ones in wav2vec2 models.

from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xls-r-300m", 
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.075,
    layerdrop=0.1,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

The first component of XLS-R consists of a stack of CNN layers that are used to extract acoustically meaningful - but contextually independent - features from the raw speech signal. This part of the model has already been sufficiently trained during pretraining and as stated in the paper does not need to be fine-tuned anymore. Thus, we can set the requires_grad to False for all parameters of the feature extraction part.

model.freeze_feature_encoder()

In a final step, we define all parameters related to training. To give more explanation on some of the parameters:

  • group_by_length makes training more efficient by grouping training samples of similar input length into one batch. This can significantly speed up training time by heavily reducing the overall number of useless padding tokens that are passed through the model
  • learning_rate and weight_decay were heuristically tuned until fine-tuning has become stable. Note that those parameters strongly depend on the dataset used and might be suboptimal for other speech datasets.

For more explanations on other parameters, one can take a look at the docs.

During training, a checkpoint will be uploaded asynchronously to the hub every 400 training steps. It allows you to also play around with the demo widget even while your model is still training.

Note: If one does not want to upload the model checkpoints to the hub, simply set push_to_hub=False.

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=10,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=800,
  eval_steps=800,
  logging_steps=800,
  learning_rate=3e-4,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=True,
  hub_strategy='checkpoint',
  resume_from_checkpoint='last-checkpoint',
)

Now, all instances can be passed to Trainer and we are ready to start training!

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=processor.feature_extractor,
)

\({}^1\) To allow models to become independent of the speaker rate, in CTC, consecutive tokens that are identical are simply grouped as a single token. However, the encoded labels should not be grouped when decoding since they don’t correspond to the predicted tokens of the model, which is why the group_tokens=False parameter has to be passed. If we wouldn’t pass this parameter a word like "hello" would incorrectly be encoded, and decoded as "helo".

\({}^2\) The blank token allows the model to predict a word, such as "hello" by forcing it to insert the blank token between the two l’s. A CTC-conform prediction of "hello" of our model would be [PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD].

Training

Training will take multiple hours depending on the GPU allocated to this notebook.

In case you want to use this google colab to fine-tune your model, you should make sure that your training doesn’t stop due to inactivity. A simple hack to prevent this is to paste the following code into the console of this tab (right mouse click -> inspect -> Console tab and insert code).

function ConnectButton(){
    console.log("Connect pushed"); 
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click() 
}
setInterval(ConnectButton,60000);

Depending on what GPU was allocated to your google colab it might be possible that you are seeing an "out-of-memory" error here. In this case, it’s probably best to reduce per_device_train_batch_size to 8 or even less and increase gradient_accumulation.

trainer.train(
    resume_from_checkpoint=True,  # Set to false if you want to start from the beginning
)

If the training loss and validation WER go down nicely, You can now upload the result of the training to the 🤗 Hub, just execute this

trainer.push_to_hub()

You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier “your-username/the-name-you-picked” so for instance:

from transformers import AutoModelForCTC, Wav2Vec2Processor

model = AutoModelForCTC.from_pretrained("spktsagar/wav2vec2-large-xls-r-300m-nepali-openslr")
processor = Wav2Vec2Processor.from_pretrained("spktsagar/wav2vec2-large-xls-r-300m-nepali-openslr")

For more examples of how XLS-R can be fine-tuned, please take a look at the official speech recognition examples.

Evaluation

As a final check, let’s load the model and verify that it indeed has learned to transcribe Nepali speech.

Let’s first load the pretrained checkpoint.

model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)

Now, we will just take all examples of the test set, run it through the model and take the argmax(...) of the logits to retrieve the predicted token ids. Those token ids will be decoded to retrieve transcriptions.

# only take 5 examples from 
pred = trainer.predict(
    torch.utils.data.Subset(
        test_dataset,
        random.sample(list(range(len(test_dataset))), 5)
    )
)
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)

pred.label_ids[pred.label_ids == LARGE_NEG] = processor.tokenizer.pad_token_id

pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

Now, let’s see first few reference transcriptions and predicted transcriptions.

list(zip(label_str, pred_str))[:5]
[('उपस्थित गराए', 'उपस्थित गराए'),
 ('प्रतिशतले वृद्धि', 'प्रतिशदले वृद्धि'),
 ('उनीहरू जहाँ जुन', 'उनीहरू जाँ जुन'),
 ('टेम्प्लेटहरू पनि', 'टेम्प्लेटहरू पनि'),
 ('रूपमा खाइन्छ', 'रूपमा खाइन्छ')]

Alright! The transcription can definitely be recognized from our prediction, but it is not perfect yet. Training the model a bit longer, spending more time on the data preprocessing, and especially using a language model for decoding would certainly improve the model’s overall performance.

You can play with and use the model I trained from here. Thank you!!

]]>
Sagar Sapkota[email protected]https://spktsagar.com
Implementation of MetaGAN: An Adversarial Approach to Few-Shot Learning2020-03-12T00:00:00-07:002020-03-12T00:00:00-07:00https://spktsagar.com/posts/2020/03/impl-metaganBackground

Humans can recognize objects from a few examples. Having seen a lot of animal images before, given very few images of novel animals, we can recognize them easily. But for deep learning models, we have to train them from scratch to learn a new task. Transfer Learning and fine-tuning are some of the techniques to adapt trained models to learn a new task. The problem with them is that such models are trained only on a single task; adapting to a completely new task needs manual verification of similarity between these tasks. One of the recent approaches to this is the concept of meta-learning. The purpose of meta-learning schemes is to share information between the models being trained on similar tasks by using adaptation strategies to extract patterns that are useful for more than one task.

Learning from a small number of samples presents another difficulty for machine learning. The few-shot learning and zero-shot learning frameworks teach models to generalize to new datasets using relatively few samples. A K-shot classification problem, for instance, requires the model to generalize using just K examples: in the extreme case, the model generalizes using zero examples in zero-shot learning.

The model must adapt to new tasks with few instances and training iterations when the context of few-shot learning is added in meta-learning schemes. For example, with the model trained on various languages’ handwritten digit recognition tasks, with only a few handwritten examples per alphabet in a completely new language and very few training iterations, the model needs to generalize to that new language. To do this, a series of tasks are used to train the model or learner (character recognition model) during the meta-learning phase (e.g., different language character recognition). Instead of forcing the model or learner to focus on a specific task, we allow them to acquire intrinsic features that are generally applicable to all tasks in the task distribution \(P(\mathrm{T})\). Our goal is to identify model parameters (meta-learning phase) that are responsive to task changes such that minor changes in the parameters (adaptation phase) result in significant gains in the loss function for each task taken from \(P(\mathrm{T})\).

Fig 1: Diagram showing meta-learning and adaptation phase. Source: MAML Paper

MetaGAN is a simple and general framework for few-shot learning problems. Given a K-shot(number of samples per class in a training task) and an N-way(number of classes in a task) classifier, a conditional task generator generates samples that are not distinguishable from true data samples drawn from the task used to condition it. We now need to train the classifier(discriminator in GAN but with N output units for N-way classification) and generator in an adversarial setup.

What is the gain of using GAN in few-shot meta-learning? In a few-shot classification problem, the model tries to find a decision boundary for each task with just a few samples in each class. With very few samples, so many decision boundaries can be made, but most of them will not generalize well. Meta-learning tries to mitigate this problem by trying to learn a shared strategy across different tasks to form a decision boundary from a few samples in the hope that the strategy of making decision boundaries generalizes well to new tasks. Although this is plausible, there might be some problems. For example, some objects look more similar than others. It may be easier to form a decision boundary between a Chinese alphabet and an English alphabet than between a Chinese alphabet and a Korean alphabet. If the training data does not contain tasks that try to separate the Chinese alphabet from the Korean alphabet, the learner may find it difficult to extract the correct features to separate these two classes of objects. However, on the other hand, the expectation to have all kinds of class combinations during training leads to the combinatorial explosion problem. This is where MetaGAN helps. The generator in MetaGAN generates fake data. This forces the classifier(discriminator) to learn a sharper decision boundary. Instead of a classifier learning to separate Chinese and Korean alphabets, MetaGAN also forces it to learn to distinguish between real and fake Chinese and Korean alphabets, as shown in the figure below. Moreover, we don’t need the generator to generate data that are exactly similar to true data. It is better if the generator learns a bit off about data manifold.

Fig 2: Decision Boundary with MetaGAN(left) and Decision Boundary without MetaGAN(right). Colors represent different classes: gray means fake classes, and green and bluish can be real characters’ images from different languages. + and - means real and fake samples. Source: MetaGAN Paper

Objective

The objective of this blog is to show you the implementation of MetaGAN. While some basics of GAN are expected from you before you delve deeper into this implementation, you will learn about meta-learning and semi-supervised classification. After reading this blog, you’ll realize that GANs can be used for purposes other than as generative models. For the reason that our only purpose is to generate samples that are plausible with real data, we ignore the discriminator when the vanilla GAN training is finished. When the discriminator is extended to output class labels, we can use it to perform supervised and semi-supervised classification, which helps in utilizing unlabeled data with very few labeled data to increase performance. Implementing meta-learning on top of that, you will see that with very few training iterations, our classifier will be able to achieve very significant performance.

Implementation

Basic Library Imports

Start with basic library imports and environment setup.

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
tf.__version__
'2.8.2'
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  except RuntimeError as e:
    print(e)

Dataset Preparation

Omniglot Dataset

We will use Omniglot Dataset which contains 1623 different handwritten characters from 50 different languages. Each character has 20 samples, so in total 32460 samples exist in the dataset. This dataset is particularly used for few-shot learning problems and is available in Tensorflow Dataset resources. Let’s import it.

import tensorflow_datasets as tfds
omniglot, info = tfds.load('omniglot', with_info=True)
info  # see the info about omniglot dataset here
tfds.core.DatasetInfo(
    name='omniglot',
    full_name='omniglot/3.0.0',
    description="""
    Omniglot data set for one-shot learning. This dataset contains 1623 different
    handwritten characters from 50 different alphabets.
    """,
    homepage='https://github.com/brendenlake/omniglot/',
    data_path='~/tensorflow_datasets/omniglot/3.0.0',
    file_format=tfrecord,
    download_size=17.95 MiB,
    dataset_size=12.29 MiB,
    features=FeaturesDict({
        'alphabet': ClassLabel(shape=(), dtype=tf.int64, num_classes=50),
        'alphabet_char_id': tf.int64,
        'image': Image(shape=(105, 105, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=1623),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'small1': <SplitInfo num_examples=2720, num_shards=1>,
        'small2': <SplitInfo num_examples=3120, num_shards=1>,
        'test': <SplitInfo num_examples=13180, num_shards=1>,
        'train': <SplitInfo num_examples=19280, num_shards=1>,
    },
    citation="""@article{lake2015human,
      title={Human-level concept learning through probabilistic program induction},
      author={Lake, Brenden M and Salakhutdinov, Ruslan and Tenenbaum, Joshua B},
      journal={Science},
      volume={350},
      number={6266},
      pages={1332--1338},
      year={2015},
      publisher={American Association for the Advancement of Science}
    }""",
)

Being aware of following few shot learning problem statements(extracted from MetaGAN paper) that helps in dataset preparation for MetaGAN.

Given a distribution of tasks \(P(T)\), a sample task \(T\) from \(P(T)\) is given by a joint distribution \(P^T_{X \times Y}(x, y)\), where the task is to predict \(y\) given \(x\). We have a set of training sample tasks \(\{T_i\}^N_{i=1}\). Each training sample task \(T\) is a tuple \(T = (S_T, Q_T)\), where the support set is denoted as \(S_T = S^s_T \cup S^u_T\), and the query set is denoted as \(Q_T = Q^s_T \cup Q^u_T\). The supervised support set \(S^s_T = \{(x_1, y_1), (x_2, y_2), \cdots (x_{N×K}, y_{N×K})\}\) contains \(K\) labeled samples from each of the \(N\) classes (this is usually known as \(K\)-shot \(N\)-way classification). The optional unlabeled support set \(S^u_T = \{(x^u_1 , x^u_2 , \cdots x^u_M\}\) contains unlabeled samples from the same set of \(N\) classes, which can also be empty in purely supervised cases. \(Q^s_T = \{(x_1, y_1), (x_2, y_2), \cdots (x_T, y_T)\}\) is the supervised query dataset. \(Q^u_T = \{x_1, x_2, \cdots x_P\}\) is the optional unlabeled query dataset. The objective of the model is to minimize the loss of its predictions on a query set, given the support set as input.

Simply put, for each alphabet in the omniglot dataset, we will keep K-shot(K-support in code) samples in the support set and K-query samples in the query set. Since the omniglot dataset doesn’t have any unlabeled samples, we don’t need to prepare an unlabeled support set and a query set. Whenever we require an unlabeled set, we will replace it with the corresponding labeled support set and labeled query set. A task is prepared by selecting N classes, each having K-shot support samples and K-query query samples. The support set will be used to fine tune the learner(adaptation), and the query set will be used to evaluate the adapted learner. Accumulated evaluation loss from a query set of a number of tasks will be used to update the learner to the best position from where, with very few gradient updates, the learner can be adapted to a new task.

We will be doing \(5\)-way, \(5\)-shot, \(15\)-query meta learning with 32 tasks for single meta update. Omniglot dataset contains images with size \(105 \times 105\) and \(3\) channels. We will resize it to \(28 \times 28\). All channels in images in Omniglot are same, so we can achieve our objective with single channel only. We will create a task by randomly selecting \(5\) labels from the train set, regardless of their alphabets. That means a single task can contain alphabets from different languages.

n_way = 5
k_support = 5  # alias of K in K-shot
k_query = 15
task_batch = 32  # number of task for single meta update
image_size = [28, 28]
num_of_channels = 1
noise_dim = 100  # number of dimension in latent space from where noise is sampled

Dataset preparation steps

  1. Just take image and their corresponding labels. Ignore all others info in Omniglot Samples
  2. Group train and test omniglot dataset samples by its label in a batch of (k_support + k_query) samples. Filter out group with samples less than (k_support + k_query). Before that shuffle dataset so that samples filtered out in one iteration gets chance to be involved in next iteration.
  3. Resize and normalize between [-1, 1]
  4. Randomly rotate all images in a class by one of 0, 90, 180, 270 degree to create new class image.
  5. Take random n_way labels and batch them to form a single task. Ignore task with number of classes less than n_way.
  6. Relabel images in a task. Relabel them to class [0, 1, …, n_way-1] for n_way different classes in a task
  7. Split n_way task into support set and query set.
  8. For training dataset take task_batch tasks and batch into one. One task batch will be used for one metalearning step.
def get_images_and_labels(sample):
    """Returns image and corresponding labels from omniglot samples.
    
    A Omniglot samples is a dictionary with following structure.
    `{alphabet: (), alphabet_char_id: (), image: (105, 105, 3), label: ()}`
    
    Parameters
    ----------
    sample : `dict` of Omniglot sample
    
    Returns
    ----------
    image : `Tensor` of dtype `tf.float32`
        Image tensor shaped [105, 105, 3] in `sample` dictionary
    
    label : `Tensor` of dtype `tf.int64`
        Scalar Label tensor in `sample` dictionary
    """
    image = tf.cast(sample['image'], tf.float32)
    label = tf.cast(sample['label'], tf.int64)
    return image, label
def get_label_group_func(dataset):
    """Returns a dataset where grouping of omniglot sample by its label
    and reduction of them to batch of size `k_support + k_query` is done
    
    Returns
    ----------
    dataset : 
        A `tf.data.Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`
    """
    dataset = dataset.group_by_window(key_func=lambda x, y: y,
                                                reduce_func=lambda _, els: els.batch(k_support + k_query),
                                                window_size=k_support + k_query)
    return dataset
def label_group_filter(images, labels):
    """A predicate to check if labeled group Omniglot images
    has exactly `k_support + k_query` samples. Otherwise we 
    cannot make support set and query set from this label group.
    Ignore them.
    
    Parameters
    ----------
    images : `Tensor` of dtype `tf.float32`
        Shape `[k_support + k_query, h, w, c]`
        Images from Omniglot with same labels
    labels : `Tensor` of dtype `tf.int64`
        Shape `[k_support + k_query,]`
        Corresponding labels of input images. They must be same here
    
    Returns
    ----------
    right_size_label_group : `boolean`
        `True` if images have `k_support + k_query` samples else `False`
    """
    right_size_label_group =  tf.shape(images)[0] == (k_support + k_query)
    return right_size_label_group
def resize_and_normalize(images, labels):
    """Resize image and normalize them in between `[-1, 1]`.
    
    Parameters
    ----------
    images : `Tensor` of dtype `tf.float32`
        Shape `[k_support + k_query, height, width, channels]`
    labels : `Tensor` of dtype `tf.int64`
        Shape `[k_support + k_query, height, width, channels,]`
        Corresponding labels of input `images`.
    
    Returns
    ----------
    images : `Tensor` of dtype `tf.float32`
        Shape `` 
        Resized and normalized `images` with shape `image_size` and values in between `[-1, 1]`
        Returns images with only first `num_of_channels` channels.
        All channels in Omniglot dataset are same.
    labels : `Tensor` of dtype `tf.int64`
        Same as input labels.
    """
    images = tf.image.resize((images[:, :, :, :num_of_channels]-127.5)/127.5, image_size)
    return images, labels
def data_augment(images, labels):
    """Data augmentation by randomly rotating each image by a
    multiple of 90 degrees to form new classes.
    
    Parameters
    ----------
    images : `Tensor` of dtype `tf.float32`
        Shape `[k_support + k_query, height, width, channels]`
        `k_support + k_query` samples in `images` should be from same class.
    labels : `Tensor` of dtype `tf.int64`
        Shape `[k_support + k_query,]`
        Corresponding labels of input `images`. Here all labels should be same. 
    
    Returns
    ----------
    images: `Tensor` of dtype `tf.float32`
        Shape same as input `images`
    labels : `Tensor` of dtype `tf.int64`
        Same as input labels. We can consider same label for new class
        as in old class since all images are rotated by same angle.
    """
    rotation = tf.random.uniform([], maxval=4, dtype=tf.int32)
    images = tf.image.rot90(images, k=rotation)
    return images, labels
def relabel(images, labels):
    """Relabel images for `n_way` classification.
    
    Parameters
    ----------
    images : `Tensor` of dtype `tf.float32`
        Shape `[n_way, k_query + k_support, height, width, channels]`
        Images for a single task managed in 5D tensor.
        --> 1st dimension: for `n_way` classes
        --> 2nd dimension: for `k_query + k_support` images samples for a class
        --> Other dimensions: Images height, width and channels
    labels : `Tensor` of dtype `tf.int64`
        Shape `[n_way, k_query + k_support]` 
        Omniglot labels for images in same structure  as input `images`.
        Labels value for images from each class can be any value betwenn 0 to 1622(total alphabets-1)
    
    Returns
    ----------
    images : `Tensor` of dtype `tf.float32`
        Returns same input `images`. No changes required here.
    new_labels : `Tensor` of dtype `tf.int64`
        Shape `[n_way, k_query + k_support]`
        New labels value must be between `[0, n_way-1]`
    """
    old_labels_shape = tf.shape(labels)
    new_classes = tf.expand_dims(tf.range(old_labels_shape[0]), -1)
    new_labels = tf.tile(new_classes, [1, old_labels_shape[-1]])
    new_labels = tf.cast(new_labels, tf.int64)

    return images, new_labels
def get_support_query_split_func(shuffle=True):
    """Returns a function that will split a task into support set and query set
    
    Parameters
    ----------
    shuffle : `boolean`
        Flags whether to shuffle before splitting `n_way` task into support set and query set.
        If not take first `k_support` into support set and remaining into query set.
    
    Returns
    ----------
    support_query_split : `function`
        A function that will split a task into support set and query set.
    """
    def support_query_split(nway_images, nway_labels):
        """Split `n_way` task into `k_support`s support set and `k_query` query set.
        
        Parameters
        ----------
        nway_images : `Tensor` of dtype `tf.float32`
            Shape `[n_way, k_query + k_support, height, width, channels]`
            `k_query + k_support` images for each class in `n_way` task.
        nway_labels : `Tensor` of dtype `tf.int64`
            Shape `[n_way, k_query + k_support]`
            N-way labels for images in same structure  as input `nway_images`.
        
        Returns
        ----------
        support_images : `Tensor` of dtype `tf.float32`
            Shape `[n_way, k_support, height, width, channels]`
            Used for adaptation(K-shot learning).
        support_labels : `Tensor` of dtype `tf.int64`
            Shape `[n_way, k_support]`
            Used for adaptation(K-shot learning).
        query_images : `Tensor` of dtype `tf.float32`
            Shape `[n_way, k_query, height, width, channels]`
            Used for metalearning step(Testing adapted learner).
        query_labels : `Tensor` of dtype `tf.int64`
            Shape `[n_way, k_query]`
            Used for metalearning step(Testing adapted learner).
        """
        
        images_shape = tf.shape(nway_images)
        
        perm = tf.random.shuffle(tf.range(images_shape[1])) if shuffle \
                else tf.range(images_shape[1])

        support_images = tf.gather(nway_images, perm[:k_support], axis=1)
        support_images = tf.reshape(support_images, (-1, images_shape[-3], images_shape[-2], images_shape[-1]))
        support_labels = tf.gather(nway_labels, perm[:k_support], axis=1)
        support_labels = tf.reshape(support_labels, [-1])

        query_images = tf.gather(nway_images, perm[k_support:], axis=1)
        query_images = tf.reshape(query_images, (-1, images_shape[-3], images_shape[-2], images_shape[-1]))
        query_labels = tf.gather(nway_labels, perm[k_support:], axis=1)
        query_labels = tf.reshape(query_labels, [-1])

        return support_images, support_labels, query_images, query_labels
    
    return support_query_split

As mentioned in dataset preparation steps, we use above defined function in tf.data.Dataset pipeline. You can see in Omniglot info above, The dataset has TRAIN, TEST and other splits(not used). Remember, we do no shuffling in test dataset.

train_dataset_task_grouped = omniglot['train']\
                                .map(get_images_and_labels)\
                                .shuffle(19280, reshuffle_each_iteration=True)\
                                .group_by_window(key_func=lambda x, y: y,  # Group by label
                                                reduce_func=lambda _, els: els.batch(k_support + k_query),  # Batch size is k_support + k_query
                                                window_size=k_support + k_query)\
                                .filter(label_group_filter)\
                                .map(resize_and_normalize)\
                                .map(data_augment)\
                                .shuffle(964, reshuffle_each_iteration=True)\
                                .batch(batch_size=n_way, drop_remainder=True)\
                                .map(relabel)\
                                .map(get_support_query_split_func(shuffle=True))\
                                .batch(batch_size=task_batch)

test_dataset_nway_grouped = omniglot['test']\
                                .map(get_images_and_labels)\
                                .group_by_window(key_func=lambda x, y: y,  # Group by label
                                                reduce_func=lambda _, els: els.batch(k_support + k_query),  # Batch size is k_support + k_query
                                                window_size=k_support + k_query)\
                                .filter(label_group_filter)\
                                .map(resize_and_normalize)\
                                .batch(batch_size=n_way, drop_remainder=True)\
                                .map(relabel)\
                                .map(get_support_query_split_func(shuffle=False))

Now define a utility function to generate noise from normal distribution.

def generate_noise(shape):
    """Generate noise from normal distribution.
    
    Parameters
    ----------
    shape: `tuple` of `int` or `tf.int` or both.
        Shape of noise tensor to generate.
    
    Returns
    ----------
    noise: `Tensor` of dtype `tf.float32`
        Noise tensor shaped `shape` from normal distribution.
    """
    noise = tf.random.normal(shape)
    return noise

Model Creation

Generator

The generator should be able to generate fake data that is close to real data manifold in specific task \(T\). That means we need to condition generator on task basis. For that, we compress the information in the task’s support dataset with a dataset encoder \(E\) into vector \(h_T\), which contains sufficient statistics for the data distribution of task \(T\). The task representation vector \(h_T\) is than concatenated with noise input \(z\) to be input to the generator network. The task encoder contains two modules. The Instance-Encoder encodes each samples and feature aggregation module produce representation vector \(h_T\) for the whole training task set by some aggrefation scheme like averaging, max-pooling, etc.

Implementation of Task Encoder Model.

class TaskEncoder(tf.keras.Model):
    """Takes a task $$T$$ support set and generates a representation vector $$h_T$$.
    
    Parameters
    ----------
    conv_filters : `list` of `int`
        List of number of filters to use in each `Conv2D` layers.
        The length of list is the number of `Conv2D` layers to use. 
    conv_kernels : `list` of `int` or `tuples` or both.
        List of kernel size to use in each corresponding `Conv2D` layers.
    conv_strides : `list` of `int` or `tuples` or both.
        List of strides to use in each corresponding `Conv2D` layers.
    output_units : `int`
        Dimension of Task representation.
    
    Input shape
    ----------
        N-D tensor with shape: `(n_way*k_support, height, width, channels)`
        Whole support set for a task must be given to this model
    
    Output shape
    ----------
        N-D tensor with shape: `(1, output_units)`
        Single representation vector for the whole task set.
    """
    def __init__(self,
                 conv_filters,
                 conv_kernels,
                 conv_strides,
                 output_units,
                 **kwargs):
        super(TaskEncoder, self).__init__(**kwargs)
        
        self.conv2d_layers = [tf.keras.layers.Conv2D(filters=f,
                                                     kernel_size=k,
                                                     strides=s,
                                                     padding='same')
                                for f, k, s in zip(conv_filters,
                                                   conv_kernels,
                                                   conv_strides)]
        self.activation_layers = [tf.keras.layers.LeakyReLU()
                                    for _ in conv_filters]
        self.flatten_layer = tf.keras.layers.Flatten()
        self.output_layer = tf.keras.layers.Dense(output_units)
        self.output_activation = tf.keras.layers.LeakyReLU()
        self.output_dropout = tf.keras.layers.Dropout(rate=0.2)
    
    def call(self, inputs):
        for conv, activation in zip(self.conv2d_layers,
                                    self.activation_layers):
            inputs = conv(inputs)
            inputs = activation(inputs)
        
        outputs = self.flatten_layer(inputs)
        outputs = self.output_layer(outputs)
        outputs = self.output_activation(outputs)
        outputs = self.output_dropout(outputs)
        task_repr = tf.reduce_mean(outputs, axis=0, keepdims=True)

        return task_repr
conv_filters=[64, 64, 128, 128, 256, 256]
conv_kernels=[3, 3, 3, 3, 3, 3]
conv_strides=[1, 2, 1, 2, 1, 2]
output_units=256
task_encoder = TaskEncoder(conv_filters=conv_filters,
                           conv_kernels=conv_kernels,
                           conv_strides=conv_strides,
                           output_units=output_units)

Implementation of Task Conditioned Generative Model

Note: No batchnormalization is used here.

class ConditionalGenerator(tf.keras.Model):
    """Task conditioned generator
    
    Parameters
    ----------
    conv_start_shape : `tuple` of length 3
        The conditioned noise will be projected to `np.prod(conv_start_shape)`
        and reshaped to `conv_start_shape`. To this output we can perform
        convolutional operation.
    upsample_scales : `list` of `int`
        Instead of `Conv2DTranspose`, we use `Upsample + Conv2D`. It is the 
        list of scale sizes for upsampling before corresponding convolutional layers.
    conv_filters : `list` of `int`
        List of number of filters to use in each `Conv2D` layers.
        The length of list is the number of `Conv2D` layers to use. 
    conv_kernels : `list` of `int` or `tuples` or both.
        List of kernel size to use in each `Conv2D` layers.
    
    Input shape
    ----------
        Tuple of N-D tensor of length 2 and with shapes:
            `[(n_way*k_support, noise_dim), (1, task_repr_size)]` during adaptation
            or `[(n_way*k_query, noise_dim), (1, task_repr_size)]` during metalearning
        1st element in tuple is noise and 2nd is task representation vector to condition generator.
        
    Output shape
    ----------
        N-D tensor with shape: 
            `(n_way*k_support, height, width, channels)` during adaptation
            `(n_way*k_query, height, width, channels)` during metalearning    
    """
    def __init__(self,
                 conv_start_shape,
                 upsample_scales,
                 conv_filters,
                 conv_kernels,
                 **kwargs):
        super(ConditionalGenerator, self).__init__(**kwargs)
        
        self.concatenation = tf.keras.layers.Concatenate()
        self.noise_embedding_projection = tf.keras.layers.Dense(np.prod(conv_start_shape))
        self.noise_embedding_reshape = tf.keras.layers.Reshape(conv_start_shape)

        self.upsample_layers = [tf.keras.layers.UpSampling2D(size=scale)
                                    for scale in upsample_scales]
        self.conv_layers = [tf.keras.layers.Conv2D(filters=f,
                                                   kernel_size=k,
                                                   padding='same')
                                for f, k in zip(conv_filters,
                                                conv_kernels)]
        self.activation_layers = [tf.keras.layers.LeakyReLU()
                                    for _ in range(len(conv_filters) - 1)]\
                                    + [tf.keras.layers.Activation('tanh'),]
    
    def call(self, inputs):
        noise, task_representation = inputs

        task_encodings = tf.tile(task_representation, [tf.shape(noise)[0], 1])

        contitioned_noise = self.concatenation([noise, task_encodings])  # noise is now task conditioned

        # Now, Same as normal gan except upsample instead of convtranspose.
        output_image = self.noise_embedding_reshape(self.noise_embedding_projection(contitioned_noise))
        for upsample, conv, activation in zip(self.upsample_layers,
                                              self.conv_layers,
                                              self.activation_layers):
            output_image = upsample(output_image)
            output_image = conv(output_image)
            output_image = activation(output_image)
        
        return output_image
generator = ConditionalGenerator(conv_start_shape=(7, 7, 256),
                                 upsample_scales=[1, 2, 2],
                                 conv_filters=[128, 64, num_of_channels],
                                 conv_kernels=[5, 5, 5])

Discriminator

MetaGAN discriminator can be any few shot classifiers. We will be using one from Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks Chelsea

Implementation of Discriminator/Classifier Model

class Discriminator(tf.keras.Model):
    """The Discriminator/Classifier network for MetaGAN.

    Parameters
    ----------
    conv_filters : `list` of `int`
        List of number of filters to use in each `Conv2D` layers.
        The length of list is the number of `Conv2D` layers to use. 
    conv_kernels : `list` of `int` or `tuples` or both.
        List of kernel size to use in each `Conv2D` layers.
    conv_strides : `list` of `int` or `tuples` or both.
        List of strides to use in each `Conv2D` layers.
    
    Input shape
    ----------
        N-D tensor with shape: `(batch_size, height, width, channels)`.
    Output shape:
    ----------
        Tuple of 2 N-D tensor with shape: `[(batch_size, n_way), (batch_size, flattened_size)]`.
        Tuple of length 2 with 1st element classifier logit output and
        2nd element flattened last convolutional layer output.
    """
    def __init__(self,
                 conv_filters,
                 conv_kernels,
                 conv_strides,
                 n_way,
                 **kwargs):
        super(Discriminator, self).__init__(**kwargs)

        self.conv2d_layers = [tf.keras.layers.Conv2D(filters=f,
                                                     kernel_size=k,
                                                     strides=s,
                                                     padding='same')
                                for f, k, s in zip(conv_filters,
                                                   conv_kernels,
                                                   conv_strides)]
        self.batchnorm_layers = [tf.keras.layers.BatchNormalization()
                                for _ in conv_filters]
        self.activation_layers = [tf.keras.layers.ReLU()
                                for _ in conv_filters]
        
        self.flatten_layer = tf.keras.layers.Flatten()
        self.output_layer = tf.keras.layers.Dense(n_way)
        
    def call(self, inputs):
        for conv, batchnorm, activation in zip(self.conv2d_layers,
                                               self.batchnorm_layers,
                                               self.activation_layers):
            inputs = conv(inputs)
            inputs = batchnorm(inputs, training=True)
            inputs = activation(inputs)
        
        flattened_features = self.flatten_layer(inputs)
        class_logits = self.output_layer(flattened_features)

        return class_logits, flattened_features

discriminator = Discriminator(conv_filters=[64, 64, 64, 64],
                              conv_kernels=[3, 3, 3, 3],
                              conv_strides=[2, 2, 2, 2],
                              n_way=n_way)

# We require two discriminators to restore previous weights after trial of meta learning step
duplicate_discriminator = Discriminator(conv_filters=[64, 64, 64, 64],
                                        conv_kernels=[3, 3, 3, 3],
                                        conv_strides=[2, 2, 2, 2],
                                        n_way=n_way)

Optimizers and Loss Functions

We’ll use hyperparameters that are choosen according to MetaGAN paper and Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks Chelsea

meta_learning_rate=1e-3
meta_beta_1 = 0.5
meta_beta_2 = 0.9
adaption_learning_rate=0.1
adaptation_number_of_steps = 5
EPOCHS = 100

For adaptation steps, use SGD, and for meta learning, use adam optimizer with above defined parameters.

adaptation_optimizer = tf.keras.optimizers.SGD(learning_rate=adaption_learning_rate)  # used for inner gradient update
meta_discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=meta_learning_rate,
                                                        beta_1=meta_beta_1,
                                                        beta_2=meta_beta_2)
meta_generator_optimizer = tf.keras.optimizers.Adam(learning_rate=meta_learning_rate,
                                                    beta_1=meta_beta_1,
                                                    beta_2=meta_beta_2)

As mentioned in the paper, discriminator loss can be divided into two parts. One that represents the GAN problem, the unsupervised loss and the other that represents the individual n_way class probabilities, the supervised loss. To get the classification probabilities, we feed the logits through the softmax function. However, We still need a way to represent the probability of an input image being real rather than fake. That is, we still need to account for the binary classification problem of regular a GAN. We know that the logits are in terms of softmax probability values. Yet, we need a way to represent them as sigmoid logits as well. We know that the probability of an input being real corresponds to the sum over all real class logits. With that mind, we can feed these values to a LogSumExp function that will model the binary classification value. After that, we feed the result from the LogSumExp to a sigmoid function as binary classification logits. Using the Tensorflow’s LogSumExp built-in function helps to avoid numerical problems. This routine prevents over/under flow issues that may occur when LogSumExp encounters very extreme, either positive or negative values.

As usual in vanilla GAN, for images coming from the training set, we maximize their probabilities of being real by assigning labels of 1s. For fabricated images coming from the generator, we maximize their probabilities to be fake by giving them labels of 0s.

Implementation Labeled/Supervised Cross Entropy Loss

def labeled_loss(target, class_logits):
    """Loss function to calculate supervised loss.
    
    Parameters
    ----------
    target : `Tensor` of dtype `tf.int64`
        Shape `(batch_size,)`
        Categorical labels for `n_way` classification
    class_logits : `Tensor` of dtype `tf.float32`
        Shape `(batch_size, n_way)`
        It's a logit tensor.
    
    Returns
    ----------
    loss : Scalar `Tensor` of dtype `tf.float32`
        Supervised cross entropy loss for labeled input
    """
    losses = tf.keras.losses.sparse_categorical_crossentropy(target, class_logits, from_logits=True)
    loss = tf.reduce_mean(losses)
    return loss

Implementation of Unlabeled/Unsupervised Loss

For both real and generated images

def unlabeled_loss(class_logits, real=True):
    """Loss function to calculate supervised loss.
    
    Parameters
    ----------
    class_logits : `Tensor` of dtype `tf.float32`
        Shape `(batch_size, n_way)`
        It's a class logit tensor predicted by model for unlabeled input.
    real : `boolean`
        Flags whether `class_logits` is for real unlabeled data or for unlabeled generated data.
    
    Returns
    ----------
    loss : Scalar `Tensor` of dtype `tf.float32`
        Unsupervised loss for unlabeled input.
    """
    gan_logits = tf.reduce_logsumexp(class_logits, axis=1)
    labels = tf.ones_like(gan_logits) if real else tf.zeros_like(gan_logits)
    losses = tf.keras.losses.binary_crossentropy(labels, gan_logits, from_logits=True)
    loss = tf.reduce_mean(losses)
    return loss

Implementation of Discriminator Loss

def discriminator_loss(label_class_logits,
                       label_target,
                       unlabel_class_logits,
                       fake_class_logits):
    """Function that estimates how well the discriminator is able to distinguish real images from fakes.
    
    discriminator_loss = supervised loss + real unsupervised loss + fake unsupervised loss.
    
    Parameters
    ----------
    label_class_logits : `Tensor` of dtype `tf.float32`
        Shape `(batch_size, n_way)`
        Classfier predicted class logits for labeled real data.
    label_target : `Tensor` of dtype `tf.int32`
        Shape `(batch_size,)`
        True categorical labels of `label_class_logits` for `n_way` classification.
    unlabel_class_logits : `Tensor` of dtype `tf.float32`
        Shape `(batch_size, n_way)`
        Classfier predicted class logits for unlabeled real data.
    fake_class_logits : `Tensor` of dtype `tf.float32`
        Shape `(batch_size, n_way)`
        Classfier predicted class logits for unlabeled generated data.
    
    Returns
    ----------
    loss : Scalar `Tensor` of dtype `tf.float32`
        
    """
    supervised_loss = labeled_loss(label_target, label_class_logits)
    real_unsupervised_loss = unlabeled_loss(unlabel_class_logits, real=True)
    fake_unsupervised_loss = unlabeled_loss(fake_class_logits, real=False)
    
    disc_loss = supervised_loss + real_unsupervised_loss + fake_unsupervised_loss
    return disc_loss

Implementation of Generator Loss

As described in the Improved Techniques for Training GANs paper, we use feature matching for the generator loss. Have a look at the author’s quote about feature matching: “Feature matching is the concept of penalizing the mean absolute error between the average value of some set of features on the training data and the average values of that set of features on the generated samples.” So, we take the average of the features of samples extracted from the discriminator when a real training minibatch is being processed and in same way take average of features for generated data samples. The generator loss is the mean squared difference between them.

def generator_loss(real_features, fake_features):
    """The generator's loss quantifies how well it was able to trick the discriminator.
    
    Parameters
    ----------
    real_features : `Tensor` of dtype `tf.float32`
        Shape : `(batch_size, flattened_size)`
        Flattened last convolutional layer output for real input data batch
    real_features : `Tensor` of dtype `tf.float32`
        Shape : `(batch_size, flattened_size)`
        Flattened last convolutional layer output for generated input data batch
    
    Returns
    ----------
    loss : Scalar `Tensor` of dtype `tf.float32`
        
    """
    real_mean_feature = tf.reduce_mean(real_features, axis=0)
    fake_mean_feature = tf.reduce_mean(fake_features, axis=0)
    gen_loss = tf.reduce_mean((real_mean_feature - fake_mean_feature)**2)
    return gen_loss

Implementation of Accuracy, an evaluation metric

def accuracy(class_logits, labels):
    """Accuracy measure for given class logits and target labels.
    
    Parameters
    ----------
    class_logits : `Tensor` of dtype `tf.float32`
        Shape `(batch_size, n_way)`
        It's a logit tensor.
    labels : `Tensor` of dtype `tf.int64`
        Shape `(batch_size,)`
        True categorical labels of `class_logits` for `n_way` classification.
    
    Returns
    ----------
    accuracy : `Tensor` of dtype `tf.float32`
        Accuracy measure for given class logits and target labels.
    """
    label_predictions = tf.argmax(class_logits, axis=-1)
    equality = tf.equal(labels, label_predictions)
    return tf.reduce_mean(tf.cast(equality, tf.float32))

Training

Define an adaptaion function.

def adaptation(learner, real_support, real_support_label, fake_support):
    """Given a learner model, perform `adaptation_number_of_steps`
    gradient descent using supervised support set, generated fake support set
    and same supervised support set as unsupervised support set as we don't
    have unlabeled omniglot images.
    
    Parameters
    ----------
    learner : `tensorflow.keras.Model`
        A few shot classifier model which needs to be adapted to given task support set.
    real_support : `Tensor` of dtype `tf.float32`
        Shape `(n_way*k_support, height, width, channels)`
        Target task support images
    real_support_label : `Tensor` of dtype `tf.int64`
        Shape `(n_way*k_support,)`
        `n_way` labels of given real support set
    fake_support : `Tensor` of dtype `tf.float32`
        Shape `(n_way*k_support, height, width, channels)`
        Generated images conditioned on same task as input real support images task.
    
    Returns
    ----------
    support_loss : Scalar Tensor of dtype `tf.float32`
        Support loss during last step of adaptation of given `learner`.
    """
    support_disc_losses = []
    for _ in range(adaptation_number_of_steps):
        with tf.GradientTape() as adaptation_tape:
            support_real_class_logits, _ = learner(real_support)
            support_fake_class_logits, _ = learner(fake_support)
            disc_loss = discriminator_loss(support_real_class_logits,
                                           real_support_label,
                                           support_real_class_logits,
                                           support_fake_class_logits)
        adaptation_grads = adaptation_tape.gradient(disc_loss, learner.trainable_variables)
        adaptation_optimizer.apply_gradients(zip(adaptation_grads, learner.trainable_variables))
        support_disc_losses.append(disc_loss)
    return tf.reduce_mean(support_disc_losses)
    
@tf.function
def meta_learn_step(support_taskbatch,
                    support_taskbatch_labels,
                    query_taskbatch,
                    query_taskbatch_labels):
    """Perform one step of metalearning given a batch of tasks.
    
    Parameters
    ----------
    support_taskbatch : `Tensor` of dtype `tf.float32`
        Shape `(task_batch, n_way*k_support, height, width, channels)`
        Support set images for task batch.
    support_taskbatch_labels : `Tensor` of dtype `tf.int64`
        Shape `(task_batch, n_way*k_support,)`
        Support set images for task batch set labels for the task batch.
    query_taskbatch : `Tensor` of dtype `tf.float32`
        Shape `(task_batch, n_way*k_query, height, width, channels)`
        Query set images for task batch.
    query_taskbatch_labels : `Tensor` of dtype `tf.int64`
        Shape `(task_batch, n_way*k_query,)`
        Query set images for task batch set labels for the task batch.
    
    Returns
    ----------
    taskbatch_query_discriminator_loss : Scalar Tensor of dtype `tf.float32`
        Average discriminator loss over task batch on query set 
    task_batch_query_generator_loss : Scalar Tensor of dtype `tf.float32`
        Average generator loss over task batch on query set 
    taskbatch_query_accuracy : Scalar Tensor of dtype `tf.float32`
        Average accuracy over task batch on query set 
    """
    number_of_tasks = support_taskbatch.shape[0]

    # Step 1. Store discriminator weights in another model,
    #         Both model should be built before executing this step
    for dup_wts, wts in zip(duplicate_discriminator.trainable_variables,
                            discriminator.trainable_variables):
        dup_wts.assign(wts)
    
    # Step 2. Initialize tensor to find total losses and accuracies on various tasks
    taskbatch_query_discriminator_loss = tf.constant(0.0)
    task_batch_query_generator_loss = tf.constant(0.0)
    taskbatch_query_accuracy = tf.constant(0.0)
    
    with tf.GradientTape() as meta_discriminator_tape, tf.GradientTape() as meta_generator_tape:
        ## Step 3. Repeat Step 4-12 for all tasks in current task batch.
        for task_no in range(number_of_tasks):
            # Step 4. For each task, find its representation vector using support set and TaskEncoder model.
            task_representation = task_encoder(support_taskbatch[task_no])
            
            # Step 5. Adapt discriminator model to the current task, call `adaptation` function passing discriminator
            #         and required support inputs
            with meta_discriminator_tape.stop_recording(), meta_generator_tape.stop_recording():
                # No need to recording operatin of adaptation for meta updates
                
                # Generate fake support set with same number of samples as in real support set in current task
                support_noise = generate_noise((tf.shape(support_taskbatch[task_no])[0], noise_dim))
                support_fake = generator([support_noise, task_representation])
                
                support_loss = adaptation(discriminator,
                                          support_taskbatch[task_no],
                                          support_taskbatch_labels[task_no],
                                          support_fake)
            # Step 6. Generate fake query set
            query_noise = generate_noise((tf.shape(query_taskbatch[task_no])[0], noise_dim))
            query_fake = generator([query_noise, task_representation])
            
            # Step 7. Find discriminator feature and class logits for real and generated query set of current task
            query_real_class_logits, query_real_features = discriminator(query_taskbatch[task_no])
            query_fake_class_logits, query_fake_features = discriminator(query_fake)
            
            # Step 8. Find discriminator loss
            disc_loss = discriminator_loss(query_real_class_logits,
                                           query_taskbatch_labels[task_no],
                                           query_real_class_logits,
                                           query_fake_class_logits)
            # Step 9. Find generator loss
            gen_loss = generator_loss(query_real_features, query_fake_features)
            
            # Step 10. Add query discriminator loss and generator loss to recording variable.
            taskbatch_query_discriminator_loss += disc_loss
            task_batch_query_generator_loss += gen_loss
            
            # Step 11. Calculate query accuracy for current task and add to the sum variable
            query_accuracy = accuracy(query_real_class_logits, query_taskbatch_labels[task_no])
            taskbatch_query_accuracy += query_accuracy
            
            # Step 12. Recover discriminator weights before adaptation for next task adaptation.
            for dup_wts, wts in zip(duplicate_discriminator.trainable_variables,
                                    discriminator.trainable_variables):
                wts.assign(dup_wts)
    
    # Step 13. Find discriminator and generator gradients; TaskEncoder is updated along with Generator,
    #          So from total generator loss, find gradients wrt to both generator variables and task encoder variables
    meta_discriminator_grads = meta_discriminator_tape.gradient(taskbatch_query_discriminator_loss,
                                                                discriminator.trainable_variables)
    meta_generator_grads = meta_generator_tape.gradient(task_batch_query_generator_loss,
                                                        task_encoder.trainable_variables + generator.trainable_variables)
    
    # Step 14. Using respective meta optimizer updates discriminator, generator and task encoder weights
    meta_discriminator_optimizer.apply_gradients(zip(meta_discriminator_grads,
                                                     discriminator.trainable_variables))
    meta_generator_optimizer.apply_gradients(zip(meta_generator_grads,
                                                 task_encoder.trainable_variables + generator.trainable_variables))
    # Find average metrices for task batch to return
    avg_disc_loss = taskbatch_query_discriminator_loss/number_of_tasks
    avg_gen_loss = task_batch_query_generator_loss/number_of_tasks
    avg_accuracy = taskbatch_query_accuracy/number_of_tasks
    
    return avg_disc_loss, avg_gen_loss, avg_accuracy
# for a single task
@tf.function
def evaluation(support_images, support_labels, query_images, query_labels):
    """Perform finetuning/adaptation using support set of a task and
    returns evaluation metrices calculated using query set of the same task.
    
    Parameters
    ----------
    support_images : `Tensor` of dtype `tf.float32`
        Shape `(n_way*k_support, height, width, channels)`
    support_labels : `Tensor` of dtype `tf.int64`
        Shape `(n_way*k_support,)
    query_images : `Tensor` of dtype `tf.float32`
        Shape `(n_way*k_query, height, width, channels)`
    query_labels : `Tensor` of dtype `tf.int64`
        Shape `(n_way*k_query,)
    
    Returns
    ----------
    query_discriminator_loss : Scalar Tensor of dtype `tf.float32`
        Classifier loss on query set of given task after adaptation on same task support set.
    query_accuracy : `Tensor` of dtype `tf.float32`
        Classifier accuracy on query set of given task after adaptation on same task support set.
    """
    # During evaluation since no metalearning step is done, we will use secondary discriminator for adaptation
    # and evaluation metrices calculation on query set.
    
    # Step 1. Copy primary discriminator weights to secondary discriminator.
    for dup_wts, wts in zip(duplicate_discriminator.trainable_variables, discriminator.trainable_variables):
        dup_wts.assign(wts)
    
    # Step 2. Produce fake support set and adapt secondary discriminator to passed task. 
    task_representation = task_encoder(support_images)
    support_noise = generate_noise((tf.shape(support_images)[0], noise_dim))
    support_fake = generator([support_noise, task_representation])

    support_loss = adaptation(duplicate_discriminator,
                              support_images,
                              support_labels,
                              support_fake)
    
    # Step 3. Produce fake query set with same number of samples as in real query set
    query_noise = generate_noise((tf.shape(query_images)[0], noise_dim))
    query_fake = generator([query_noise, task_representation])
    
    # Step 4. Find class logits for real and fake query set. Since generator loss is irrelevant here
    #         Discriminator loss is ignored here.
    query_real_class_logits, _ = duplicate_discriminator(query_images)
    query_fake_class_logits, _ = duplicate_discriminator(query_fake)
    
    # Step 5. Find discriminator loss
    query_discriminator_loss = discriminator_loss(query_real_class_logits,
                                                  query_labels,
                                                  query_real_class_logits,
                                                  query_fake_class_logits)
    
    # Step 6. Find accuracy on query set
    query_accuracy = accuracy(query_real_class_logits, query_labels)

    return query_discriminator_loss, query_accuracy

Training Loop

import os
checkpoint_dir = './training_checkpoints'
if not os.path.exists(checkpoint_dir):
    os.mkdir(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(adaptation_optimizer=adaptation_optimizer,
                                 meta_discriminator_optimizer=meta_discriminator_optimizer,
                                 meta_generator_optimizer=meta_generator_optimizer,
                                 task_encoder=task_encoder,
                                 generator=generator,
                                 discriminator=discriminator)
# To share same weights between primary discriminator and secondary discriminator
# manually intialize weights by calling build function
discriminator.build((None, image_size[0], image_size[1], num_of_channels))
duplicate_discriminator.build((None, image_size[0], image_size[1], num_of_channels))

gen_losses = []
disc_losses = []
accuracies = []

test_losses = []
test_accuracies = []


for ep in range(EPOCHS):
    
    gen_loss = []
    disc_loss = []
    disc_accuracy = []
    
    ########################################### Training ################################################
    for task_batch_no, (support_images, support_labels, query_images, query_labels) in enumerate(train_dataset_task_grouped):
        d_loss, g_loss, acc = meta_learn_step(support_images,
                                              support_labels,
                                              query_images,
                                              query_labels)
        disc_loss.append(d_loss)
        gen_loss.append(g_loss)
        disc_accuracy.append(acc)

    disc_loss = tf.reduce_mean(disc_loss)
    gen_loss = tf.reduce_mean(gen_loss)
    disc_accuracy = tf.reduce_mean(disc_accuracy)
        
    disc_losses.append(disc_loss)
    gen_losses.append(gen_loss)
    accuracies.append(accuracy)
    #####################################################################################################
    
    #################################### Evaluation and Logging ################################################
    if ep%10 == 0:  # Every 10 epochs
        test_loss = []
        test_accuracy = []
        
        for task_no, (support_images, support_labels, query_images, query_labels) in enumerate(test_dataset_nway_grouped):
            query_loss, query_accuracy = evaluation(support_images, support_labels, query_images, query_labels)
            
            test_loss.append(query_loss)
            test_accuracy.append(query_accuracy)
        
        test_loss = tf.reduce_mean(test_loss)
        test_accuracy = tf.reduce_mean(test_accuracy)

        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        
        tf.print('Epoch: ', ep,
                 'Train gen loss: ', gen_loss,
                 'Train disc loss: ', disc_loss,
                 'Train accuracy: ', disc_accuracy,
                 'Test loss: ', test_loss,
                 'Test accuracy: ', test_accuracy)
    
    else:
        tf.print('Epoch: ', ep,
                 'Train gen loss: ', gen_loss,
                 'Train disc loss: ', disc_loss,
                 'Train accuracy: ', disc_accuracy)
    #####################################################################################################

checkpoint.save(file_prefix = checkpoint_prefix)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

Evaluation on test dataset

test_loss = []
test_accuracy = []

for task_no, (support_images, support_labels, query_images, query_labels) in enumerate(test_dataset_nway_grouped):
    query_loss, query_accuracy = evaluation(support_images, support_labels, query_images, query_labels)
    
    test_loss.append(query_loss)
    test_accuracy.append(query_accuracy)

test_loss = tf.reduce_mean(test_loss)
test_accuracy = tf.reduce_mean(test_accuracy)

tf.print('Test loss: ', test_loss,
         'Test accuracy: ', test_accuracy)
Test loss:  1.80416536 Test accuracy:  0.628600478

As we said before, we don’t need generator to be perfect. Let’s see what generator has learned to generate

test_iterator = tfds.as_numpy(test_dataset_nway_grouped)

Run following cells multiple times to check generated images conditioned on different tasks.

test_task = next(test_iterator)
support_images, support_labels, query_images, query_labels = test_task
generated = generator([generate_noise((tf.shape(support_images)[0], noise_dim)), task_encoder(support_images)])

Real task images.

fig = plt.figure(figsize=(5,5))

for i in range(5*5):
    plt.subplot(5, 5, i+1)
    plt.imshow(support_images[i, :, :, -1], cmap='gray')
    plt.axis('off')
plt.show()

png

Generated images from generator conditioned on same task.

fig = plt.figure(figsize=(5,5))

for i in range(5*5):
    plt.subplot(5, 5, i+1)
    plt.imshow(generated[i, :, :, -1], cmap='gray')
    plt.axis('off')
plt.show()

png

]]>
Sagar Sapkota[email protected]https://spktsagar.com