<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.9.3">Jekyll</generator><link href="https://omoindrot.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://omoindrot.github.io/" rel="alternate" type="text/html" /><updated>2023-11-21T11:41:17+00:00</updated><id>https://omoindrot.github.io/feed.xml</id><title type="html">Olivier Moindrot blog</title><subtitle>My personal page.</subtitle><entry><title type="html">Triplet Loss and Online Triplet Mining in TensorFlow</title><link href="https://omoindrot.github.io/triplet-loss" rel="alternate" type="text/html" title="Triplet Loss and Online Triplet Mining in TensorFlow" /><published>2018-03-19T00:00:00+00:00</published><updated>2018-03-19T00:00:00+00:00</updated><id>https://omoindrot.github.io/triplet-loss</id><content type="html" xml:base="https://omoindrot.github.io/triplet-loss">&lt;p&gt;In face recognition, triplet loss is used to learn good embeddings (or “encodings”) of faces.
If you are not familiar with triplet loss, you should first learn about it by watching this &lt;a href=&quot;https://www.coursera.org/learn/convolutional-neural-networks/lecture/HuUtN/triplet-loss&quot;&gt;coursera video&lt;/a&gt; from Andrew Ng’s deep learning specialization.&lt;/p&gt;

&lt;p&gt;Triplet loss is known to be difficult to implement, especially if you add the constraints of building a computational graph in TensorFlow.&lt;/p&gt;

&lt;p&gt;In this post, I will define the triplet loss and the different strategies to sample triplets.
I will then explain how to correctly implement triplet loss with online triplet mining in TensorFlow.&lt;/p&gt;

&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;About two years ago, I was working on face recognition during my internship at &lt;a href=&quot;http://www.reminiz.com&quot;&gt;Reminiz&lt;/a&gt; and I answered a &lt;a href=&quot;https://stackoverflow.com/a/38270293/5098368&quot;&gt;question&lt;/a&gt; on stackoverflow about implementing triplet loss in TensorFlow. I concluded by saying:&lt;/p&gt;

&lt;blockquote&gt;
  &lt;p&gt;Clearly, implementing triplet loss in Tensorflow is hard, and there are ways to make it more efficient than sampling in python but explaining them would require a whole blog post !&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Two years later, here we go.&lt;/p&gt;

&lt;p&gt;&lt;em&gt;All the code can be found on this &lt;a href=&quot;https://github.com/omoindrot/tensorflow-triplet-loss&quot;&gt;github repository&lt;/a&gt;.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Table of contents&lt;/strong&gt;&lt;/p&gt;

&lt;ul id=&quot;markdown-toc&quot;&gt;
  &lt;li&gt;&lt;a href=&quot;#triplet-loss-and-triplet-mining&quot; id=&quot;markdown-toc-triplet-loss-and-triplet-mining&quot;&gt;Triplet loss and triplet mining&lt;/a&gt;    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#why-not-just-use-softmax&quot; id=&quot;markdown-toc-why-not-just-use-softmax&quot;&gt;Why not just use softmax?&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#definition-of-the-loss&quot; id=&quot;markdown-toc-definition-of-the-loss&quot;&gt;Definition of the loss&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#triplet-mining&quot; id=&quot;markdown-toc-triplet-mining&quot;&gt;Triplet mining&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#offline-and-online-triplet-mining&quot; id=&quot;markdown-toc-offline-and-online-triplet-mining&quot;&gt;Offline and online triplet mining&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#strategies-in-online-mining&quot; id=&quot;markdown-toc-strategies-in-online-mining&quot;&gt;Strategies in online mining&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#a-naive-implementation-of-triplet-loss&quot; id=&quot;markdown-toc-a-naive-implementation-of-triplet-loss&quot;&gt;A naive implementation of triplet loss&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#a-better-implementation-with-online-triplet-mining&quot; id=&quot;markdown-toc-a-better-implementation-with-online-triplet-mining&quot;&gt;A better implementation with online triplet mining&lt;/a&gt;    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#compute-the-distance-matrix&quot; id=&quot;markdown-toc-compute-the-distance-matrix&quot;&gt;Compute the distance matrix&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#batch-all-strategy&quot; id=&quot;markdown-toc-batch-all-strategy&quot;&gt;Batch all strategy&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#batch-hard-strategy&quot; id=&quot;markdown-toc-batch-hard-strategy&quot;&gt;Batch hard strategy&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#testing-our-implementation&quot; id=&quot;markdown-toc-testing-our-implementation&quot;&gt;Testing our implementation&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#experience-with-mnist&quot; id=&quot;markdown-toc-experience-with-mnist&quot;&gt;Experience with MNIST&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#conclusion&quot; id=&quot;markdown-toc-conclusion&quot;&gt;Conclusion&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#resources&quot; id=&quot;markdown-toc-resources&quot;&gt;Resources&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;hr /&gt;

