<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="4.3.4">Jekyll</generator><link href="/feed.xml" rel="self" type="application/atom+xml" /><link href="/" rel="alternate" type="text/html" /><updated>2024-11-12T04:35:27+00:00</updated><id>/feed.xml</id><title type="html">Sven Elflein</title><author><name>Sven Elflein</name></author><entry><title type="html">A practical guide to Diffusion models</title><link href="/diffusion_practical_guide" rel="alternate" type="text/html" title="A practical guide to Diffusion models" /><published>2022-11-18T00:00:00+00:00</published><updated>2022-11-18T00:00:00+00:00</updated><id>/diffusion_practical_guide</id><content type="html" xml:base="/diffusion_practical_guide"><![CDATA[<p>The motivation of this blog post is to provide a intuition and a practical guide to train a (simple) diffusion model <a class="citation" href="#sohl2015deep">[Sohl-Dickstein et al. 2015]</a> together with the respective code leveraging PyTorch. If you are interested in a more mathematical description with proofs I can highly recommend <a class="citation" href="#luoUnderstandingDiffusionModels2022a">[Luo 2022]</a>.</p>

<h2 id="diffusion">Diffusion</h2>
<p>In general, the goal of a diffusion model is to be able to generate novel data after being trained on data points of that distribution.</p>

<p>Here, let’s consider a simple 2D toy dataset provided by <code class="language-plaintext highlighter-rouge">scikit-learn</code> to make this example as simple as possible:</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/diffusion_practical_guide_files/dataset.png" alt="Figure 1: Two Moons toy dataset used for our experiments." width="50%" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 1: Two Moons toy dataset used for our experiments.
    </figcaption>
</figure>

<p>Diffusion models define a forward and backward process:</p>

<ul>
  <li>the forward process gradually adds noise to the data until the original data is indistinguishable (one arrives at a standard normal distribution $N(0, \mathbf{I})$)</li>
  <li>the backward process aims to reverse the forward process, i.e., start from noise and then gradually tries to restore data</li>
</ul>

<p>To generate new samples by starting from random noise, one aims to learn the backward process.</p>

<p>To be able to start training a model that learns this backward process, we first need to know how to do the forward process.</p>

<p>The forward process adds noise at every step $t$ controlled by parameters \(\{\beta_t\}_{t=1, \dots, T}, \beta_{t-1} &lt; \beta_t, \beta_T = 1\):</p>

\[\begin{equation}
q(x_t \mid x_{t-1}) \sim \mathcal{N}(\sqrt{1 - \beta_t}x_{t-1}, \beta_t \mathbf{I})
\end{equation}\]

<p>As \(t \rightarrow T\) this distribution becomes a multi-variate Gaussian distribution \(\mathcal{N}(0, \mathbf{I})\).</p>

<p>So one starts with the original data samples $x_0$ and then gradually add noise to the samples:</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/diffusion_practical_guide_files/forward_diffusion.png" alt="Figure 2: Forward diffusion process that gradually adds noise." width="110%" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 2: Forward diffusion process that gradually adds noise.
    </figcaption>
</figure>

<p>The cool thing about this being Gaussian noise is that instead of simulating this forward process by iteratively sampling noise, one can derive a closed form for the distribution at a certain $t$ given the original data point $x_0$ so one has to only sample noise once:</p>

\[\begin{equation}
q(x_t \mid x_0) \sim \mathcal{N}(\sqrt{\bar{\alpha}}_t x_0, (1 - \bar{\alpha}_t)\mathbf{I})
\end{equation}\]

<p>with $\alpha_t = 1 - \beta_t$ and $\bar{\alpha}_t = \prod_{s = 1}^t \alpha_s$.</p>

<p>Let’s implement this:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ForwardProcess</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">betas</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">):</span>
        <span class="n">self</span><span class="p">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">betas</span>

        <span class="n">self</span><span class="p">.</span><span class="n">alphas</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="n">betas</span>
        <span class="n">self</span><span class="p">.</span><span class="n">alpha_bar</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">cumprod</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">alphas</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>


    <span class="k">def</span> <span class="nf">get_x_t</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x_0</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">LongTensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">]:</span>
        <span class="sh">"""</span><span class="s">Forward diffusion process given the unperturbed sample x_0.
        
        Args:
            x_0: Original, unperturbed samples.
            t: Target timestamp of the diffusion process of each sample.
        
        Returns:
            Noise added to original sample and perturbed sample.
        </span><span class="sh">"""</span>
        <span class="n">eps_0</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">randn_like</span><span class="p">(</span><span class="n">x_0</span><span class="p">).</span><span class="nf">to</span><span class="p">(</span><span class="n">x_0</span><span class="p">)</span>
        <span class="n">alpha_bar</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">alpha_bar</span><span class="p">[</span><span class="n">t</span><span class="p">,</span> <span class="bp">None</span><span class="p">]</span>
        <span class="n">mean</span> <span class="o">=</span> <span class="p">(</span><span class="n">alpha_bar</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="n">x_0</span>
        <span class="n">std</span> <span class="o">=</span> <span class="p">((</span><span class="mf">1.</span> <span class="o">-</span> <span class="n">alpha_bar</span><span class="p">)</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span>

        <span class="nf">return </span><span class="p">(</span><span class="n">eps_0</span><span class="p">,</span> <span class="n">mean</span> <span class="o">+</span> <span class="n">std</span> <span class="o">*</span> <span class="n">eps_0</span><span class="p">)</span>
</code></pre></div></div>

<h3 id="training">Training</h3>
<p>Next, we want to train a model that reverses that process.</p>

<p>For this, one can show that the there is also a closed form for the less noisy version $x_{t-1}$ given the next sample $x_t$ and the original sample $x_0$.</p>

\[\tag{1}\label{eq:reverse}
\begin{equation}
q(x_{t-1} \mid x_t, x_0) = \mathcal{N}(\mu(x_t, x_0), \sigma_t^2\mathbf{I})
\end{equation}\]

<p>where</p>

\[\begin{equation}
\sigma_t^2 = \frac{(1 - \alpha_t)(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t}, \quad \mu(x_t, x_0) = \frac{1}{\sqrt{\alpha_t}} \left(x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_0\right)
\end{equation}\]

<p>and $\epsilon_0 \sim \mathcal{N}(0, \mathbf{I})$ is the noise drawn to perturb the original data $x_0$<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>.</p>

<p>Obviously, we cannot use this directly to generate new data since this relies on knowing the original datapoint $x_0$ in the first place but <strong>we can use it to generate the ground truth data for training a model that does not rely on $\mathbf{x}_0$ and predicts $\epsilon_0$ from the noisy data $\mathbf{x}_t$ and $t$ alone</strong><sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup>.</p>

<p>Let’s define a small neural network $\epsilon_{\mathbf{\theta}}(\mathbf{x}_t, t)$ where $\mathbf{\theta}$ are the parameters of the network that does just that:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">NoisePredictor</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>

    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">T</span><span class="p">):</span>
        <span class="nf">super</span><span class="p">().</span><span class="nf">__init__</span><span class="p">()</span>
        <span class="n">self</span><span class="p">.</span><span class="n">T</span> <span class="o">=</span> <span class="n">T</span>
        <span class="n">self</span><span class="p">.</span><span class="n">t_encoder</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
        
        <span class="n">self</span><span class="p">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Sequential</span><span class="p">(</span>
            <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="mi">2</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">100</span><span class="p">),</span>   <span class="c1"># Input: Noisy data x_t and t
</span>            <span class="n">nn</span><span class="p">.</span><span class="nc">LeakyReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
            <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">),</span>
            <span class="n">nn</span><span class="p">.</span><span class="nc">LeakyReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
            <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">100</span><span class="p">),</span>
            <span class="n">nn</span><span class="p">.</span><span class="nc">LeakyReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
            <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">20</span><span class="p">),</span>
            <span class="n">nn</span><span class="p">.</span><span class="nc">LeakyReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
            <span class="c1"># Output: Predicted noise that was added to the original data point
</span>            <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span>
        <span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x_t</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
        <span class="c1"># Encode the time index t as one-hot and then use one layer to encode
</span>        <span class="c1"># into a single value
</span>        <span class="n">t_embedding</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">t_encoder</span><span class="p">(</span>
            <span class="n">nn</span><span class="p">.</span><span class="n">functional</span><span class="p">.</span><span class="nf">one_hot</span><span class="p">(</span><span class="n">t</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">self</span><span class="p">.</span><span class="n">T</span><span class="p">).</span><span class="nf">to</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nb">float</span><span class="p">)</span>
        <span class="p">)</span>
        
        <span class="n">inp</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">cat</span><span class="p">([</span><span class="n">x_t</span><span class="p">,</span> <span class="n">t_embedding</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">self</span><span class="p">.</span><span class="nf">model</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span>
</code></pre></div></div>

<p>Here, we encode the timestamp of the diffusion process $t$ as a one-hot vector with a single layer and then concatenate this information with the noisy data.</p>

<p><strong>Next up</strong>: Training the model to predict the noise. 
For this, one can just sample $t$’s, use the forward process to generate the noisy sample $x_t$ together with the noise $e_0$, and train the model to reduce the mean squared error between the predicted noise and $e_0$.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="nc">NoisePredictor</span><span class="p">(</span><span class="n">T</span><span class="o">=</span><span class="n">T</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="nc">AdamW</span><span class="p">(</span><span class="n">params</span><span class="o">=</span><span class="n">model</span><span class="p">.</span><span class="nf">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">),</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">)</span>

<span class="n">N</span> <span class="o">=</span> <span class="n">X</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nf">trange</span><span class="p">(</span><span class="mi">5000</span><span class="p">):</span>
    <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="nf">no_grad</span><span class="p">():</span>
        <span class="c1"># Sample random t's
</span>        <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">randint</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">T</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">N</span><span class="p">,))</span>

        <span class="c1"># Get the noise added and the noisy version of the data using the forward
</span>        <span class="c1"># process given t
</span>        <span class="n">eps_0</span><span class="p">,</span> <span class="n">x_t</span> <span class="o">=</span> <span class="n">fp</span><span class="p">.</span><span class="nf">get_x_t</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">t</span><span class="o">=</span><span class="n">t</span><span class="p">)</span>
    
    <span class="c1"># Predict the noise added to x_0 from x_t
</span>    <span class="n">pred_eps</span> <span class="o">=</span> <span class="nf">model</span><span class="p">(</span><span class="n">x_t</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>

    <span class="c1"># Simplified objective without weighting with alpha terms (Ho et al, 2020)
</span>    <span class="n">loss</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">functional</span><span class="p">.</span><span class="nf">mse_loss</span><span class="p">(</span><span class="n">pred_eps</span><span class="p">,</span> <span class="n">eps_0</span><span class="p">)</span>

    <span class="n">loss</span><span class="p">.</span><span class="nf">backward</span><span class="p">()</span>
    <span class="n">optimizer</span><span class="p">.</span><span class="nf">step</span><span class="p">()</span>
    <span class="n">optimizer</span><span class="p">.</span><span class="nf">zero_grad</span><span class="p">()</span>
</code></pre></div></div>

<h3 id="inference">Inference</h3>
<p>After training the model to predict the noise $\epsilon$, we can simply iteratively run the reverse process to predict $\mathbf{x}_{t-1}$ from $x_t$ starting from random noise $\mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I})$ as defined in \eqref{eq:reverse} where we set the mean:</p>

\[\begin{equation}
\mu(x_t) = \frac{1}{\sqrt{\alpha_t}} \left(x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_{\mathbf{\theta}}(\mathbf{x}_t, t) \right)
\end{equation}\]

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ReverseProcess</span><span class="p">(</span><span class="n">ForwardProcess</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">betas</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">model</span><span class="p">:</span> <span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
        <span class="nf">super</span><span class="p">().</span><span class="nf">__init__</span><span class="p">(</span><span class="n">betas</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">model</span> <span class="o">=</span> <span class="n">model</span>
        <span class="n">self</span><span class="p">.</span><span class="n">T</span> <span class="o">=</span> <span class="nf">len</span><span class="p">(</span><span class="n">betas</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>

        <span class="n">self</span><span class="p">.</span><span class="n">sigma</span> <span class="o">=</span> <span class="p">(</span>
            <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="n">alphas</span><span class="p">)</span>
            <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">torch</span><span class="p">.</span><span class="nf">roll</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">alpha_bar</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="n">alpha_bar</span><span class="p">)</span>
        <span class="p">)</span> <span class="o">**</span> <span class="mf">0.5</span>
        <span class="n">self</span><span class="p">.</span><span class="n">sigma</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.</span>
    
    <span class="k">def</span> <span class="nf">get_x_t_minus_one</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x_t</span><span class="p">:</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="p">.</span><span class="n">Tensor</span><span class="p">:</span>
        <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="nf">no_grad</span><span class="p">():</span>
            <span class="n">t_vector</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">full</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">x_t</span><span class="p">),),</span> <span class="n">fill_value</span><span class="o">=</span><span class="n">t</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">long</span><span class="p">)</span>
            <span class="n">eps</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">model</span><span class="p">(</span><span class="n">x_t</span><span class="p">,</span> <span class="n">t</span><span class="o">=</span><span class="n">t_vector</span><span class="p">)</span>
        
        <span class="n">eps</span> <span class="o">*=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="n">alphas</span><span class="p">[</span><span class="n">t</span><span class="p">])</span> <span class="o">/</span> <span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="n">alpha_bar</span><span class="p">[</span><span class="n">t</span><span class="p">])</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span>
        <span class="n">mean</span> <span class="o">=</span>  <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">alphas</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">x_t</span> <span class="o">-</span> <span class="n">eps</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">mean</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="n">sigma</span><span class="p">[</span><span class="n">t</span><span class="p">]</span> <span class="o">*</span> <span class="n">torch</span><span class="p">.</span><span class="nf">randn_like</span><span class="p">(</span><span class="n">x_t</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">n_samples</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">full_trajectory</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
        <span class="c1"># Initialize with X_T ~ N(0, I)