&lt;h2 id=&quot;triplet-loss-and-triplet-mining&quot;&gt;Triplet loss and triplet mining&lt;/h2&gt;

&lt;h3 id=&quot;why-not-just-use-softmax&quot;&gt;Why not just use softmax?&lt;/h3&gt;

&lt;p&gt;The triplet loss for face recognition has been introduced by the paper &lt;a href=&quot;https://arxiv.org/abs/1503.03832&quot;&gt;&lt;em&gt;FaceNet: A Unified Embedding for Face Recognition and Clustering&lt;/em&gt;&lt;/a&gt; from Google.
They describe a new approach to train face embeddings using online triplet mining, which will be discussed in the &lt;a href=&quot;#offline-and-online-triplet-mining&quot;&gt;next section&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;Usually in supervised learning we have a fixed number of classes and train the network using the softmax cross entropy loss.
However in some cases we need to be able to have a variable number of classes.
In face recognition for instance, we need to be able to compare two unknown faces and say whether they are from the same person or not.&lt;/p&gt;

&lt;p&gt;Triplet loss in this case is a way to learn good embeddings for each face. In the embedding space, faces from the same person should be close together and form well separated clusters.&lt;/p&gt;

&lt;h3 id=&quot;definition-of-the-loss&quot;&gt;Definition of the loss&lt;/h3&gt;

&lt;p&gt;&lt;img src=&quot;assets/triplet_loss/triplet_loss.png&quot; alt=&quot;triplet-loss-img&quot; /&gt;&lt;/p&gt;
&lt;center&gt;&lt;i&gt;Triplet loss on two positive faces (Obama) and one negative face (Macron)&lt;/i&gt;&lt;/center&gt;

&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;The goal of the triplet loss is to make sure that:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;Two examples with the same label have their embeddings close together in the embedding space&lt;/li&gt;
  &lt;li&gt;Two examples with different labels have their embeddings far away.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;However, we don’t want to push the train embeddings of each label to collapse into very small clusters.
The only requirement is that given two positive examples of the same class and one negative example, the negative should be farther away than the positive by some margin.
This is very similar to the margin used in SVMs, and here we want the clusters of each class to be separated by the margin.&lt;/p&gt;

&lt;p&gt;&lt;br /&gt;
To formalise this requirement, the loss will be defined over &lt;strong&gt;triplets&lt;/strong&gt; of embeddings:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;an &lt;strong&gt;anchor&lt;/strong&gt;&lt;/li&gt;
  &lt;li&gt;a &lt;strong&gt;positive&lt;/strong&gt; of the same class as the anchor&lt;/li&gt;
  &lt;li&gt;a &lt;strong&gt;negative&lt;/strong&gt; of a different class&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;For some distance on the embedding space $d$,  the loss of a triplet $(a, p, n)$ is:&lt;/p&gt;

\[\mathcal{L} = max(d(a, p) - d(a, n) + margin, 0)\]

&lt;p&gt;We minimize this loss, which pushes $d(a, p)$ to $0$ and $d(a, n)$ to be greater than $d(a, p) + margin$. As soon as $n$ becomes an “easy negative”, the loss becomes zero.&lt;/p&gt;

&lt;h3 id=&quot;triplet-mining&quot;&gt;Triplet mining&lt;/h3&gt;

&lt;p&gt;Based on the definition of the loss, there are three categories of triplets:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;easy triplets&lt;/strong&gt;: triplets which have a loss of $0$, because $d(a, p) + margin &amp;lt; d(a,n)$&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;hard triplets&lt;/strong&gt;: triplets where the negative is closer to the anchor than the positive, i.e. $d(a,n) &amp;lt; d(a,p)$&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;semi-hard triplets&lt;/strong&gt;: triplets where the negative is not closer to the anchor than the positive, but which still have positive loss: $d(a, p) &amp;lt; d(a, n) &amp;lt; d(a, p) + margin$&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Each of these definitions depend on where the negative is, relatively to the anchor and positive. We can therefore extend these three categories to the negatives: &lt;strong&gt;hard negatives&lt;/strong&gt;, &lt;strong&gt;semi-hard negatives&lt;/strong&gt; or &lt;strong&gt;easy negatives&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;The figure below shows the three corresponding regions of the embedding space for the negative.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;assets/triplet_loss/triplets.png&quot; alt=&quot;triplet-types-img&quot; /&gt;&lt;/p&gt;
&lt;center&gt;&lt;i&gt;The three types of negatives, given an anchor and a positive&lt;/i&gt;&lt;/center&gt;

&lt;p&gt;&lt;br /&gt;
Choosing what kind of triplets we want to train on will greatly impact our metrics.
In the original Facenet &lt;a href=&quot;https://arxiv.org/abs/1503.03832&quot;&gt;paper&lt;/a&gt;, they pick a random semi-hard negative for every pair of anchor and positive, and train on these triplets.&lt;/p&gt;

&lt;h3 id=&quot;offline-and-online-triplet-mining&quot;&gt;Offline and online triplet mining&lt;/h3&gt;

&lt;p&gt;We have defined a loss on triplets of embeddings, and have seen that some triplets are more useful than others. The question now is how to sample, or “mine” these triplets.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Offline triplet mining&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;The first way to produce triplets is to find them offline, at the beginning of each epoch for instance.
We compute all the embeddings on the training set, and then only select hard or semi-hard triplets.
We can then train one epoch on these triplets.&lt;/p&gt;

&lt;p&gt;Concretely, we would produce a list of triplets $(i, j, k)$.
We would then create batches of these triplets of size $B$, which means we will have to compute $3B$ embeddings to get the $B$ triplets, compute the loss of these $B$ triplets and then backpropagate into the network.&lt;/p&gt;

&lt;p&gt;Overall this technique is not very efficient since we need to do a full pass on the training set to generate triplets.
It also requires to update the offline mined triplets regularly.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Online triplet mining&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Online triplet mining was introduced in &lt;em&gt;Facenet&lt;/em&gt; and has been well described by Brandon Amos in his blog post &lt;a href=&quot;http://bamos.github.io/2016/01/19/openface-0.2.0/&quot;&gt;&lt;em&gt;OpenFace 0.2.0: Higher accuracy and halved execution time&lt;/em&gt;&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;The idea here is to compute useful triplets on the fly, for each batch of inputs.
Given a batch of $B$ examples (for instance $B$ images of faces), we compute the $B$ embeddings and we then can find a maximum of $B^3$ triplets.
Of course, most of these triplets are not &lt;strong&gt;valid&lt;/strong&gt; (i.e. they don’t have 2 positives and 1 negative).&lt;/p&gt;

&lt;p&gt;This technique gives you more triplets for a single batch of inputs, and doesn’t require any offline mining. It is therefore much more efficient. We will see an implementation of this in the last &lt;a href=&quot;#a-better-implementation-with-online-triplet-mining&quot;&gt;part&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;assets/triplet_loss/online_triplet_loss.png&quot; alt=&quot;online-triplet-loss-img&quot; /&gt;&lt;/p&gt;
&lt;center&gt;&lt;i&gt;Triplet loss with online mining: triplets are computed on the fly from a batch of embeddings&lt;/i&gt;&lt;/center&gt;