</span>        <span class="n">x_t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">randn</span><span class="p">(</span><span class="n">n_samples</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">trajectory</span> <span class="o">=</span> <span class="p">[</span><span class="n">x_t</span><span class="p">.</span><span class="nf">clone</span><span class="p">()]</span>
        
        <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">T</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
            <span class="n">x_t</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">get_x_t_minus_one</span><span class="p">(</span><span class="n">x_t</span><span class="p">,</span> <span class="n">t</span><span class="o">=</span><span class="n">t</span><span class="p">)</span>
            
            <span class="k">if</span> <span class="n">full_trajectory</span><span class="p">:</span>
                <span class="n">trajectory</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">x_t</span><span class="p">.</span><span class="nf">clone</span><span class="p">())</span>
        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="nf">stack</span><span class="p">(</span><span class="n">trajectory</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="k">if</span> <span class="n">full_trajectory</span> <span class="k">else</span> <span class="n">x_t</span>
</code></pre></div></div>

<p>Now, let’s sample new data points and plot them:</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/diffusion_practical_guide_files/new_samples.png" alt="Figure 3: New samples generated from the trained diffusion model." width="50%" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 3: New samples generated from the trained diffusion model.
    </figcaption>
</figure>

<p>We can also inspect the (negative) direction of the predicted noise vector at a particular timestamp $t$ for each position in a grid to visualize the dynamics a sample follows during the reverse process as a vector field:</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/diffusion_practical_guide_files/vectorfield.png" alt="Figure 4: Vector field describing reverse process dynamics at different timestamps. The blue line shows the trajectory of a sample during the reverse process." width="50%" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 4: Vector field describing reverse process dynamics at different timestamps. The blue line shows the trajectory of a sample during the reverse process.
    </figcaption>
</figure>

<p>One can see that as $t \rightarrow 0$ more fine-grained structure emerges that guides the sample to the original data manifold. At $t=T$ samples are guided coarsely towards the center as the signal is still very noisy and hard for the network to predict.</p>

<h2 id="insights">Insights</h2>

<p>Working on this small dataset already revealed some important things that one has to consider when training diffusion models.
In particular, in the beginning when I started to implement this from the paper description, a huge amount of diffusion steps ($T=1000$) were required to yield good results.</p>

<p>Further looking into the literature and appendix of the papers revealed some things that brought down the diffusion steps required to $T=10$:</p>
<ul>
  <li>It is important to perform linear scaling of the input data into the range $[-1, 1]$. Standardizing the input data (i.e., subtracting the mean and dividing by the standard dev.) as it is usually done for neural networks yielded worse results</li>
  <li>The variance schedule (${\beta_t}_t$) ideally has small changes towards $t=0$ such that the noise is not too much for the network to reconstruct, i.e., it learn fine-grained details of the data. This was already discovered in <a class="citation" href="#nichol2021improved">[Nichol and Dhariwal 2021]</a>, however, it is interesting to see that his insight can be shown from a toy dataset already instead of training expensive image models. Fig. 5 shows how the variance of the forward process $1 - \bar{\alpha}_t$ evolves for when $\beta_t$ is set linear (left), or polynomial (right). The right setting works much better in practice since the perturbation of the input does not happen too fast.</li>
</ul>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/diffusion_practical_guide_files/variance_schedule.png" alt="Figure 5: Different variance schedules for the diffusion process." width="80%" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 5: Different variance schedules for the diffusion process.
    </figcaption>
</figure>

<p>Check out the full notebook which this blog post is based on <a href="https://gist.github.com/selflein/9bee0818a48966179b18d577a89f792a">here</a>.</p>

<h2 id="references">References</h2>
<ol class="bibliography"><li><span id="sohl2015deep"><span style="font-variant: small-caps">Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N., and Ganguli, S.</span> 2015. Deep unsupervised learning using nonequilibrium thermodynamics. <i>International Conference on Machine Learning</i>, PMLR, 2256–2265.</span></li>
<li><span id="luoUnderstandingDiffusionModels2022a"><span style="font-variant: small-caps">Luo, C.</span> 2022. Understanding Diffusion Models: A Unified Perspective. .</span></li>
<li><span id="nichol2021improved"><span style="font-variant: small-caps">Nichol, A.Q. and Dhariwal, P.</span> 2021. Improved Denoising Diffusion Probabilistic Models. <i>International Conference on Machine Learning</i>, PMLR, 8162–8171.</span></li>
<li><span id="ho2020denoising"><span style="font-variant: small-caps">Ho, J., Jain, A., and Abbeel, P.</span> 2020. Denoising Diffusion Probabilistic Models. <i>Advances in Neural Information Processing Systems</i> <i>33</i>, 6840–6851.</span></li></ol>
<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:1" role="doc-endnote">
      <p>This is one possible parameterization of the mean that is most effective based on the experiments in <a class="citation" href="#ho2020denoising">[Ho et al. 2020]</a>. <a class="citation" href="#luoUnderstandingDiffusionModels2022a">[Luo 2022]</a> summarizes two other paramterizations in the literature, e.g., regressing the mean directly. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>Here we treat the variances as fixed. <a class="citation" href="#nichol2021improved">[Nichol and Dhariwal 2021]</a> propose to learn these with an additional objective. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
  </ol>
</div>]]></content><author><name>Sven Elflein</name></author><category term="Generative models" /><category term="Deep Learning" /><summary type="html"><![CDATA[The motivation of this blog post is to provide a intuition and a practical guide to train a (simple) diffusion model [Sohl-Dickstein et al. 2015] together with the respective code leveraging PyTorch. If you are interested in a more mathematical description with proofs I can highly recommend [Luo 2022].]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/diffusion_practical_guide_files/vectorfield_thumbnail.png" /><media:content medium="image" url="/diffusion_practical_guide_files/vectorfield_thumbnail.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Learning distributions on compact support using Normalizing Flows</title><link href="/normalizing_flow_bounded_domain" rel="alternate" type="text/html" title="Learning distributions on compact support using Normalizing Flows" /><published>2022-01-10T00:00:00+00:00</published><updated>2022-01-10T00:00:00+00:00</updated><id>/normalizing_flow_bounded_domain</id><content type="html" xml:base="/normalizing_flow_bounded_domain"><![CDATA[<p>Normalizing Flows <a class="citation" href="#pmlr-v37-rezende15">[Rezende and Mohamed 2015]</a> are powerful density estimators that have shown to be able to learn complex distributions, e.g., of natural images <a class="citation" href="#NEURIPS2018-d139db6a">[Kingma and Dhariwal 2018]</a>.</p>

<p>Recently, I was interested in learning a distribution on line segments which only has compact support, i.e., the support is not $\mathbb{R}$ but only defined on a compact interval $[a, b]$ along the line segment.</p>

<p>The vanilla formulation of Normalizing Flows <a class="citation" href="#pmlr-v37-rezende15">[Rezende and Mohamed 2015]</a> only considers distributions with support in $\mathbb{R}$, and a quick literature research did not yield any solutions to the problem. By dwelling on this problem for a bit, I came up with a solution by carefully applying invertible transformations.</p>

<h2 id="the-idea">The idea</h2>

<p>Consider a vanilla normalizing flow stacking a set of invertible and differentiable transformations $\{f_1, \dots, f_n \}$. After applying common transformations (e.g. radial <a class="citation" href="#pmlr-v37-rezende15">[Rezende and Mohamed 2015]</a> or affine coupling <a class="citation" href="#dinh2015nice">[Dinh et al. 2015]</a> transform) the support of the function is still $\mathbb{R}$. This is visualized in Fig. 1.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/normalizing_flow_bounded_domain_files/nf.png" alt="Figure 1: Common normalizing flow definition transforming a latent Normal distribution into a more complex, target distribution." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 1: Common normalizing flow definition transforming a latent Normal distribution into a more complex, target distribution.
    </figcaption>
</figure>

<p>Now, in order to obtain a distribution with compact support we require a function that is invertible, differentiable (in order to satisfy the constrains within normalizing flows), and additionally we want the function to have a compact co-domain. One such choice, is the logistic function</p>

\[f_{n+1}: \mathbb{R} \mapsto [0, 1], f_{n+1}(x) = \frac{1}{1 + e^{-x}}\]

<p>which is visualized as transformation in the first part of Fig. 2.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/normalizing_flow_bounded_domain_files/compact_transform.png" alt="Figure 2: Two additional transforms to squash the distribution into the [0, 1] interval and scale and move it afterwards." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 2: Two additional transforms to squash the distribution into the [0, 1] interval and scale and move it afterwards.
    </figcaption>
</figure>

<p>After applying the logistic function, we can use a simple affine transformation $f_{n+2}$ in order to move and scale the support $[0, 1]$ to our desired interval as shown in the second part of Fig. 2.</p>

<h2 id="action">Action!</h2>

<p>Next, we are going to use the <a href="https://pyro.ai/">Pyro</a> library which itself is based on PyTorch to implement our idea and test the implementation by learning a simple 1D distribution with compact support.</p>

<p>In order to be able to learn the parameters of the normalizing flow efficiently using maximum likelihood, we <strong>need to be able to evaluate the likelihood of individual samples of our dataset</strong>. Therefore, we are going to use the <em>inverse</em> parameterization which allows us to transform our training sample backwards through the transformation shown in Fig. 1 and 2 in order to evaluate the density of the sample in latent distribution.</p>

<p>Now let us define the model in code:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">NormalizingFlowDensity</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span>
        <span class="n">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">flow_length</span><span class="p">,</span> <span class="n">flow_type</span><span class="o">=</span><span class="sh">"</span><span class="s">radial_flow</span><span class="sh">"</span><span class="p">,</span> <span class="n">loc</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mi">1</span>
    <span class="p">):</span>
        <span class="nf">super</span><span class="p">(</span><span class="n">NormalizingFlowDensity</span><span class="p">,</span> <span class="n">self</span><span class="p">).</span><span class="nf">__init__</span><span class="p">()</span>
        <span class="n">self</span><span class="p">.</span><span class="n">dim</span> <span class="o">=</span> <span class="n">dim</span>
        <span class="n">self</span><span class="p">.</span><span class="n">flow_length</span> <span class="o">=</span> <span class="n">flow_length</span>
        <span class="n">self</span><span class="p">.</span><span class="n">flow_type</span> <span class="o">=</span> <span class="n">flow_type</span>

        <span class="n">self</span><span class="p">.</span><span class="n">mean</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nf">zeros</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">dim</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">cov</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nf">eye</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">dim</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        
        <span class="n">modules</span> <span class="o">=</span> <span class="p">[</span>
            <span class="c1"># Affine transformation of the [0, 1] interval 
</span>            <span class="nc">InvAffineTransformModule</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">loc</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="n">scale</span><span class="p">),</span>
            <span class="c1"># Squeeze R into [0, 1] interval
</span>            <span class="nc">InvSigmoidTransform</span><span class="p">()</span>
        <span class="p">]</span>
        <span class="k">if</span> <span class="n">self</span><span class="p">.</span><span class="n">flow_type</span> <span class="o">==</span> <span class="sh">"</span><span class="s">radial_flow</span><span class="sh">"</span><span class="p">:</span>
            <span class="n">self</span><span class="p">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">modules</span><span class="p">.</span><span class="nf">extend</span><span class="p">(</span>
                <span class="p">[</span><span class="nc">Radial</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">flow_length</span><span class="p">)]</span>
            <span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">raise</span> <span class="nb">NotImplementedError</span>
        
        <span class="n">self</span><span class="p">.</span><span class="n">transforms</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">ModuleList</span><span class="p">(</span><span class="n">modules</span><span class="p">)</span>
        
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
        <span class="n">sum_log_jacobians</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="k">for</span> <span class="n">transform</span> <span class="ow">in</span> <span class="n">self</span><span class="p">.</span><span class="n">transforms</span><span class="p">:</span>
            <span class="n">z_next</span> <span class="o">=</span> <span class="nf">transform</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
            <span class="n">sum_log_jacobians</span> <span class="o">+=</span> <span class="n">transform</span><span class="p">.</span><span class="nf">log_abs_det_jacobian</span><span class="p">(</span>
                <span class="n">z</span><span class="p">,</span> <span class="n">z_next</span>
            <span class="p">)</span>
            <span class="n">z</span> <span class="o">=</span> <span class="n">z_next</span>
        <span class="k">return</span> <span class="n">z</span><span class="p">,</span> <span class="n">sum_log_jacobians</span>

    <span class="k">def</span> <span class="nf">log_prob</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">z</span><span class="p">,</span> <span class="n">sum_log_jacobians</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">log_prob_z</span> <span class="o">=</span> <span class="n">tdist</span><span class="p">.</span><span class="nc">MultivariateNormal</span><span class="p">(</span>
            <span class="n">self</span><span class="p">.</span><span class="n">mean</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">cov</span><span class="p">).</span><span class="nf">log_prob</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
        <span class="n">log_prob_x</span> <span class="o">=</span> <span class="n">log_prob_z</span> <span class="o">+</span> <span class="n">sum_log_jacobians</span>
        <span class="k">return</span> <span class="n">log_prob_x</span>
</code></pre></div></div>

<p>Note that since we want to use <em>inverse</em> parameterization, we add the inverse of the transforms in reverse order into the list. Further, we can set the support of the distribution using the parameters <code class="language-plaintext highlighter-rouge">loc</code> and <code class="language-plaintext highlighter-rouge">scale</code>.</p>

<p>We can use the <code class="language-plaintext highlighter-rouge">log_prob</code> function, which takes a datapoint $x$, computes the inverse transformation and returns the log likelihood for that datapoint $\log p(x)$ for training.</p>

<p>Now, we consider the following 1D example distribution which is a piecewise uniform with support $[1.0, 2.5]$ shown in Fig. 3.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/normalizing_flow_bounded_domain_files/normalizing_flow_bounded_domain_11_1.png" alt="Figure 3: The target distribution we are aiming to learn on data." width="85%" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 3: The target distribution we are aiming to learn on data.
    </figcaption>
</figure>

<p>In order to learn the parameters $\theta$ of the normalizing flow, we can simply maximize the likelihood</p>

\[\arg\max_{\theta} p_{\theta}(x)\]

<p>which corresponds the following training code where we appropriately for the target distribution set the <code class="language-plaintext highlighter-rouge">loc</code> parameter to $1$ and the <code class="language-plaintext highlighter-rouge">scale</code> parameter to $1.5$.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">net</span> <span class="o">=</span> <span class="nc">NormalizingFlowDensity</span><span class="p">(</span>
    <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">flow_length</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">flow_type</span><span class="o">=</span><span class="sh">"</span><span class="s">radial_flow</span><span class="sh">"</span><span class="p">,</span> <span class="n">loc</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mf">1.5</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="nc">Adam</span><span class="p">(</span><span class="n">net</span><span class="p">.</span><span class="nf">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">5e-2</span><span class="p">)</span>

<span class="n">epochs</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">device</span> <span class="o">=</span> <span class="sh">"</span><span class="s">cpu</span><span class="sh">"</span>
<span class="n">net</span><span class="p">.</span><span class="nf">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>

<span class="n">epoch_iter</span> <span class="o">=</span> <span class="nf">trange</span><span class="p">(</span><span class="n">epochs</span><span class="p">)</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="n">epoch_iter</span><span class="p">:</span>
    <span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span>
        <span class="n">data</span> <span class="o">=</span> <span class="n">batch</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="nf">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
        <span class="n">log_prob</span> <span class="o">=</span> <span class="n">net</span><span class="p">.</span><span class="nf">log_prob</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>

        <span class="n">loss</span> <span class="o">=</span> <span class="o">-</span><span class="n">log_prob</span><span class="p">.</span><span class="nf">mean</span><span class="p">()</span>
        <span class="n">losses</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">loss</span><span class="p">.</span><span class="nf">item</span><span class="p">())</span>

        <span class="n">loss</span><span class="p">.</span><span class="nf">backward</span><span class="p">()</span>
        <span class="n">optimizer</span><span class="p">.</span><span class="nf">step</span><span class="p">()</span>
        <span class="n">optimizer</span><span class="p">.</span><span class="nf">zero_grad</span><span class="p">()</span>
    
    <span class="n">epoch_iter</span><span class="p">.</span><span class="nf">set_description</span><span class="p">(</span><span class="sa">f</span><span class="sh">"</span><span class="s">Loss: </span><span class="si">{</span><span class="n">np</span><span class="p">.</span><span class="nf">mean</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span><span class="si">:</span><span class="p">.</span><span class="mi">03</span><span class="n">f</span><span class="si">}</span><span class="sh">"</span><span class="p">)</span>

</code></pre></div></div>

<p>Finally, we can plot the learned density in Fig. 4. Note that the density is only defined in the interval $[1.0, 2.5]$, however points outside the interval evaluate to $0$ due to clamping by Pyro.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/normalizing_flow_bounded_domain_files/normalizing_flow_bounded_domain_16_1.png" alt="Figure 4: The distribution learned by our normalizing flow model." width="85%" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 4: The distribution learned by our normalizing flow model.
    </figcaption>
</figure>

<p>While the result is not perfect this gives a powerful framework for learning distributions on compact support. One can probably improve the result quite a bit by using more powerful transformations in the “base” flow (in fact one can use any existing invertible and differentiable transformation!) and by increasing the depth of the flow.</p>

<h2 id="conclusion">Conclusion</h2>
<p>Overall, we have seen how one can learn a distribution with compact support using normalizing flows by leveraging some simple transformations in the final layers and demonstrated some proof-of-concept results on a 1D toy example.</p>

<p>Feel free to check out the full notebook on which this blog post is based on <a href="https://gist.github.com/selflein/d8ff4b40142b5b8c4b32775fd04d8797">here</a>.</p>

<h2 id="references">References</h2>
<ol class="bibliography"><li><span id="pmlr-v37-rezende15"><span style="font-variant: small-caps">Rezende, D. and Mohamed, S.</span> 2015. Variational Inference with Normalizing Flows. <i>Proceedings of the 32nd International Conference on Machine Learning</i>, PMLR, 1530–1538.</span></li>
<li><span id="NEURIPS2018-d139db6a"><span style="font-variant: small-caps">Kingma, D.P. and Dhariwal, P.</span> 2018. Glow: Generative Flow with Invertible 1x1 Convolutions. <i>Advances in Neural Information Processing Systems</i>, Curran Associates, Inc.</span></li>
<li><span id="dinh2015nice"><span style="font-variant: small-caps">Dinh, L., Krueger, D., and Bengio, Y.</span> 2015. NICE: Non-linear Independent Components Estimation. .</span></li></ol>]]></content><author><name>Sven Elflein</name></author><category term="Density Estimation" /><category term="Deep Learning" /><summary type="html"><![CDATA[Normalizing Flows [Rezende and Mohamed 2015] are powerful density estimators that have shown to be able to learn complex distributions, e.g., of natural images [Kingma and Dhariwal 2018].]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/normalizing_flow_bounded_domain_files/embedded_compact_dist_thumb.png" /><media:content medium="image" url="/normalizing_flow_bounded_domain_files/embedded_compact_dist_thumb.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry><entry><title type="html">Why did my Neural Network do that?</title><link href="/restricting-the-flow-review" rel="alternate" type="text/html" title="Why did my Neural Network do that?" /><published>2020-08-12T00:00:00+00:00</published><updated>2020-08-12T00:00:00+00:00</updated><id>/restricting-the-flow-review</id><content type="html" xml:base="/restricting-the-flow-review"><![CDATA[<p>This is a blog post about the paper “Restricting the Flow: Information Bottlenecks for Attribution” by Karl Schulz, Leon Sixt, Federico Tombari and Tim Landgraf published at ICLR 2020.</p>

<h2 id="introduction">Introduction</h2>
<p>With the current trend to applying Neural Networks to more and more domains, the question on the explainability of these models is getting more attention. While more traditional machine learning approaches like decision trees and Random Forest incorporate some kind of interpretability based on the input features, todays Deep Neural Networks rely on higher dimensional embeddings hardly interpretable by a human. The line of research which can be grouped under the “Attribution” term therefore tries to relate the final output of a Neural Network back to its input by identifying the parts most relevant for decision of the model.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/intro-figure.png" alt="Figure 1: Attribution map obtained for classification network." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 1: Attribution map obtained for classification network.
    </figcaption>
</figure>

<p>The explainability of neural networks is especially relevant for safety-critical applications where security must be considered, e.g., in medical applications where one needs to make sure that the model is relying on features relevant for the classification of condition. Further, employing models in finance leads to a desire for justifiable predictions where one can attribute decision of the model to specific parts of the input data. Finally, attribution maps might help to identify cases where a model is basing its decision on features which do not generalize indicating a bias in the dataset. For example, a model focusing on the presence of rails as the most distinctive feature for classifying an object as a train. Obviously, this will not generalize to images of trains where rails are not present and can be uncovered using attribution methods highlighting the parts of the input the model focuses most on.</p>

<h2 id="background">Background</h2>
<p>In general, the method leverages information theory to measure the flow of information through the network thus enabling judgement of which features are the most relevant for the output of the model. Thus, a small introduction to the most important notions of information theory relevant for the understanding of the paper is given beforehand.</p>

<h3 id="entropy">Entropy</h3>
<p>The central concept of information theory is the notion of entropy defined as</p>

\[H(X)=-\mathbb{E}_{P_X}(\log P_X)\]

<p>where \(X\) is a random variable and \(P_X\) the corresponding probality distribution. Usually, entropy is introduced as a measure of the uncertainty of a random variable by being low if few values with high probability are generated and high when different values are generated with equal probability since then there is more uncertainty about the value the random variable is going to produce. The unit is typically bits when using the logarithm with base 2. Equivalently, it is a notion of information by considering the formal definition of entropy as the average number of bits needed to encode the values generated by the random variable.</p>

<h3 id="mutual-information">Mutual Information</h3>
<p>The mutual information is based on the concept of entropy and measures how much information two random variables have in common. It is defined as</p>

\[\begin{aligned}
I(X,Y) &amp;= H(X)-H(X|Y) \\
       &amp;= D_{KL}(P_{XY}(x, y) || P_X(x) P_Y(y))
\end{aligned}\]

<p>where \(X\), \(Y\) are random variables and \(P_X\), \(P_Y\) are the corresponding distributions. This intuitively aligns with the formula for the computation, as the entropy of  minus the entropy of \(X\) conditioned on \(Y\) and the Kullback-Leibler divergence between the joint and the product of the marginal distributions. An example explaining the concept can be found in Figure 2.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/conditional-entropy.png" alt="Figure 2: Example for the drop in uncertainty when observing a second variable indicating that the mutual information between both random variables is high. In other words, observing \(Y\) tells us something about the value \(X\) is going to take." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 2: Example for the drop in uncertainty when observing a second variable indicating that the mutual information between both random variables is high. In other words, observing \(Y\) tells us something about the value \(X\) is going to take.
    </figcaption>
</figure>

<h3 id="information-bottleneck">Information Bottleneck</h3>

<p>Let \(X\), \(Y\) be random variables, e.g., over images and labels, and \(Z\) is introduced in between. The information bottleneck then states the following optimization problem</p>

\[\max I(Y,Z) - \beta I(X,Z).\]

<p>This corresponds to minimizing the information  \(X\) and  \(Z\) share while maximizing the information  \(Z\) and  \(Y\) share (Figure 3). Overall this restricts the flow of information between  \(X\) and  \(Y\), and pushes  \(Z\) to “extract” the most important information of  \(X\) relevant to  \(Y\).</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/information-bottleneck-concept.png" alt="Figure 3: Conceptual sketch of the information bottleneck where \(Z\) is the random variable used to introduce the botteneck between \(X\) and \(Y\). \(I(Y,Z)\) is to be maximized and \(I(X,Z)\) to be minimized demonstrated by the thickness of the arrows between the random variables." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 3: Conceptual sketch of the information bottleneck where \(Z\) is the random variable used to introduce the botteneck between \(X\) and \(Y\). \(I(Y,Z)\) is to be maximized and \(I(X,Z)\) to be minimized demonstrated by the thickness of the arrows between the random variables.
    </figcaption>
</figure>

<h2 id="methodology">Methodology</h2>
<p>Now that we have the necessary background information to understand the paper let’s move to the actual integration of this information bottleneck within Neural Networks to obtain attribution maps.</p>

<h3 id="information-bottleneck-in-neural-network">Information Bottleneck in Neural Network</h3>
<p>In order to generate attribution maps the method adopts the information bottleneck for usage in general Neural Network architectures. For this a pretrained Neural Network is considered, e.g., for classification of images. The information bottleneck is then injected at some layer by perturbing the output features with noise, thus restricting the flow information through the network. This process is visualized for a Convolutional Neural Network in Figure 4.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/pipeline-overview.png" alt="Figure 4: Construction of the random variable \(Z\) in the information bottleneck within a Neural Network." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 4: Construction of the random variable \(Z\) in the information bottleneck within a Neural Network.
    </figcaption>
</figure>

<p>\(Z \in \mathbb{R}^{h \times w \times c}\) is constructed by convex combination of the output feature map and Gaussian noise \(\epsilon \sim N(\mu_R, \Sigma_R)\) where \(\mu_R, \Sigma_R\) are the estimated means and variances of the feature map. This combination is controlled by parameters \(\lambda \in \mathbb{R}^{h \times w \times c}\) of the same dimensionality as the feature map. Thus, for \(\lambda = 0\) the input to the consecutive layers becomes noise only while for \(\lambda = 1\) the feature maps are forwarded unperturbed. Note that since \(\lambda\) has the same dimensionality each feature in the feature maps can be blanked out individually which will allow us to estimate the relevance of that particular feature later on.</p>

<h3 id="optimizing-the-bottleneck">Optimizing the bottleneck</h3>

<p>With this the construction of the bottleneck within the Neural Network concludes and the question becomes how to optimize the information bottleneck \(\max I(Y,Z) - \beta I(X,Z)\).</p>

<p>Above equation splits into two mutual information terms to be optimized: The first one \(I(Y,Z)\) measures the mutual information between the noisy feature map and the output of the neural network. In the considered classification case maximizing \(I(Y,Z)\) corresponds to minimizing the cross-entropy loss \(L_{CE}(y, \hat{\tilde{y}})\) where \(\hat{\tilde{y}}\) is the prediction of the network using the perturbed feature map \(Z\) and \(y\) is the ground-truth label.</p>

<p>The second mutual information term \(I(X,Z)\) is in general intractable to compute exactly. This becomes clear when expanding the term using the definition of mutual information:</p>

\[I(R, Z) = \mathbb{E}(D_{KL}(P|Z) \Vert P(Z))\]

<p>where we can replace \(X\) with \(R\) since it is a function of \(X\). Then \(P(Z)\) has to be computed as</p>

\[P(Z)= \int_R p(z|r)p(z) dr\]

<p>where we need to integrate over the distribution of feature maps for which no analytic expression exists as it depends on the data distribution.</p>

<p>Thus, the paper proposes to approximate \(P(Z)\) using a variational approximation \(Q(Z)\) which is chosen to be a Normal distribution with mean \(\mu\) and diagonal covariance matrix \(\Sigma\). The authors justify the choice of variational distribution with the argument that the activations within a neural network a usually normal distributed. The derivation of the upper-bound can be found in the Appendix of the paper. With this, one can rewrite \(I(X,Z)\) in terms of \(Q(Z)\) as</p>

\[I(R, Z)= \mathbb{E}_R\left[ D_{KL}(P(Z|R) \Vert Q(Z) \right] - D_{KL}(P(Z) \Vert Q(Z)).\]

<p>As the KL-divergence is always \(\geqslant 0\), the first term</p>

\[L_I = \mathbb{E}_R\ [ D_{KL}(P(Z|R) \Vert Q(Z)]\]

<p>defines a upper bound on the actual mutual information term \(I(R, Z)\). The paper chooses to optimize \(L_I\) leading to a overestimation of the mutual information which means that features that are estimated to have zero information do not have an influence on the network predictions for sure while regions attributed non-zero relevance might actually not contribute to the network output.</p>

<p>The overall loss function therefore can be written as</p>

\[L = L_{CE} + \beta L_I\]

<p>consisting of the cross-entropy loss and the upper-bound derived above.</p>

<h3 id="obtaining-the-attribution-map">Obtaining the attribution map</h3>
<p>Remember that the noisy feature map \(Z \in \mathbb{R}^{h \times w \times c}\) is defined as \(Z = \lambda R + (1-\lambda) \epsilon\), where h and w is the spatial size of the feature map and c is the number of channels, and \(\lambda \in \mathbb{R}^{h \times w \times c}\) controls the flow of information. Therefore one wants to find \(\lambda\) optimizing the loss function \(L\). The paper proposes two methods aiming to solve this task: The per-sample bottleneck and the readout bottleneck.</p>

<h4 id="per-sample-bottleneck">Per-sample bottleneck</h4>
<p>The per-sample bottleneck independently operates on a single input image. To parameterize \(\lambda \in \mathbb{R}^{h \times w \times c}\) one employs a sigmoid activation on \(\alpha \in \mathbb{R}^{h \times w \times c}\) to ensure that the values of \(\lambda\) are in the required domain \([0, 1]\) and to avoid having to use clipping or projected gradient methods. Afterwards, blurring with a fixed Gaussian kernel is applied to prevent artifacts introduced by pooling operations within the network and ensure smoothness of the attribution mask (Figure 5). One then optimizes \(L\) by computing gradients w.r.t. and applying the Adam optimizer. While this allows the per-sample bottleneck to be integrated within an existing pre-trained network without additional training, this requires running the optimization \(\alpha\) of for every new input sample.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/per-sample-bottleneck.png" alt="Figure 5: Conceptual visualization of Per-sample bottleneck." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 5: Conceptual visualization of Per-sample bottleneck.
    </figcaption>
</figure>

<h4 id="readout-bottleneck">Readout bottleneck</h4>
<p>In the readout bottleneck a second neural network is trained to directly regress \(\alpha\) on the entire dataset while sigmoid activation and blurring remain identical to the per-sample bottleneck. The readout network takes multi-scale features from the pretrained network which are resized to the same size and processed using a series of 1x1 convolutions. It is visualized in Figure 6. This approach requires training a second network for the task, however, during inference no additional optimization is necessary.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/readout-bottleneck.png" alt="Figure 6: Conceptual visualization of Readout bottleneck." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 6: Conceptual visualization of Readout bottleneck.
    </figcaption>
</figure>

<p>Finally, after having obtained \(\lambda\) one can obtain the attribution map for each spatial position by summing over the channels:</p>

\[\sum_{i=1}^{c} D_{KL}\left(P(Z_{[x, y, i]} | R_{[x, y, i]} \Vert Q(Z_{[x, y, i]})\right).\]

<p>Thus, this results in a attribution heatmap of size \((h, w)\), i.e., the same spatial dimensions as the feature map where the bottleneck was inserted, and every spatial location contains the sum of the relevance of the features along the channel dimension in bits. The attribution map is then resized to the original image dimensions.</p>

<h2 id="results--discussion">Results &amp; Discussion</h2>
<p>For the experiments, two pretrained image classification networks, VGG-16 [2] and ResNet-50 [1], are chosen which cover a wide variety of architecture design choices, e.g., max-pooling, skip-connections, and low and high number of layers. The authors argue that this makes it less likely that the proposed method only works for a subset of neural network architectures. For evaluation the ImageNet dataset is used. In the per-sample bottleneck 10 iterations of the Adam optimizer with learning rate 1 are employed to optimize \(L\). To stabilize the optimization a single sample is copied 10 times and different noise is applied. Further, the hyperparameter \(\beta\) weighting both loss terms \(L_{CE}, L_I\) is set to \(\frac{w}{k}\) where \(k=hwc\) since \(L_I\) is significantly larger than \(L_{CE}\) as it sums over all dimensions. For evaluation of the per-sample bottleneck values of \(\beta \in \{ \frac{1}{k},  \frac{10}{k},  \frac{100}{k} \} \) are considered. The readout bottleneck is trained with the best performing setting of \(\beta = \frac{10}{k}\) on the ILSVRC12 dataset for 20 epochs only.</p>

<p>For baselines the following approaches are considered:</p>
<ul>
  <li>Naive Baselines
    <ul>
      <li>Random Attribution</li>
      <li>Occlusion with patch sizes 8x8 and 14x14 [3]</li>
      <li>Gradient maps [4]</li>
    </ul>
  </li>
  <li>Gradient-based methods
    <ul>
      <li>SmoothGrad [5]</li>
      <li>Integrated Gradients [6]</li>
    </ul>
  </li>
  <li>Modified propagation rules
    <ul>
      <li>PatternAttribution [7]</li>
      <li>GuidedBP [8]</li>
      <li>LRP [9]</li>
    </ul>
  </li>
  <li>Other
    <ul>
      <li>GradCAM [10]</li>
      <li>GuidedGrad-CAM [11]</li>
    </ul>
  </li>
</ul>

<h3 id="qualitative-results">Qualitative Results</h3>

<p>The paper first qualitatively studies the effect of inserting the bottleneck at different depths within the network which is visualized for VGG-16 in the bottom row of Figure 7. As feature map sizes shrink and the estimated attribution map has the same size as the feature map, the attribution map upscaled to the original image resolution becomes less and less localized at later layers. Secondly, qualitative results for different settings of \(\beta\) are visualized in the top row of Figure 7. One can observe that for higher values of \(\beta\) less information can pass through the bottleneck and thus the attribution is more localized on the relevant input regions. However, if \(\beta\) is set to high no information is let through and no attribution map can be obtained. Figure 8 shows a visual comparison of the attribution maps produced by different methods.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/comparision_hyperparameters.png" alt="Figure 7: Top row compares attribution maps for different settings of the hyperparameter \(\beta\). Bottom row shows comparison of attribution map when inserting the bottleneck at different depths. The unit of the scale on the right of each image is bits." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 7: Top row compares attribution maps for different settings of the hyperparameter \(\beta\). Bottom row shows comparison of attribution map when inserting the bottleneck at different depths. The unit of the scale on the right of each image is bits.
    </figcaption>
</figure>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/comparison_methods.png" alt="Figure 8: Comparison of attribution maps produced by baselines and the proposed approach with per-sample and readout bottleneck." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 8: Comparison of attribution maps produced by baselines and the proposed approach with per-sample and readout bottleneck.
    </figcaption>
</figure>

<h3 id="quantitative-results">Quantitative Results</h3>
<p>As there is no standard evaluation metric for attribution methods the paper relies on a combination of existing metrics and proposes two new ones with the image degradation score and bounding box evaluation. The main objective of these metrics is to measure how well the attribution maps generated by a attribution method explain the neural network output. This in means that the attribution output should depend on the model and attribution maps should change when the network parameters change.</p>

<h4 id="sanity-check">Sanity Check</h4>
<p>In order to verify that the attribution maps produced by a method depend on the model parameters one employs the following sanity check: One gradually re-initializes the weights of the neural network from last layer to first layer and measures how much the attribution map is changing using SSIM [14] when considering the partly re-initialized network. For a valid attribution method, a change in attribution is expected when re-initializing parts of the network. As one can see in the top row of Figure 8 the proposed method passes the sanity check.</p>

<h4 id="sensitivity-n">Sensitivity-N</h4>
<p>Sensitivity-N measures the correlation between the removed attribution mass and drop in performance of the network when randomly masking N pixels of the input image (as demonstrated conceptually in Figure 9). Intuitively, if a attribution method is performing well it should assign high attribution scores to regions relevant for the final output and when removing these parts the network performance should drop more heavily than when removing inputs with low attribution. The results for Sensitivity-N for different amounts of removed pixels and tile sizes can be found in the bottom row of Figure 10. Since when only masking tiles of size 1x1 at a time does not provide distinguishable results, the authors also add results when removing tiles of size 8x8. Here the proposed methods outperform the baselines when more than 2% of pixels are masked. The high performance of the Occlusion baseline for small amounts of replaced pixels is a result for this method exactly using the strategy of replacing pixels of a certain small tile size and measuring the drop in performance to asses the attribution score of this region.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/sensitivity-n.png" alt="Figure 9: Example of two removed tiles from a input image on the left and the respective sum of attribution removed versus the change in logit score of the target class plotted on the right. Sensitivity-N measures the correlation between both quantities when removing tiles with size of in total N pixels." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 9: Example of two removed tiles from a input image on the left and the respective sum of attribution removed versus the change in logit score of the target class plotted on the right. Sensitivity-N measures the correlation between both quantities when removing tiles with size of in total N pixels.
    </figcaption>
</figure>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/paper-results-sanity-check-sens-n.png" alt="Figure 10: Top row shows results of the sanity check for ResNet-50 and VGG-16. The bottom row plots Sensitivity-N results for increasing amounts of replaced pixels with tile sizes of 1x1 and 8x8." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 10: Top row shows results of the sanity check for ResNet-50 and VGG-16. The bottom row plots Sensitivity-N results for increasing amounts of replaced pixels with tile sizes of 1x1 and 8x8.
    </figcaption>
</figure>

<h4 id="bounding-box">Bounding Box</h4>
<p>As an additional metric the paper proposes to use bounding boxes provided with parts of the dataset. The idea is that the highest attributed pixels should be within a bounding box as it delimits the relevant region for the object to classify within the image. The metric is computed by counting the \(n)\ highest attributed pixels in the image where  \(n)\ is the number of pixels within a bounding box and then calculating the ratio between the number of these pixels within and outside the bounding box. Thus, one obtains a score between 1 and 0 where 1 indicates highly localized attribution around the relevant object in the image. The proposed method is outperforming other methods by a large margin as it can be seen from the third column in Table 1. Again, the per-sample bottleneck is outperforming the readout bottleneck.</p>

<h4 id="image-degradation">Image Degradation</h4>
<p>Image degradation metrics works in a similar way to Sensitivity-N however it takes a more structured approach in removing tiles from the image. The MoRF (Most Relevant First) curve (Figure 11) is computed by replacing tiles with the highest attribution first while monitoring the target class score. Ideally, the target class score reduces strongly in the beginning since highest attributed regions are removed first.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/merf.png" alt="Figure 11: Visualization of the unnormalized MoRF curve on a single, gradually more masked input image on the left starting from the &lt;b&gt;most&lt;/b&gt; attributed regions. A fast decay is expected for a sound attribution method." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 11: Visualization of the unnormalized MoRF curve on a single, gradually more masked input image on the left starting from the <b>most</b> attributed regions. A fast decay is expected for a sound attribution method.
    </figcaption>
</figure>

<p>However, it is prone to produce out of distribution samples so a drop in performance may not be related to a relevant input region being removed. To tackle this problem, the LeRF (Least Relevant First) curve (Figure 12) takes the opposite direction and removes the tiles with the lowest attribution scores first. Here, the optimal result would be a slow decrease in target class score since regions deemed unimported by the attribution method are removed first.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/lerf.png" alt="Figure 12: Visualization of the unnormalized LeRF curve on a single, gradually more masked input image on the left starting from the &lt;b&gt;least&lt;/b&gt; attributed regions. A slow decay is expected for a sound attribution method." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 12: Visualization of the unnormalized LeRF curve on a single, gradually more masked input image on the left starting from the <b>least</b> attributed regions. A slow decay is expected for a sound attribution method.
    </figcaption>
</figure>

<p>Based on both scores the paper introduces the degradation score which is the area between MoRF and LeRF curve visualized in Figure 13 and gives single scalar metric which is reported in Table 1. The approach is able to outperform the baselines on this metric for VGG-16 and is competitive for the ResNet-50 model. The readout bottleneck is performing worse than the per-sample bottleneck.</p>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/lerf-morf-curve.png" alt="Figure 13: Visualization of LeRF and MoRF curve as well as the proposed degradation score as the area between both curves." width="50%" class="centerimg" />
    <figcaption style="text-align: center;">
        Figure 13: Visualization of LeRF and MoRF curve as well as the proposed degradation score as the area between both curves.
    </figcaption>
</figure>

<figure class="image" style="padding: 0 0.5em 0 0.5em;">
    <img src="/assets/img/restricting-the-flow-review/table-comparison.png" alt="Table 1: Evaluation results for image degradation metrics on ResNet-50 and VGG-16 for tile sizes of 8x8 and 14x14 in the first two columns and bounding box scores in third column." width="" class="centerimg" />
    <figcaption style="text-align: center;">
        Table 1: Evaluation results for image degradation metrics on ResNet-50 and VGG-16 for tile sizes of 8x8 and 14x14 in the first two columns and bounding box scores in third column.
    </figcaption>
</figure>

<h2 id="conclusion">Conclusion</h2>
<p>The paper proposes to leverage information theory in order to find the most informative regions in the input resulting in the output of the model. For this the information bottleneck is adopted and integrated into a Neural Network at a particular layer. Two options are proposed to parameterize the bottleneck: The per-sample bottleneck requires optimization for each input sample but can be employed on an existing pretrained neural network without additional training. For the readout bottleneck requires training a second network on the entire dataset, however at inference time only a single forward pass is necessary to obtain the attribution map. The authors advise usage of the per-sample bottleneck as it outperforms the readout bottleneck in the evaluation and is more flexible. Overall, the method provides a guarantee that regions which are attributed zero relevance are not required for the network classification and measures the information of regions in the well-established unit of bits. In comparison to other attribution method it does not make assumptions on the architecture or activations of the neural network as well.</p>

<h2 id="own-review">Own Review</h2>
<p>Overall, the paper is very well written, and the derivation of the method is easy to follow. Further, the approach is based on information theory which is well studied and understood which helps with the credibility of the produced attribution maps where relevance of features is measured in bits. In the case of the per-sample bottleneck the approach is flexible and can be integrated into any existing pretrained neural network. For this, the authors have even created a library for easy usage of the approach within existing PyTorch and Tensorflow code [12]. The code for the experiments is open-source as well [13]. Finally, the paper provides an extensive comparison with existing methods on a variety of metrics found in existing literature and proposes new metrics on top.</p>

<p>However, there are also some issues with the paper. Firstly, the choice of distribution in the variational approximation as an Gaussian distribution with diagonal covariance metrics makes the assumption that features within the feature map are independently distributed. This is a strong assumption for features produced by convolutional filters with low stride and might lead to a poor variational approximation of the true distribution. A poor approximation in turn leads to a poor upper-bound on the mutual information and regions which might actually be irrelevant for the decision of the network could be attributed relevance. As a result, the produced attribution maps still have to be handled with care in regions with high attribution scores. However, the claim of the paper that regions which are attributed zero information are irrelevant for the decision of the model remains valid. Secondly, the proposed bounding box evaluation does not measure how well a attribution method is explaining the network the metric does not depend on the network at all. One example could be a method which is trained to produced bounding boxes and randomly distributions attribution scores within the bounding box. Such a method would perform well on the proposed metric but does not explain the network at all. As a final negative point, the evaluation of the sanity check is missing to retrain the readout bottleneck after re-initializing parts of the base network. This is required as the readout network depends on the model parameters so the obtained values in the sanity check are of limited expressiveness for the readout bottleneck without retraining.</p>

<p>Lastly, one interesting open question would be whether the proposed method generalizes to other tasks, e.g., sematic segmentation, and input domains, e.g., text, as only image classification is evaluated in the paper.</p>

<p>This concludes my first blog post (yay! 🥳). I hope you found it interesting and if there are any questions/ discussion points left feel free to comment at the bottom or message me directly! Thanks for reading!</p>

<h2 id="references">References</h2>
<p>[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, 2016.<br />
[2] Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale imagerecognition.arXiv preprint arXiv:1409.1556, 2014.<br />
[3] Ancona, M &amp; Ceolini, E &amp; Öztireli, C &amp; Gross, M, Towards better understanding of Gradient-based Attribution Methods for Deep Neural Networks, ICLR, 2018.<br />
[4] David Baehrens, Timon Schroeter, Stefan Harmeling, Motoaki Kawanabe, Katja Hansen, and Klaus-Robert Müller. How to explain individual classification decisions. Journal of Machine Learning Research, 2010.<br />
[5] Daniel Smilkov, Nikhil Thorat, Been Kim, Fernanda Vi ́egas, and Martin Wattenberg. Smoothgrad: removing noise by adding noise. arXiv:1706.03825 [cs, stat], 2017.<br />
[6] Mukund Sundararajan, Ankur Taly, and Qiqi Yan. Axiomatic attribution for deep networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, 2017.<br />
[7] Pieter-Jan Kindermans, Kristof T. Schütt, Maximilian Alber, Klaus-Robert Müller, Dumitru Er-han, Been Kim, and Sven Dähne. Learning how to explain neural networks: Patternnet and pattern attribution. In International Conference on Learning Representations, 2018.<br />
[8] Jost Tobias Springenberg, Alexey Dosovitskiy, Thomas Brox, and Martin Riedmiller. Striving for Simplicity: The All Convolutional Net. arXiv e-prints, 2014.<br />
[9] Sebastian Bach, Alexander Binder, Gr ́egoire Montavon, Frederick Klauschen, Klaus-Robert Müller,and Wojciech Samek. On pixel-wise explanations for non-linear classifier decisions by layer-wiserelevance propagation. PLoS ONE, 2015.<br />
[10] Ramprasaath R Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh,and Dhruv Batra. Grad-cam: Visual explanations from deep networks via gradient-based local-ization. In Proceedings of the IEEE International Conference on Computer Vision, 2017.<br />
[11] Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. Why should i trust you?: Explaining thepredictions of any classifier. InProceedings of the 22nd ACM SIGKDD international conferenceon knowledge discovery and data mining, ACM, 2016.<br />
[12] https://github.com/BioroboticsLab/IBA<br />
[13] https://github.com/BioroboticsLab/IBA-paper-code<br />
[14] Zhou Wang, Alan C Bovik, Hamid R Sheikh, Eero P Simoncelli, et al. Image quality assessment: From error visibility to structural similarity. IEEE transactions on image processing, 2004.</p>]]></content><author><name>Sven Elflein</name></author><category term="Explainability" /><category term="Attribution methods" /><category term="Deep Learning" /><summary type="html"><![CDATA[This is a blog post about the paper “Restricting the Flow: Information Bottlenecks for Attribution” by Karl Schulz, Leon Sixt, Federico Tombari and Tim Landgraf published at ICLR 2020. Introduction With the current trend to applying Neural Networks to more and more domains, the question on the explainability of these models is getting more attention. While more traditional machine learning approaches like decision trees and Random Forest incorporate some kind of interpretability based on the input features, todays Deep Neural Networks rely on higher dimensional embeddings hardly interpretable by a human. The line of research which can be grouped under the “Attribution” term therefore tries to relate the final output of a Neural Network back to its input by identifying the parts most relevant for decision of the model.]]></summary><media:thumbnail xmlns:media="http://search.yahoo.com/mrss/" url="/restricting-the-flow-review/thumb2.png" /><media:content medium="image" url="/restricting-the-flow-review/thumb2.png" xmlns:media="http://search.yahoo.com/mrss/" /></entry></feed>