&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;h3 id=&quot;strategies-in-online-mining&quot;&gt;Strategies in online mining&lt;/h3&gt;

&lt;p&gt;In online mining, we have computed a batch of $B$ embeddings from a batch of $B$ inputs.
Now we want to generate triplets from these $B$ embeddings.&lt;/p&gt;

&lt;p&gt;Whenever we have three indices $i, j, k \in [1, B]$, if examples $i$ and $j$ have the same label but are distinct, and example $k$ has a different label, we say that &lt;strong&gt;$(i, j, k)$ is a valid triplet&lt;/strong&gt;.
What remains here is to have a good strategy to pick triplets among the valid ones on which to compute the loss.&lt;/p&gt;

&lt;p&gt;A detailed explanation of two of these strategies can be found in section 2 of the paper &lt;a href=&quot;https://arxiv.org/abs/1703.07737&quot;&gt;&lt;em&gt;In Defense of the Triplet Loss for Person Re-Identification&lt;/em&gt;&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;They suppose that you have a batch of faces as input of size $B = PK$, composed of $P$ different persons with $K$ images each.
A typical value is $K = 4$.
The two strategies are:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;batch all&lt;/strong&gt;: select all the valid triplets, and average the loss on the hard and semi-hard triplets.
    &lt;ul&gt;
      &lt;li&gt;a crucial point here is to not take into account the easy triplets (those with loss $0$), as averaging on them would make the overall loss very small&lt;/li&gt;
      &lt;li&gt;this produces a total of $PK(K-1)(PK-K)$ triplets ($PK$ anchors, $K-1$ possible positives per anchor, $PK-K$ possible negatives)&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;batch hard&lt;/strong&gt;: for each anchor, select the hardest positive (biggest distance $d(a, p)$) and the hardest negative among the batch
    &lt;ul&gt;
      &lt;li&gt;this produces $PK$ triplets&lt;/li&gt;
      &lt;li&gt;the selected triplets are the hardest among the batch&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;According to the &lt;a href=&quot;https://arxiv.org/abs/1703.07737&quot;&gt;paper&lt;/a&gt; cited above, the batch hard strategy yields the best performance:&lt;/p&gt;
&lt;blockquote&gt;
  &lt;p&gt;Additionally, the selected triplets can be considered moderate triplets, since they are the hardest within a small subset of the data, which is exactly what is best for learning with the triplet loss.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;However it really depends on your dataset and should be decided by comparing performance on the dev set.&lt;/p&gt;

&lt;hr /&gt;

&lt;h2 id=&quot;a-naive-implementation-of-triplet-loss&quot;&gt;A naive implementation of triplet loss&lt;/h2&gt;

&lt;p&gt;In the &lt;a href=&quot;https://stackoverflow.com/a/38270293/5098368&quot;&gt;stackoverflow answer&lt;/a&gt;, I gave a simple implementation of triplet loss for offline triplet mining:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;anchor_output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# shape [None, 128]
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;positive_output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape [None, 128]
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;negative_output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# shape [None, 128]
&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;d_pos&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;square&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;anchor_output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;positive_output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;d_neg&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;square&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;anchor_output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;negative_output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;maximum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;margin&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d_pos&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d_neg&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The network is replicated three times (with shared weights) to produce the embeddings of $B$ anchors, $B$ positives and $B$ negatives.
We then simply compute the triplet loss on these embeddings.&lt;/p&gt;

&lt;p&gt;This is an easy implementation, but also a very inefficient one because it uses offline triplet mining.&lt;/p&gt;

&lt;hr /&gt;

&lt;h2 id=&quot;a-better-implementation-with-online-triplet-mining&quot;&gt;A better implementation with online triplet mining&lt;/h2&gt;

&lt;p&gt;All the relevant code is available on github in &lt;a href=&quot;https://github.com/omoindrot/tensorflow-triplet-loss/blob/master/model/triplet_loss.py&quot;&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;model/triplet_loss.py&lt;/code&gt;&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;em&gt;There is an existing implementation of triplet loss with semi-hard online mining in TensorFlow: &lt;a href=&quot;https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/triplet_semihard_loss&quot;&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tf.contrib.losses.metric_learning.triplet_semihard_loss&lt;/code&gt;&lt;/a&gt;.
Here we will not follow this implementation and start from scratch.&lt;/em&gt;&lt;/p&gt;

&lt;h3 id=&quot;compute-the-distance-matrix&quot;&gt;Compute the distance matrix&lt;/h3&gt;

&lt;p&gt;As the final triplet loss depends on the distances $d(a, p)$ and $d(a, n)$, we first need to &lt;em&gt;efficiently&lt;/em&gt; compute the pairwise distance matrix.
We implement this for the euclidean norm and the squared euclidean norm, in the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_pairwise_distances&lt;/code&gt; function:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_pairwise_distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embeddings&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;squared&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;Compute the 2D matrix of distances between all the embeddings.

    Args:
        embeddings: tensor of shape (batch_size, embed_dim)
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.

    Returns:
        pairwise_distances: tensor of shape (batch_size, batch_size)
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Get the dot product between all embeddings
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# shape (batch_size, batch_size)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;dot_product&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embeddings&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embeddings&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# This also provides more numerical stability (the diagonal of the result will be exactly 0).
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# shape (batch_size,)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;square_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;diag_part&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot_product&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Compute the pairwise distance matrix as we have:
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# ||a - b||^2 = ||a||^2  - 2 &amp;lt;a, b&amp;gt; + ||b||^2
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# shape (batch_size, batch_size)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;square_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;2.0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dot_product&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;square_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Because of computation errors, some distances might be negative so we put everything &amp;gt;= 0.0
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;maximum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;squared&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# we need to add a small epsilon where distances == 0.0
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to_float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;equal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-16&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Correct the epsilon added: set the distances on the mask to be exactly 0.0
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;distances&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;To explain the code in more details, we compute the dot product between embeddings which will have shape $(B, B)$.
The squared euclidean norm of each embedding is actually contained in the diagonal of this dot product so we extract it with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tf.diag_part&lt;/code&gt;.
Finally we compute the distance using the formula:&lt;/p&gt;

\[\Vert a - b \Vert ^2 = \Vert a \Vert^2 - 2 \langle a, b \rangle + \Vert b \Vert ^2\]

&lt;p&gt;One tricky thing is that if &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;squared=False&lt;/code&gt;, we take the square root of the distance matrix.
First we have to ensure that the distance matrix is always positive.
Some values could be negative because of small inaccuracies in computation.
We just make sure that every negative value gets set to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0.0&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;The second thing to take care of is that if any element is exactly &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0.0&lt;/code&gt; (the diagonal should always be &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0.0&lt;/code&gt; for instance), as the derivative of the square root is infinite in $0$, we will have a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;nan&lt;/code&gt; gradient.
To handle this case, we replace values equal to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0.0&lt;/code&gt; with a small &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;epsilon = 1e-16&lt;/code&gt;.
We then take the square root, and replace the values $\sqrt{\epsilon}$  with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;0.0&lt;/code&gt;.&lt;/p&gt;

&lt;h3 id=&quot;batch-all-strategy&quot;&gt;Batch all strategy&lt;/h3&gt;

&lt;p&gt;In this strategy, we want to compute the triplet loss on almost all triplets.
In the TensorFlow graph, we want to create a 3D tensor of shape $(B, B, B)$ where the element at index $(i, j, k)$ contains the loss for triplet $(i, j, k)$.&lt;/p&gt;

&lt;p&gt;We then get a 3D mask of the valid triplets with function &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_get_triplet_mask&lt;/code&gt;.
Here, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;mask[i, j, k]&lt;/code&gt; is true iff $(i, j, k)$ is a valid triplet.&lt;/p&gt;

&lt;p&gt;Finally, we set to $0$ the loss of the invalid triplets and take the average over the positive triplets.&lt;/p&gt;

&lt;p&gt;Everything is implemented in function &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch_all_triplet_loss&lt;/code&gt;:&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;batch_all_triplet_loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;embeddings&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;margin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;squared&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;Build the triplet loss over a batch of embeddings.

    We generate all the valid triplets and average the loss over the positive ones.

    Args:
        labels: labels of the batch, of size (batch_size,)
        embeddings: tensor of shape (batch_size, embed_dim)
        margin: margin for triplet loss
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.

    Returns:
        triplet_loss: scalar tensor containing the triplet loss
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Get the pairwise distance matrix
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;pairwise_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_pairwise_distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embeddings&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;squared&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;squared&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;anchor_positive_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pairwise_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;anchor_negative_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pairwise_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# and the 2nd (batch_size, 1, batch_size)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;anchor_positive_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;anchor_negative_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;margin&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Put to zero the invalid triplets
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# (where label(a) != label(p) or label(n) == label(a) or a == p)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_get_triplet_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to_float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;multiply&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Remove negative losses (i.e. the easy triplets)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;maximum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Count number of positive triplets (where triplet_loss &amp;gt; 0)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;valid_triplets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to_float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;greater&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_positive_triplets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;valid_triplets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_valid_triplets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;fraction_positive_triplets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_positive_triplets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_valid_triplets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Get final mean triplet loss over the positive valid triplets
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_positive_triplets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fraction_positive_triplets&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The implementation of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_get_triplet_mask&lt;/code&gt; is straightforward, so I will not detail it.&lt;/p&gt;

&lt;h3 id=&quot;batch-hard-strategy&quot;&gt;Batch hard strategy&lt;/h3&gt;

&lt;p&gt;In this strategy, we want to find the hardest positive and negative for each anchor.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Hardest positive&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;To compute the hardest positive, we begin with the pairwise distance matrix.
We then get a 2D mask of the valid pairs $(a, p)$ (i.e. $a \neq p$ and $a$ and $p$ have same labels) and put to $0$ any element outside of the mask.&lt;/p&gt;

&lt;p&gt;The last step is just to take the maximum distance over each row of this modified distance matrix. The result should be a valid pair $(a, p)$ since invalid elements are set to $0$.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Hardest negative&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;The hardest negative is similar but a bit trickier to compute.
Here we need to get the minimum distance for each row, so we cannot set to $0$ the invalid pairs $(a, n)$ (invalid if $a$ and $n$ have the same label).&lt;/p&gt;

&lt;p&gt;Our trick here is for each row to add the maximum value to the invalid pairs $(a, n)$.
We then take the minimum over each row.
The result should be a valid pair $(a, n)$ since invalid elements are set to the maximum value.&lt;/p&gt;

&lt;p&gt;&lt;br /&gt;
The final step is to combine these into the triplet loss:&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;maximum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hardest_positive_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hardest_negative_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;margin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Everything is implemented in function &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch_hard_triplet_loss&lt;/code&gt;:&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;batch_hard_triplet_loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;embeddings&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;margin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;squared&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;Build the triplet loss over a batch of embeddings.

    For each anchor, we get the hardest positive and hardest negative to form a triplet.

    Args:
        labels: labels of the batch, of size (batch_size,)
        embeddings: tensor of shape (batch_size, embed_dim)
        margin: margin for triplet loss
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.

    Returns:
        triplet_loss: scalar tensor containing the triplet loss
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Get the pairwise distance matrix
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;pairwise_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_pairwise_distances&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;embeddings&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;squared&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;squared&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# For each anchor, get the hardest positive
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# First, we need to get a mask for every valid positive (they should have same label)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;mask_anchor_positive&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_get_anchor_positive_triplet_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mask_anchor_positive&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to_float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask_anchor_positive&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;anchor_positive_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;multiply&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask_anchor_positive&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;pairwise_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# shape (batch_size, 1)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;hardest_positive_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;anchor_positive_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;keepdims&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# For each anchor, get the hardest negative
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# First, we need to get a mask for every valid negative (they should have different labels)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;mask_anchor_negative&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_get_anchor_negative_triplet_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mask_anchor_negative&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to_float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask_anchor_negative&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;max_anchor_negative_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pairwise_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;keepdims&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;anchor_negative_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;pairwise_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;max_anchor_negative_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask_anchor_negative&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# shape (batch_size,)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;hardest_negative_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;anchor_negative_dist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;keepdims&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;maximum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hardest_positive_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hardest_negative_dist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;margin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Get final mean triplet loss
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;triplet_loss&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;testing-our-implementation&quot;&gt;Testing our implementation&lt;/h3&gt;

&lt;p&gt;If you don’t trust that the implementation above works as expected, then you’re right!
The only way to make sure that there is no bug in the implementation is to write tests for every function in &lt;a href=&quot;https://github.com/omoindrot/tensorflow-triplet-loss/blob/master/model/triplet_loss.py&quot;&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;model/triplet_loss.py&lt;/code&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;This is especially important for tricky functions like this that are difficult to implement in TensorFlow but much easier to write using three nested for loops in python for instance.
The tests are written in &lt;a href=&quot;https://github.com/omoindrot/tensorflow-triplet-loss/blob/master/model/tests/test_triplet_loss.py&quot;&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;model/tests/test_triplet_loss.py&lt;/code&gt;&lt;/a&gt;, and compare the result of our TensorFlow implementation with the results of a simple numpy implementation.&lt;/p&gt;

&lt;p&gt;To check yourself that the tests pass, run:&lt;/p&gt;
&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;pytest model/tests/test_triplet_loss.py
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;p&gt;(or just &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;pytest&lt;/code&gt;)&lt;/p&gt;

&lt;p&gt;Here is a list of the tests performed:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;test_pairwise_distances()&lt;/code&gt;: compare results of numpy of tensorflow for pairwise distance&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;test_pairwise_distances_are_positive()&lt;/code&gt;: make sure that the resulting distance is positive&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;test_gradients_pairwise_distances()&lt;/code&gt;:  make sure that the gradients are not &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;nan&lt;/code&gt;&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;test_triplet_mask()&lt;/code&gt;: compare numpy and tensorflow implementations&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;test_anchor_positive_triplet_mask()&lt;/code&gt;: compare numpy and tensorflow implementations&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;test_anchor_negative_triplet_mask()&lt;/code&gt;: compare numpy and tensorflow implementations&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;test_simple_batch_all_triplet_loss()&lt;/code&gt;: simple test where there is just one type of label&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;test_batch_all_triplet_loss()&lt;/code&gt;: full test of batch all strategy (compares with numpy)&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;test_batch_hard_triplet_loss()&lt;/code&gt;: full test of batch hard strategy (compares with numpy)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;hr /&gt;

&lt;h2 id=&quot;experience-with-mnist&quot;&gt;Experience with MNIST&lt;/h2&gt;

&lt;p&gt;Even with the tests above, it is easy to oversee some mistakes.
For instance, at first I implemented the pairwise distance without checking that the input to the square root was strictly greater than $0$.
All the tests I had passed but the gradients during training were immediately &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;nan&lt;/code&gt;.
I therefore added &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;test_gradients_pairwise_distances&lt;/code&gt;, and corrected the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_pairwise_distances&lt;/code&gt; function.&lt;/p&gt;

&lt;p&gt;To make things simple, we will test the triplet loss on MNIST.
The code can be found &lt;a href=&quot;https://github.com/omoindrot/tensorflow-triplet-loss&quot;&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;To train and evaluate the model, do:&lt;/p&gt;
&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;python train.py &lt;span class=&quot;nt&quot;&gt;--model_dir&lt;/span&gt; experiments/base_model
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;This will launch a new experiment (i.e. a training run) named &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;base_model&lt;/code&gt;.
The model directory (containing weights, summaries…) is located in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;experiments/base_model&lt;/code&gt;.
Here we use a json file &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;experiments/base_model/params.json&lt;/code&gt; that specifies all the hyperparameters in the model.
This file must be created for any new experiment.&lt;/p&gt;

&lt;p&gt;Once training is complete (or as soon as some weights are saved in the model directory), we can visualize the embeddings using TensorBoard.
To do this, run:&lt;/p&gt;
&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;python visualize_embeddings.py &lt;span class=&quot;nt&quot;&gt;--model_dir&lt;/span&gt; experiments/base_model
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;And run TensorBoard in the experiment directory:&lt;/p&gt;
&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;tensorboard &lt;span class=&quot;nt&quot;&gt;--logdir&lt;/span&gt; experiments/base_model
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p align=&quot;center&quot;&gt;
&lt;img src=&quot;assets/triplet_loss/embeddings.gif&quot; /&gt;
&lt;/p&gt;

&lt;p align=&quot;center&quot;&gt;
&lt;i&gt;Embeddings of the MNIST test images visualized with T-SNE (perplexity 25)&lt;/i&gt;
&lt;/p&gt;

&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;These embeddings were run with the hyperparameters specified in the configuration file &lt;a href=&quot;https://github.com/omoindrot/tensorflow-triplet-loss/blob/master/experiments/base_model/params.json&quot;&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;experiments/base_model/params.json&lt;/code&gt;&lt;/a&gt;.
It’s pretty interesting to see which evaluation images get misclassified: a lot of them would surely be mistaken by humans too.&lt;/p&gt;

&lt;h2 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h2&gt;

&lt;p&gt;TensorFlow doesn’t make it easy to implement triplet loss, but with a bit of effort we can build a good-looking version of triplet loss with online mining.&lt;/p&gt;

&lt;p&gt;The tricky part is mostly how to compute efficiently the distances between embeddings, and how to mask out the invalid / easy triplets.&lt;/p&gt;

&lt;p&gt;Finally if you need to remember one thing: &lt;strong&gt;always test your code&lt;/strong&gt;, especially when it’s complex like triplet loss.&lt;/p&gt;

&lt;h2 id=&quot;resources&quot;&gt;Resources&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://github.com/omoindrot/tensorflow-triplet-loss&quot;&gt;Github repo&lt;/a&gt; for this blog post&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://arxiv.org/abs/1503.03832&quot;&gt;Facenet paper&lt;/a&gt; introducing online triplet mining&lt;/li&gt;
  &lt;li&gt;Detailed explanation of online triplet mining in &lt;a href=&quot;https://arxiv.org/abs/1703.07737&quot;&gt;&lt;em&gt;In Defense of the Triplet Loss for Person Re-Identification&lt;/em&gt;&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;Blog post by Brandon Amos on online triplet mining: &lt;a href=&quot;http://bamos.github.io/2016/01/19/openface-0.2.0/&quot;&gt;&lt;em&gt;OpenFace 0.2.0: Higher accuracy and halved execution time&lt;/em&gt;&lt;/a&gt;.&lt;/li&gt;
  &lt;li&gt;Source code for the built-in TensorFlow function for semi hard online mining triplet loss: &lt;a href=&quot;https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/triplet_semihard_loss&quot;&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tf.contrib.losses.metric_learning.triplet_semihard_loss&lt;/code&gt;&lt;/a&gt;.&lt;/li&gt;
  &lt;li&gt;The &lt;a href=&quot;https://www.coursera.org/learn/convolutional-neural-networks/lecture/HuUtN/triplet-loss&quot;&gt;coursera lecture&lt;/a&gt; on triplet loss&lt;/li&gt;
&lt;/ul&gt;</content><author><name>Olivier Moindrot</name></author><summary type="html">Triplet loss is known to be difficult to implement, especially if you add the constraints of TensorFlow.</summary></entry></feed>