<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://iamycy.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://iamycy.github.io/" rel="alternate" type="text/html" /><updated>2026-03-04T02:09:31+00:00</updated><id>https://iamycy.github.io/feed.xml</id><title type="html">YCY / Kazaneya</title><subtitle>Researcher &amp; Software/Algorithm Engineer</subtitle><author><name>Chin-Yun Yu</name><email>chin-yun.yu@qmul.ac.uk</email></author><entry><title type="html">Block-based Fast Differentiable IIR in PyTorch</title><link href="https://iamycy.github.io/posts/2025/06/28/unroll-ssm/" rel="alternate" type="text/html" title="Block-based Fast Differentiable IIR in PyTorch" /><published>2025-06-28T00:00:00+00:00</published><updated>2025-06-28T00:00:00+00:00</updated><id>https://iamycy.github.io/posts/2025/06/28/unroll-ssm</id><content type="html" xml:base="https://iamycy.github.io/posts/2025/06/28/unroll-ssm/"><![CDATA[<p>I recently came across a presentation by Andres Ezequiel Viso from GPU Audio at ADC 2022, in which he talked about how they accelerate IIR filters on the GPU.
The approach they use is to formulate the IIR filter as a state-space model (SSM) and augment the transition matrix so that each step processes multiple samples at once.
The primary speedup stems from the fact that GPUs are very good at performing large matrix multiplications, and the SSM formulation enables us to leverage this capability.</p>

<iframe width="1024px" height="576px" src="https://www.youtube.com/embed/UmYnoFo0Bb8?start=1356" allowfullscreen="">
</iframe>
<p><br /></p>

<p>Speeding up IIR filters while maintaining differentiability has always been my interest.
The most recent method I worked on is from my recent <a href="https://arxiv.org/abs/2504.14735">submission</a> to DAFx 25, where my co-author Ben proposed using parallel associative scan to speed up the recursion on the GPU.
Nevertheless, since PyTorch does not have a built-in associative scan operator (in contrast to JAX), we must implement custom kernels for it, which is non-trivial.
It also requires that the filter has distinct poles so that the state-space transition matrix is diagonalisable.
The method that GPU Audio presented appears to be feasible solely using the PyTorch Python API and doesn’t have the restrictions I mentioned; thus, I decided to benchmark it and see how it performs.</p>

<p>Since it’s just a proof of concept, the filter I’m going to test is a <strong>time-invariant all-pole IIR filter</strong>, which is the minimal case of a recursive filter.
This allows us to leverage some special optimisations that won’t work with time-varying general IIR filters, but that won’t affect the main idea I’m going to present here.</p>

<h2 id="naive-implementation-of-an-all-pole-iir-filter">Naive implementation of an all-pole IIR filter</h2>

<p>The difference equation of an \(M\)-th order all-pole IIR filter is given by:</p>

\[y[n] = x[n] -\sum_{m=1}^{M} a_m y[n-m].\]

<p>Let’s implement this in PyTorch:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>

<span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">jit</span><span class="p">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">naive_allpole</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
    <span class="s">"""
    Naive all-pole filter implementation.
    
    Args:
        x (Tensor): Input signal.
        a (Tensor): All-pole coefficients.
        
    Returns:
        Tensor: Filtered output signal.
    """</span>
    <span class="k">assert</span> <span class="n">x</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s">"Input signal must be a 2D tensor (batch_size, signal_length)"</span>
    <span class="k">assert</span> <span class="n">a</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s">"All-pole coefficients must be a 1D tensor"</span>

    <span class="c1"># list to store output at each time step
</span>    <span class="n">output</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="c1"># assume initial condition is zero
</span>    <span class="n">zi</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>

    <span class="k">for</span> <span class="n">xt</span> <span class="ow">in</span> <span class="n">x</span><span class="p">.</span><span class="n">unbind</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span>
        <span class="c1"># use addmv for efficient matrix-vector multiplication
</span>        <span class="n">yt</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">addmv</span><span class="p">(</span><span class="n">xt</span><span class="p">,</span> <span class="n">zi</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=-</span><span class="mf">1.0</span><span class="p">)</span>
        <span class="n">output</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">yt</span><span class="p">)</span>

        <span class="c1"># update the state for the next time step
</span>        <span class="n">zi</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">yt</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">zi</span><span class="p">[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<p>In this implementation, I didn’t use any in-place operations for speedup since it would break the differentiability of the function.
This naive implementation is not very efficient, as <code class="language-plaintext highlighter-rouge">torch.addmv</code> and <code class="language-plaintext highlighter-rouge">torch.cat</code> are called at each time step. 
Typically, the audio signal is hundreds of thousands of samples long, resulting in a significant amount of function call overhead.
For details, please take a look at my <a href="https://intro2ddsp.github.io/filters/iir_torch.html">tutorial on differentiable IIR filters</a> at ISMIR 2023.</p>

<p>Notice that I used <code class="language-plaintext highlighter-rouge">torch.jit.script</code> to compile the function for some slight speedup.
I tried the newer compilation feature <code class="language-plaintext highlighter-rouge">torch.compile</code>, but it didn’t work.
The compilation hangs forever, I don’t know why…
I never found <code class="language-plaintext highlighter-rouge">torch.compile</code> to be useful in my research projects, and <code class="language-plaintext highlighter-rouge">torch.jit.*</code> has proven to be way more reliable.</p>

<p>Let’s benchmark its speed on my Ubuntu with an Intel i7-7700K.
We’ll use a batch size of 8, a signal length of 16384, and \(M=2\), which is a reasonable setting for audio processing.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">torch.utils.benchmark</span> <span class="kn">import</span> <span class="n">Timer</span>

<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">signal_length</span> <span class="o">=</span> <span class="mi">16384</span>
<span class="n">order</span> <span class="o">=</span> <span class="mi">2</span>

<span class="k">def</span> <span class="nf">order2a</span><span class="p">(</span><span class="n">order</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
    <span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">order</span><span class="p">)</span>
    <span class="c1"># simple way to ensure stability
</span>    <span class="n">a</span> <span class="o">=</span> <span class="n">a</span> <span class="o">/</span> <span class="n">a</span><span class="p">.</span><span class="nb">abs</span><span class="p">().</span><span class="nb">sum</span><span class="p">()</span>
    <span class="k">return</span> <span class="n">a</span>

<span class="n">a</span> <span class="o">=</span> <span class="n">order2a</span><span class="p">(</span><span class="n">order</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">signal_length</span><span class="p">)</span>

<span class="n">naive_allpole_t</span> <span class="o">=</span> <span class="n">Timer</span><span class="p">(</span>
    <span class="n">stmt</span><span class="o">=</span><span class="s">"naive_allpole(x, a)"</span><span class="p">,</span>
    <span class="nb">globals</span><span class="o">=</span><span class="p">{</span><span class="s">"naive_allpole"</span><span class="p">:</span> <span class="n">naive_allpole</span><span class="p">,</span> <span class="s">"x"</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="s">"a"</span><span class="p">:</span> <span class="n">a</span><span class="p">},</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"naive_allpole"</span><span class="p">,</span>
    <span class="n">description</span><span class="o">=</span><span class="s">"Naive All-Pole Filter"</span><span class="p">,</span>
    <span class="n">num_threads</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">naive_allpole_t</span><span class="p">.</span><span class="n">blocked_autorange</span><span class="p">(</span><span class="n">min_run_time</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-shell highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;torch.utils.benchmark.utils.common.Measurement object at 0x7f5b4423b260&gt;
naive_allpole
Naive All-Pole Filter
  Median: 168.93 ms
  IQR:    0.54 ms <span class="o">(</span>168.57 to 169.11<span class="o">)</span>
  6 measurements, 1 runs per measurement, 4 threads
</code></pre></div></div>

<p>168.93 ms is relatively slow, but it is expected.</p>

<h2 id="state-space-model-formulation">State-space model formulation</h2>

<p>Before we proceed to showing the sample unrolling trick, let’s first introduce the state-space model (SSM) formulation of the all-pole IIR filter.
The model is similar to the one used in my previous blogpost on <a href="https://iamycy.github.io/posts/2025/04/differentiable-tdf-ii/">TDF-II filter</a>:</p>

\[\begin{align}
\mathbf{h}[n] &amp;= \begin{bmatrix}
    -a_1 &amp; -a_2 &amp; \cdots &amp; -a_{M-1} &amp; -a_M \\
    1 &amp; 0 &amp;\cdots &amp; 0 &amp; 0 \\
    0 &amp; 1 &amp; \cdots &amp; 0 &amp; 0 \\
    \vdots &amp;  \vdots &amp; \ddots &amp; \vdots &amp; \vdots \\
    0 &amp; 0 &amp; \cdots &amp; 1 &amp; 0 \\
\end{bmatrix} \mathbf{h}[n-1] + \begin{bmatrix}
    1 \\
    0 \\
    0 \\
    \vdots \\
    0 \\
\end{bmatrix} x[n] \\
&amp;= \mathbf{A} \mathbf{h}[n-1] + \mathbf{B} x[n] \\

y[n] &amp;= \mathbf{B}^\top \mathbf{h}[n].
\end{align}\]

<p>Here, I simplified the original SSM by omitting the direct path, as it can be derived from the state vector (for the all-pole filter only).
Below is the PyTorch implementation of it:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">jit</span><span class="p">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">a2companion</span><span class="p">(</span><span class="n">a</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
    <span class="s">"""
    Convert all-pole coefficients to a companion matrix.

    Args:
        a (Tensor): All-pole coefficients.

    Returns:
        Tensor: Companion matrix.
    """</span>
    <span class="k">assert</span> <span class="n">a</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s">"All-pole coefficients must be a 1D tensor"</span>
    <span class="n">order</span> <span class="o">=</span> <span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">diag</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">new_ones</span><span class="p">(</span><span class="n">order</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">c</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="o">-</span><span class="n">a</span>
    <span class="k">return</span> <span class="n">c</span>

<span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">jit</span><span class="p">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">state_space_allpole</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
    <span class="s">"""
    State-space implementation of all-pole filtering.

    Args:
        x (Tensor): Input signal.
        a (Tensor): All-pole coefficients.

    Returns:
        Tensor: Filtered output signal.
    """</span>
    <span class="k">assert</span> <span class="n">x</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s">"Input signal must be a 2D tensor (batch_size, signal_length)"</span>
    <span class="k">assert</span> <span class="n">a</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s">"All-pole coefficients must be a 1D tensor"</span>

    <span class="n">c</span> <span class="o">=</span> <span class="n">a2companion</span><span class="p">(</span><span class="n">a</span><span class="p">).</span><span class="n">T</span>

    <span class="n">output</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="c1"># assume initial condition is zero
</span>    <span class="n">h</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>

    <span class="c1"># B * x
</span>    <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">(</span>
        <span class="p">[</span><span class="n">x</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="p">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span>
    <span class="p">)</span>

    <span class="k">for</span> <span class="n">xt</span> <span class="ow">in</span> <span class="n">x</span><span class="p">.</span><span class="n">unbind</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span>
        <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">addmm</span><span class="p">(</span><span class="n">xt</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
        <span class="c1"># B^T @ h
</span>        <span class="n">output</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">h</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
    <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">stack</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<p><code class="language-plaintext highlighter-rouge">a2companion</code> converts the all-pole coefficients to a <a href="https://en.wikipedia.org/wiki/Companion_matrix">companion matrix</a>, which is \(\mathbf{A}\) in the SSM formulation.</p>

<p>Before we benchmark the speed of this implementation, let’s predict how fast it will be.
Intuitively, since the complexity of vector-dot product is \(O(M)\) and matrix-vector multiplication is \(O(M^2)\), the SSM implementation uses more computational resources, so it should be slower than the naive implementation.
Let’s benchmark its speed:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">state_space_allpole_t</span> <span class="o">=</span> <span class="n">Timer</span><span class="p">(</span>
    <span class="n">stmt</span><span class="o">=</span><span class="s">"state_space_allpole(x, a)"</span><span class="p">,</span>
    <span class="nb">globals</span><span class="o">=</span><span class="p">{</span><span class="s">"state_space_allpole"</span><span class="p">:</span> <span class="n">state_space_allpole</span><span class="p">,</span> <span class="s">"x"</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="s">"a"</span><span class="p">:</span> <span class="n">a</span><span class="p">},</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"state_space_allpole"</span><span class="p">,</span>
    <span class="n">description</span><span class="o">=</span><span class="s">"State-Space All-Pole Filter"</span><span class="p">,</span>
    <span class="n">num_threads</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">state_space_allpole_t</span><span class="p">.</span><span class="n">blocked_autorange</span><span class="p">(</span><span class="n">min_run_time</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-shell highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;torch.utils.benchmark.utils.common.Measurement object at 0x7f5a02eaf4a0&gt;
state_space_allpole
State-Space All-Pole Filter
  Median: 118.41 ms
  IQR:    1.17 ms <span class="o">(</span>117.79 to 118.96<span class="o">)</span>
  9 measurements, 1 runs per measurement, 4 threads
</code></pre></div></div>

<p>Interestingly, the SSM implementation is approximately 50 ms faster.</p>

<p>By using <code class="language-plaintext highlighter-rouge">torch.profiler.profile</code>, I found that, in the naive implementation, <code class="language-plaintext highlighter-rouge">torch.cat</code> for updating the last M outputs accounts for a significant portion of the total time (~20%).
The actual computation, <code class="language-plaintext highlighter-rouge">torch.addmv</code>, takes only about 10% of the time.
Regarding memory usage, the most memory-intensive operation is <code class="language-plaintext highlighter-rouge">torch.addmv</code>, which consumes approximately 512 Kb of memory.
In contrast, the SSM implementation uses more memory (&gt; 1 Mb) due to matrix multiplication, but roughly 38% of the time is spent on filtering since it no longer has to call <code class="language-plaintext highlighter-rouge">torch.cat</code> at each time step.
The state vector (a.k.a the last M outputs) is automatically updated during the matrix multiplication.</p>

<p><strong>Conclusion</strong>: Tensor concatenation (including <code class="language-plaintext highlighter-rouge">torch.cat</code> and <code class="language-plaintext highlighter-rouge">torch.stack</code>) is computationally expensive, and it is advisable to avoid it whenever possible.</p>

<h2 id="unrolling-the-ssm">Unrolling the SSM</h2>

<p>Now we can apply the unrolling trick to the SSM implementation.
The idea is to divide the input signal into blocks of size \(T\) and perform the recursion on these blocks instead of processing them sample-by-sample.
Each recursion takes the last vector state \(\mathbf{h}[n-1]\) and predicts the next \(T\) states \([\mathbf{h}[n], \mathbf{h}[n+1], \ldots, \mathbf{h}[n+T-1]]^\top\) at once.
To see how to calculate these states, let’s unroll the SSM recursion for \(T\) steps:</p>

\[\begin{align}
\mathbf{h}[n] &amp;= \mathbf{A} \mathbf{h}[n-1] + \mathbf{B} x[n] \\
\mathbf{h}[n+1] &amp;= \mathbf{A} \mathbf{h}[n] + \mathbf{B} x[n+1] \\
&amp;= \mathbf{A} (\mathbf{A} \mathbf{h}[n-1] + \mathbf{B} x[n]) + \mathbf{B} x[n+1] \\
&amp;= \mathbf{A}^2 \mathbf{h}[n-1] + \mathbf{A} \mathbf{B} x[n] + \mathbf{B} x[n+1] \\
\mathbf{h}[n+2] &amp;= \mathbf{A} \mathbf{h}[n+1] + \mathbf{B} x[n+2] \\
&amp;= \mathbf{A} (\mathbf{A}^2 \mathbf{h}[n-1] + \mathbf{A} \mathbf{B} x[n] + \mathbf{B} x[n+1]) + \mathbf{B} x[n+2] \\
&amp;= \mathbf{A}^3 \mathbf{h}[n-1] + \mathbf{A}^2 \mathbf{B} x[n] + \mathbf{A} \mathbf{B} x[n+1] + \mathbf{B} x[n+2] \\
&amp; \vdots \\
\mathbf{h}[n+T-1] &amp;= \mathbf{A}^{T} \mathbf{h}[n-1] + \sum_{t=0}^{T-1} \mathbf{A}^{T - t -1} \mathbf{B} x[n+t] \\
\end{align}\]

<p>We can rewrite the above equation in matrix form as follows:</p>

\[\begin{align}
\begin{bmatrix}
    \mathbf{h}[n] \\
    \mathbf{h}[n+1] \\
    \vdots \\
    \mathbf{h}[n+T-1]
\end{bmatrix} &amp;= \begin{bmatrix}
    \mathbf{A} \\
    \mathbf{A}^2 \\
    \vdots \\
    \mathbf{A}^T \\
\end{bmatrix} \mathbf{h}[n-1]
+ \begin{bmatrix}
    \mathbf{I} &amp; 0 &amp; \cdots &amp; 0 \\
    \mathbf{A} &amp; \mathbf{I} &amp; \cdots &amp; 0 \\
    \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
    \mathbf{A}^{T-1} &amp; \mathbf{A}^{T-2} &amp; \cdots &amp; \mathbf{I}
\end{bmatrix}
\begin{bmatrix}
    \mathbf{B}x[n] \\
    \mathbf{B}x[n+1] \\
    \vdots \\
    \mathbf{B}x[n+T-1]
\end{bmatrix} \\
&amp; = \begin{bmatrix}
    \mathbf{A} \\
    \mathbf{A}^2 \\
    \vdots \\
    \mathbf{A}^T \\
\end{bmatrix} \mathbf{h}[n-1]
+ \begin{bmatrix}
    \mathbf{I}_{.1} &amp; 0 &amp; \cdots &amp; 0 \\
    \mathbf{A}_{.1} &amp; \mathbf{I}_{.1} &amp; \cdots &amp; 0 \\
    \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
    \mathbf{A}_{.1}^{T-1} &amp; \mathbf{A}_{.1}^{T-2} &amp; \cdots &amp; \mathbf{I}_{.1}
\end{bmatrix}
\begin{bmatrix}
    x[n] \\
    x[n+1] \\
    \vdots \\
    x[n+T-1]
\end{bmatrix} \\
&amp; = \mathbf{M} \mathbf{h}[n-1] + \mathbf{V} \begin{bmatrix}
    x[n] \\
    x[n+1] \\
    \vdots \\
    x[n+T-1]
\end{bmatrix} \\
\end{align}\]

<p>Notice that in the second line, I utilised the fact that \(\mathbf{B}\) has only one non-zero entry to simplify the matrix.
(This is not possible if the filter is not strictly all-pole.)
\(\mathbf{I}_{.1}\) denotes the first column of the identity matrix and so on.</p>

<p>Now, the number of autoregressive steps is reduced from \(N\) to \(\frac{N}{T}\) and the matrix multiplication is done in parallel for every \(T\) samples.
There are added costs for pre-computing the transition matrix \(\mathbf{M}\) and the input matrix \(\mathbf{V}\), though.
However, as long as the extra cost is relatively small compared to the cost of \(N - \frac{N}{T}\) autoregressive steps, we should observe a speedup.</p>

<p>Here’s the PyTorch implementation of the unrolled SSM:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">jit</span><span class="p">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">state_space_allpole_unrolled</span><span class="p">(</span>
    <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">unroll_factor</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
    <span class="s">"""
    Unrolled state-space implementation of all-pole filtering.

    Args:
        x (Tensor): Input signal.
        a (Tensor): All-pole coefficients.
        unroll_factor (int): Factor by which to unroll the loop.

    Returns:
        Tensor: Filtered output signal.
    """</span>
    <span class="k">if</span> <span class="n">unroll_factor</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">state_space_allpole</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="p">)</span>
    <span class="k">elif</span> <span class="n">unroll_factor</span> <span class="o">&lt;</span> <span class="mi">1</span><span class="p">:</span>
        <span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="s">"Unroll factor must be &gt;= 1"</span><span class="p">)</span>

    <span class="k">assert</span> <span class="n">x</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s">"Input signal must be a 2D tensor (batch_size, signal_length)"</span>
    <span class="k">assert</span> <span class="n">a</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s">"All-pole coefficients must be a 1D tensor"</span>
    <span class="k">assert</span> <span class="p">(</span>
        <span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">unroll_factor</span> <span class="o">==</span> <span class="mi">0</span>
    <span class="p">),</span> <span class="s">"Signal length must be divisible by unroll factor"</span>

    <span class="n">c</span> <span class="o">=</span> <span class="n">a2companion</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>

    <span class="c1"># create an initial identity matrix
</span>    <span class="n">initial</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">c</span><span class="p">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">c</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
    <span class="n">c_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">initial</span><span class="p">]</span>
    <span class="c1"># TODO: use parallel scan to improve speed
</span>    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">unroll_factor</span><span class="p">):</span>
        <span class="n">c_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">c_list</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">@</span> <span class="n">c</span><span class="p">)</span>

    <span class="c1"># c_list = [I c c^2 ... c^unroll_factor]
</span>    <span class="n">M</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">(</span><span class="n">c_list</span><span class="p">[</span><span class="mi">1</span><span class="p">:],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">).</span><span class="n">T</span>
    <span class="n">flatten_c_list</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">(</span>
        <span class="p">[</span><span class="n">c</span><span class="p">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">unroll_factor</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))]</span>
        <span class="o">+</span> <span class="p">[</span><span class="n">xx</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">xx</span> <span class="ow">in</span> <span class="n">c_list</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span>
        <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="n">V</span> <span class="o">=</span> <span class="n">flatten_c_list</span><span class="p">.</span><span class="n">unfold</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">unroll_factor</span><span class="p">,</span> <span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)).</span><span class="n">flip</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

    <span class="c1"># divide the input signal into blocks of size unroll_factor
</span>    <span class="n">unrolled_x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">unflatten</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">unroll_factor</span><span class="p">))</span> <span class="o">@</span> <span class="n">V</span>

    <span class="n">output</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="c1"># assume initial condition is zero
</span>    <span class="n">h</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
    <span class="k">for</span> <span class="n">xt</span> <span class="ow">in</span> <span class="n">unrolled_x</span><span class="p">.</span><span class="n">unbind</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span>
        <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">addmm</span><span class="p">(</span><span class="n">xt</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">M</span><span class="p">)</span>
        <span class="c1"># B^T @ h
</span>        <span class="n">output</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">h</span><span class="p">[:,</span> <span class="p">::</span> <span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)])</span>
        <span class="n">h</span> <span class="o">=</span> <span class="n">h</span><span class="p">[</span>
            <span class="p">:,</span> <span class="o">-</span><span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="p">:</span>
        <span class="p">]</span>  <span class="c1"># take the last state vector as the initial condition for the next step
</span>    <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<p>The <code class="language-plaintext highlighter-rouge">unroll_factor</code> parameter controls the number of samples to process in parallel.
If it is set to 1, the function is the original SSM implementation.</p>

<p>Now let’s benchmark the speed of the unrolled SSM implementation.
We’ll use <code class="language-plaintext highlighter-rouge">unroll_factor=128</code> since I already tested that it is the optimal value :)</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">state_space_allpole_unrolled_t</span> <span class="o">=</span> <span class="n">Timer</span><span class="p">(</span>
    <span class="n">stmt</span><span class="o">=</span><span class="s">"state_space_allpole_unrolled(x, a, unroll_factor=unroll_factor)"</span><span class="p">,</span>
    <span class="nb">globals</span><span class="o">=</span><span class="p">{</span>
        <span class="s">"state_space_allpole_unrolled"</span><span class="p">:</span> <span class="n">state_space_allpole_unrolled</span><span class="p">,</span>
        <span class="s">"x"</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span>
        <span class="s">"a"</span><span class="p">:</span> <span class="n">a</span><span class="p">,</span>
        <span class="s">"unroll_factor"</span><span class="p">:</span> <span class="mi">128</span><span class="p">,</span>
    <span class="p">},</span>
    <span class="n">label</span><span class="o">=</span><span class="s">"state_space_allpole_unrolled"</span><span class="p">,</span>
    <span class="n">description</span><span class="o">=</span><span class="s">"State-Space All-Pole Filter Unrolled"</span><span class="p">,</span>
    <span class="n">num_threads</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">state_space_allpole_unrolled_t</span><span class="p">.</span><span class="n">blocked_autorange</span><span class="p">(</span><span class="n">min_run_time</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-shell highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;torch.utils.benchmark.utils.common.Measurement object at 0x7f5a01d75160&gt;
state_space_allpole_unrolled
State-Space All-Pole Filter Unrolled
  Median: 1.89 ms
  IQR:    0.08 ms <span class="o">(</span>1.88 to 1.96<span class="o">)</span>
  6 measurements, 100 runs per measurement, 4 threads
</code></pre></div></div>
<p>1.89 ms! What sorcery is this? That’s a whopping 60x speedup compared to the standard SSM implementation!</p>

<p>A closer look at the profiling results shows that in total, 38% of the time is spent on matrix multiplication and addition.
The speedup comes with a cost of increased memory usage, requiring more than 2 MB for filtering.
Not a significant cost for modern Hardwares.</p>

<p>For convenience, I ran the above benchmarks using the CPU, which has very limited parallelism compared to the GPU.
Thus, the significant speedup we observe indicates that function call overhead is the major bottleneck for running recursions.</p>

<h2 id="more-comparison">More comparison</h2>

<p>Since \(T\) is an essential parameter for the unrolled SSM, I did some benchmarks to see how it affects the speed.</p>

<h3 id="varying-sequence-length">Varying sequence length</h3>

<p>In this benchmark, I fixed the batch size to 8 and the order to 2, and varied the sequence length from 4096 to 262144.
The results suggest that the best unroll factor increases as the sequence length increases, and it’s very likely to be \(\sqrt{N}\).
Additionally, the longer the sequence length, the greater the speedup we achieve from the unrolled SSM.</p>

<p><img src="/images/unroll-ssm/benchmark_seq_len.png" alt="" /></p>

<h3 id="varying-filter-order">Varying filter order</h3>

<p>To examine the impact of filter order on speed, I set the batch size to 8 and the sequence length to 16384, and then varied the filter order from 2 to 16.
It appears that my hypothesis that the best factor is \(\sqrt{N}\) still holds, but the peak gradually shifts to the left as the order increases.
Moreover, the speedup is less significant for higher orders, which is expected as the \(\mathbf{V}\) matrix becomes larger.</p>

<p><img src="/images/unroll-ssm/benchmark_order.png" alt="" /></p>

<h3 id="varying-batch-size">Varying batch size</h3>

<p>The speedup is less as the batch size increases, which is expected.
However, the peak of the best unroll factor also shifts slightly to the left as the batch size increases.</p>

<p><img src="/images/unroll-ssm/benchmark_batch.png" alt="" /></p>

<h3 id="memory-usage">Memory usage</h3>

<p>To observe how memory usage changes in a differentiable training context, I ran the unrolled SSM on a 5060 Ti, allowing me to use <code class="language-plaintext highlighter-rouge">torch.cuda.max_memory_allocated()</code> to measure memory usage.
When batch size is 1, as expected, the memory usage grows quadratically with the unroll factor, due to the creation of the \(\mathbf{V}\) matrix.</p>

<p><img src="/images/unroll-ssm/mem_batch_1.png" alt="" /></p>

<p>When using a larger batch size (32 in this case), this cost becomes less significant compared to the more memory used for the input signal.</p>

<p><img src="/images/unroll-ssm/mem_batch_32.png" alt="" /></p>

<h2 id="discussion">Discussion</h2>

<p>So far, we have seen that the unrolled SSM can achieve a significant speedup for IIR filtering in PyTorch.
However, determining the best unrolling factor automatically is still unclear.
From the benchmarks I did on an i7 CPU, it seems that the optimal \(T^*\) is \(\sqrt{N}\alpha\) and \(0 &lt; \alpha \leq 1\) is given by a function of the filter order and batch size.
Since I also observe similar behaviour on the GPU, it is likely that this hypothesis holds true for other hardware as well.</p>

<p>One thing I didn’t mention is numerical accuracy.
If \(|\mathbf{A}|\) is very small, the precomputed exponentials \(\mathbf{A}^T \to \mathbf{0}\) which may not be accurately represented in floating point, especially in deep learning applications we use single precision a lot.
This is less of a problem for the standard SSM, since at each time step, the input is mixed with the state vector, which could help cancel out the numerical errors.</p>

<p>The idea should apply when there are zeros in the filter.
\(\mathbf{B}\) will not be a simple one-hot vector anymore so \(\mathbf{V}\) has to be a full \(MT\times MT\) square matrix.
Time-varying filters will benefit less from the unrolling trick since \(\mathbf{V}\) will also be time-varying, and computing \(\frac{N}{T}\) such matrices in advance increases the cost.</p>

<h2 id="conclusion--thoughts">Conclusion &amp; Thoughts</h2>

<p>In this post, I demonstrate that the unrolling trick can significantly accelerate differentiable IIR filtering in PyTorch.
The extra memory cost is less of a problem for large batch sizes.
Although the filter I tested is a simple all-pole filter, it’s trivial to extend the idea to general IIR filters.</p>

<p>This method might help address one of the issues for future TorchAudio, after the Meta developers <a href="https://github.com/pytorch/audio/issues/3902">announced</a> their future plan for it.
In the next major release, all the specialised kernels written in C++, including the <code class="language-plaintext highlighter-rouge">lfilter</code> I contributed years ago, will be removed from TorchAudio.
The filter I presented here is written entirely in Python and can serve as a straightforward drop-in replacement for the current compiled <code class="language-plaintext highlighter-rouge">lfilter</code> implementation.</p>

<h2 id="notes">Notes</h2>

<p>The complete code is available in the Jupyter notebook version of this post on <a href="https://gist.github.com/yoyolicoris/b67407ffb56fa168c59275aea548fe96">Gist</a>.</p>

<h2 id="update-2962025">Update (29.6.2025)</h2>

<p>I realised that the <code class="language-plaintext highlighter-rouge">state_space_allpole_unrolled</code> function I made is very close to a two-level <a href="https://en.wikipedia.org/wiki/Prefix_sum">parallel scan</a>, and with some modifications, we can squeeze a bit more performance out of it.
Instead of computing all the \(T\) states at once per block, we can just compute the last state, which is the only one we need for the next block.
Thus, the matrix size for the multiplication is reduced from \(\mathbf{M} \in \mathbb{R}^{MT\times M}\) to \(\mathbf{A}^T \in \mathbb{R}^{M\times M}\).
The first \(M-1\) states for all the blocks can be computed later in parallel.
The algorithm (parallel scan) is as follows:</p>

<p>Firstly, compute the input to the last state in the block:</p>

\[\mathbf{z}[n+T-1] = 
\begin{bmatrix}
    \mathbf{A}_{.1}^{T-1} &amp; \mathbf{A}_{.1}^{T-2} &amp; \cdots &amp; \mathbf{I}_{.1}
\end{bmatrix}
\begin{bmatrix}
    x[n] \\
    x[n+1] \\
    \vdots \\
    x[n+T-1]
\end{bmatrix}.\]

<p>Then, compute the last state in each block recursively as follows:</p>

\[\mathbf{h}[n+T-1] = \mathbf{A}^{T} \mathbf{h}[n-1] + \mathbf{z}[n+T-1].\]

<p>Lastly, compute the remaining states in parallel:</p>

\[\begin{bmatrix}
    \mathbf{h}[n] \\
    \mathbf{h}[n+1] \\
    \vdots \\
    \mathbf{h}[n+T-2]
\end{bmatrix} =
\begin{bmatrix}
    \mathbf{A}  &amp; \mathbf{I}_{.1} &amp; 0 &amp; \cdots &amp; 0 \\
    \mathbf{A}^2 &amp;  \mathbf{A}_{.1} &amp; \mathbf{I}_{.1} &amp; \cdots &amp; 0 \\
    \vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
    \mathbf{A}^{T-1} &amp;　\mathbf{A}_{.1}^{T-2} &amp; \mathbf{A}_{.1}^{T-3} &amp; \cdots &amp; \mathbf{I}_{.1}
\end{bmatrix}
\begin{bmatrix}
    \mathbf{h}[n-1] \\
    x[n] \\
    x[n+1] \\
    \vdots \\
    x[n+T-2]
\end{bmatrix}.\]

<p>The following code implements this algorithm, modified from the previous <code class="language-plaintext highlighter-rouge">state_space_allpole_unrolled</code> function.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">torch</span><span class="p">.</span><span class="n">jit</span><span class="p">.</span><span class="n">script</span>
<span class="k">def</span> <span class="nf">state_space_allpole_unrolled_v2</span><span class="p">(</span>
    <span class="n">x</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">a</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">unroll_factor</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tensor</span><span class="p">:</span>
    <span class="s">"""
    Unrolled state-space implementation of all-pole filtering.

    Args:
        x (Tensor): Input signal.
        a (Tensor): All-pole coefficients.
        unroll_factor (int): Factor by which to unroll the loop.

    Returns:
        Tensor: Filtered output signal.
    """</span>
    <span class="k">if</span> <span class="n">unroll_factor</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">state_space_allpole</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="p">)</span>
    <span class="k">elif</span> <span class="n">unroll_factor</span> <span class="o">&lt;</span> <span class="mi">1</span><span class="p">:</span>
        <span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="s">"Unroll factor must be &gt;= 1"</span><span class="p">)</span>

    <span class="k">assert</span> <span class="n">x</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">2</span><span class="p">,</span> <span class="s">"Input signal must be a 2D tensor (batch_size, signal_length)"</span>
    <span class="k">assert</span> <span class="n">a</span><span class="p">.</span><span class="n">dim</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s">"All-pole coefficients must be a 1D tensor"</span>
    <span class="k">assert</span> <span class="p">(</span>
        <span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="n">unroll_factor</span> <span class="o">==</span> <span class="mi">0</span>
    <span class="p">),</span> <span class="s">"Signal length must be divisible by unroll factor"</span>

    <span class="n">c</span> <span class="o">=</span> <span class="n">a2companion</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>

    <span class="c1"># create an initial identity matrix
</span>    <span class="n">I</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">c</span><span class="p">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">c</span><span class="p">.</span><span class="n">dtype</span><span class="p">)</span>
    <span class="n">c_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">I</span><span class="p">]</span>
    <span class="c1"># TODO: use parallel scan to improve speed
</span>    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">unroll_factor</span><span class="p">):</span>
        <span class="n">c_list</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">c_list</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">@</span> <span class="n">c</span><span class="p">)</span>

    <span class="c1"># c_list = [I c c^2 ... c^unroll_factor]
</span>    <span class="n">flatten_c_list</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">(</span>
        <span class="p">[</span><span class="n">c</span><span class="p">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">unroll_factor</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))]</span>
        <span class="o">+</span> <span class="p">[</span><span class="n">xx</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">xx</span> <span class="ow">in</span> <span class="n">c_list</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span>
        <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="n">V</span> <span class="o">=</span> <span class="n">flatten_c_list</span><span class="p">.</span><span class="n">unfold</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">unroll_factor</span><span class="p">,</span> <span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)).</span><span class="n">flip</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>

    <span class="c1"># divide the input signal into blocks of size unroll_factor
</span>    <span class="n">unrolled_x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">unflatten</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">unroll_factor</span><span class="p">))</span>

    <span class="c1"># get the last row of Vx
</span>    <span class="n">last_x</span> <span class="o">=</span> <span class="n">unrolled_x</span> <span class="o">@</span> <span class="n">V</span><span class="p">[:,</span> <span class="o">-</span><span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="p">:]</span>

    <span class="c1"># initial condition
</span>    <span class="n">zi</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">new_zeros</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>

    <span class="c1"># transition matrix on the block level
</span>    <span class="n">AT</span> <span class="o">=</span> <span class="n">c_list</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">].</span><span class="n">T</span>
    <span class="n">block_output</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">h</span> <span class="o">=</span> <span class="n">zi</span>
    <span class="c1"># block level recursion
</span>    <span class="k">for</span> <span class="n">xt</span> <span class="ow">in</span> <span class="n">last_x</span><span class="p">.</span><span class="n">unbind</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span>
        <span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">addmm</span><span class="p">(</span><span class="n">xt</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">AT</span><span class="p">)</span>
        <span class="n">block_output</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>

    <span class="c1"># stack the accumulated last outputs of the blocks as initial conditions for the intermediate steps
</span>    <span class="n">initials</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span><span class="n">zi</span><span class="p">]</span> <span class="o">+</span> <span class="n">block_output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># prepare the augmented matrix and input for all the remaining steps
</span>    <span class="n">aug_x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">initials</span><span class="p">[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">unrolled_x</span><span class="p">[...,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
    <span class="n">aug_A</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">(</span>
        <span class="p">[</span>
            <span class="n">torch</span><span class="p">.</span><span class="n">stack</span><span class="p">([</span><span class="n">c</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">c_list</span><span class="p">[</span><span class="mi">1</span><span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
            <span class="n">V</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:</span> <span class="o">-</span><span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="p">:</span> <span class="n">c</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)],</span>
        <span class="p">],</span>
        <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
    <span class="p">)</span>
    <span class="n">output</span> <span class="o">=</span> <span class="n">aug_x</span> <span class="o">@</span> <span class="n">aug_A</span>

    <span class="c1"># concat the first M - 1 outputs with the last one
</span>    <span class="n">output</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">output</span><span class="p">,</span> <span class="n">initials</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:,</span> <span class="p">:</span><span class="mi">1</span><span class="p">]],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">output</span><span class="p">.</span><span class="n">flatten</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
</code></pre></div></div>

<p>Let’s benchmark it!</p>

<div class="language-shell highlighter-rouge"><div class="highlight"><pre class="highlight"><code>&lt;torch.utils.benchmark.utils.common.Measurement object at 0x78d297b8b290&gt;
state_space_allpole_unrolled_v2
State-Space All-Pole Filter Unrolled
  Median: 1.40 ms
  IQR:    0.01 ms <span class="o">(</span>1.40 to 1.41<span class="o">)</span>
  7 measurements, 100 runs per measurement, 4 threads
</code></pre></div></div>

<p>1.40 ms! That’s approximately 1.35 times faster than the previous version.
It might be worth redoing the benchmarks again, but I’m too lazy to do it now :D
It should be similar to the previous result. 
I’ll upload benchmark results to Gist soon.</p>]]></content><author><name>Chin-Yun Yu</name><email>chin-yun.yu@qmul.ac.uk</email></author><category term="differentiable IIR" /><category term="scientific computing" /><category term="pytorch" /><category term="state-space model" /><summary type="html"><![CDATA[I recently came across a presentation by Andres Ezequiel Viso from GPU Audio at ADC 2022, in which he talked about how they accelerate IIR filters on the GPU. The approach they use is to formulate the IIR filter as a state-space model (SSM) and augment the transition matrix so that each step processes multiple samples at once. The primary speedup stems from the fact that GPUs are very good at performing large matrix multiplications, and the SSM formulation enables us to leverage this capability.]]></summary></entry><entry><title type="html">Notes on Differentiable TDF-II Filter</title><link href="https://iamycy.github.io/posts/2025/04/differentiable-tdf-ii/" rel="alternate" type="text/html" title="Notes on Differentiable TDF-II Filter" /><published>2025-04-26T00:00:00+00:00</published><updated>2025-04-26T00:00:00+00:00</updated><id>https://iamycy.github.io/posts/2025/04/differentiable-tdf-ii</id><content type="html" xml:base="https://iamycy.github.io/posts/2025/04/differentiable-tdf-ii/"><![CDATA[<p>This blog is a continuation of some of my early calculations for propagating gradients through general IIR filters, including direct-form and transposed-direct-form.</p>

<h2 id="back-story">Back story</h2>

<p>In early 2021, I implemented a differentiable <code class="language-plaintext highlighter-rouge">lfilter</code> function for <code class="language-plaintext highlighter-rouge">torchaudio</code> (a few core details were published two years later <a href="/publications/2023-11-4-golf">here</a>). 
The basic idea is to implement the backpropagation of gradients in C++ for optimal performance.
The implementation was based on Direct-Form-I (DF-I).
This differs from the popular implementation of SciPy’s <code class="language-plaintext highlighter-rouge">lfilter</code>, which is based on Transposed-Direct-Form-II (TDF-II) and is more numerically stable<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>.</p>

<p>Implementing it in this form would be better, but… at the time, my knowledge base was insufficient to generalise the idea to TDF-II.
In DF-I/II, the gradients of FIR and all-pole filters can be treated independently, so I worked only on the recursive part of the filter (the all-pole).</p>

<div style="display: flex; justify-content: space-between;">
  <div style="width: 55%;">
    <img src="https://ccrma.stanford.edu/~jos/filters/img1127_2x.png" alt="TDF-II" style="width: 100%; height: auto;" />
  </div>
  <div style="width: 45%;">
    <img src="https://ccrma.stanford.edu/~jos/filters/img1144_2x.png" alt="DF-I" style="width: 100%; height: auto;" />
  </div>
</div>

<p>However, in TDF-II, the two parts are combined and the registers are shared, so my previous approach does not work.
I left this as a TODO for the future<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup>.</p>

<p><img src="https://ccrma.stanford.edu/~jos/filters/img1147_2x.png" alt="DF-I" style="width: 50%; height: auto;" /></p>

<p>Many things have changed since then.
I started my PhD in 2022 and have more time to think thoroughly about the problem.
My understanding of filters improved after exploring the idea with some publications a few times.
It’s time to revisit the problem, a differentiable TDF-II filter.</p>

<p><strong>TL;DR</strong>, <em>the backpropagation of TDF-II filter is a DF-II filter, and vice versa.</em></p>

<p>The following calculation considers the general case when the filter parameters are <strong>time-varying</strong>.
Time-invariant systems are a special case and are trivial once we have the time-varying results.</p>

<h2 id="transposed-direct-form-ii">(Transposed-)Direct-Form-II</h2>
<p>Given time-varying coefficients \(\{b_0[n], b_1[n],\dots,b_M[n]\}\) and \(\{a_1[n],\dots,a_N[n]\}\), the TDF-II filter can be expressed as:</p>

\[y[n] = s_1[n] + b_0[n] x[n]\]

\[s_1[n+1] = s_2[n] + b_1[n] x[n] - a_1[n] y[n]\\\]

\[s_2[n+1] = s_3[n] + b_2[n] x[n] - a_2[n] y[n]\]

\[\vdots\]

\[s_M[n+1] = b_M[n] x[n] - a_M[n] y[n].\]

<p>We can also write it in observable canonical form:</p>

\[\mathbf{s}[n+1]
=
\mathbf{A}[n] \mathbf{s}[n] + \mathbf{B}[n] x[n]\]

\[y[n] = \mathbf{C}\mathbf{s}[n] + b_0[n] x[n]\]

\[\mathbf{A}[n] =
\begin{bmatrix}
  -a_1[n] &amp; 1 &amp; 0 &amp; \cdots &amp; 0 \\
  -a_2[n] &amp; 0 &amp; 1 &amp; \cdots &amp; 0 \\
  \vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \vdots \\
  -a_{M-1}[n] &amp; 0 &amp; 0 &amp; \cdots &amp; 1 \\
  -a_M[n] &amp; 0 &amp; 0 &amp; \cdots &amp; 0
\end{bmatrix}\]

\[\mathbf{C} =
\begin{bmatrix}
  1 &amp; 0 &amp; 0 &amp; \cdots &amp; 0 \\
\end{bmatrix}.\]

<p>The values of \(\mathbf{B}[n] \) can be referred from Julius’ blog<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup>.</p>

<p>Regarding DF-II, its difference equations are:</p>

\[v[n] = x[n] - \sum_{i=1}^{M} a_i[n] v[n-i]\]

\[y[n] = \sum_{i=0}^{M} b_i[n] v[n-i].\]

<p>Similarly, it can be expressed as a state-space model using the controller canonical form:</p>

\[\mathbf{v}[n+1]
=
\begin{bmatrix}
  -a_1[n] &amp; -a_2[n] &amp; \cdots &amp; -a_{M-1}[n] &amp; -a_M[n] \\
  1 &amp; 0 &amp; \cdots &amp; 0 &amp; 0 \\
  0 &amp; 1 &amp; \cdots &amp; 0 &amp; 0 \\
  \vdots &amp; \vdots &amp; \ddots &amp; \vdots &amp; \vdots \\
  0 &amp; 0 &amp; \cdots &amp; 1 &amp; 0
\end{bmatrix}
\mathbf{v}[n] +
\begin{bmatrix}
  1 \\
  0 \\
  \vdots \\
  0
\end{bmatrix}
x[n]
= \mathbf{A}^\top[n]\mathbf{v}[n] + \mathbf{C}^\top x[n]\]

\[y[n] = \mathbf{B}^\top[n] \mathbf{v}[n] + b_0[n] x[n].\]

<p>As I have shown above, the forms are very similar.
The transition matrix of TDF-II is the transpose of the DF-II, and the vectors <strong>B</strong> and <strong>C</strong> are swapped.
(This is the reason why we call it transposed-DF-II.)
Note that the resulting transfer function is not the same due to the difference in computation order in the time-varying case.
(They are the same if the coefficients are time-invariant!)
I will use the state-space form for simplicity in the following sections.</p>

<h2 id="backpropagation-through-tdf-ii">Backpropagation through TDF-II</h2>

<p>Supposed we have evaluated some loss function \(\mathcal{L}\) on the output of the filter \(y[n]\) and has the instantaneous gradients \(\frac{\partial \mathcal{L}}{\partial \mathbf{s}[n]}\).
We want to backpropagate the gradients through the filter to get the gradients of the input \(\frac{\partial \mathcal{L}}{\partial x[n]}\) and the filter coefficients \(\frac{\partial \mathcal{L}}{\partial a_i[n]}\) and \(\frac{\partial \mathcal{L}}{\partial b_i[n]}\).
Let’s first denote \(\mathbf{z}[n] = \mathbf{B}[n] x[n]\) since once we get the gradients of \(\mathbf{z}[n]\), it’s easy to get the gradients of the two using the chain rule.
<!-- Also we'll assume the length of the signal is bounded in the range \\([1, N]\\). -->
The recursion in TDF-II state-space form becomes:</p>

\[\mathbf{s}[n+1] = \mathbf{A}[n] \mathbf{s}[n] + \mathbf{z}[n].\]

<p>If we unroll the recursion so there’s no <strong>s</strong> in the right-hand side, we get:</p>

\[\mathbf{s}[n+1] = \sum_{i=1}^{\infty} \left(\prod_{j=1}^{i} \mathbf{A}[n-j+1]\right) \mathbf{z}[n-i] + \mathbf{z}[n].\]

<p>The gradients for <strong>z</strong> can be computed as:</p>

\[\frac{\partial \mathbf{s}[n]}{\partial \mathbf{z}[i]} 
= 
\begin{cases}
  \prod_{j=1}^{n-i-1} \mathbf{A}[n-j] &amp; i &lt; n - 1 \\
  \mathbf{I} &amp; i = n -1 \\
  0 &amp; i \geq n
\end{cases}\]

\[\frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]}
= \sum_{i=n+1}^{\infty} \frac{\partial \mathcal{L}}{\partial \mathbf{s}[i]} \frac{\partial \mathbf{s}[i]}{\partial \mathbf{z}[n]}\]

\[= \frac{\partial \mathcal{L}}{\partial \mathbf{s}[n+1]} + \sum_{i=n+2}^{\infty} \frac{\partial \mathcal{L}}{\partial \mathbf{s}[i]} \prod_{j=1}^{i-n-1} \mathbf{A}[i-j]\]

\[= \frac{\partial \mathcal{L}}{\partial \mathbf{s}[n+1]} + \sum_{i=n+2}^{\infty} \left( \prod_{j=i-n-1}^{1} \mathbf{A}^\top[i-j] \right) \frac{\partial \mathcal{L}}{\partial \mathbf{s}[i]}\]

\[= \frac{\partial \mathcal{L}}{\partial \mathbf{s}[n+1]} + \sum_{i=n+2}^{\infty} \left( \prod_{j=1}^{i-n-1} \mathbf{A}^\top[n+j] \right) \frac{\partial \mathcal{L}}{\partial \mathbf{s}[i]}.\]

\[= \mathbf{A}^\top[n+1] \frac{\partial \mathcal{L}}{\partial \mathbf{z}[n+1]} + \frac{\partial \mathcal{L}}{\partial \mathbf{s}[n+1]}.\]

<p>For simplicity, I omitted the transpose sign for the vector.
The last recursion involves \(\mathbf{A}^\top\), which implies that, to backpropagate the gradients through the recursion of TDF-II, we need to use the <strong>recursion of DF-II but in the opposite direction</strong>!
Their roles will be swapped if we compute the gradients of DF-II using the same procedure, but I’ll leave it as an exercise for the reader :D</p>

<p>For completeness, the following are the procedures to compute the gradients of the input and filter coefficients.</p>

<h3 id="gradients-of-the-input">Gradients of the input</h3>

\[\frac{\partial \mathcal{L}}{\partial \mathbf{s}[n]} 
= \mathbf{C}^\top \frac{\partial \mathcal{L}}{\partial y[n]}
% \begin{bmatrix}
%   \frac{\partial \mathcal{L}}{\partial y[n]} \\
%   0 \\
%   \vdots \\
%   0
% \end{bmatrix}\]

\[\frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]}
= \mathbf{A}^\top[n+1] \frac{\partial \mathcal{L}}{\partial \mathbf{z}[n+1]} + \mathbf{C}^\top \frac{\partial \mathcal{L}}{\partial y[n+1]}\]

<p>(Note that the above line is the same as in DF-II! Just the input and output variables are changed.)</p>

\[\frac{\partial \mathcal{L}}{\partial x[n]}
=  \mathbf{B}^\top[n] \frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]} + b_0[n] \frac{\partial \mathcal{L}}{\partial y[n]}\]

<h3 id="gradients-of-the-b-coefficients">Gradients of the <strong>b</strong> coefficients</h3>

\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}[n]}
= \frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]} x[n]\]

\[\frac{\partial \mathcal{L}}{\partial b_0[n]}
= \frac{\partial \mathcal{L}}{\partial y[n]} x[n]\]

<h3 id="gradients-of-the-a-coefficients">Gradients of the <strong>a</strong> coefficients</h3>

\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}[n]}
= \frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]} \mathbf{s}^\top[n]
\to a_i[n] = -\frac{\partial \mathcal{L}}{\partial z_i[n]} s_1[n]\]

<h3 id="time-invariant-case">Time-invariant case</h3>
<p>In the time-invariant case, the parameters are constant.</p>

\[a_i[n] = a_i[m] \quad \forall n, m, \quad i = 1, \dots, M\]

\[b_i[n] = b_i[m] \quad \forall n, m, \quad i = 0, \dots, M\]

<p>In this case, we can just sum the gradients over time:</p>

\[\frac{\partial \mathcal{L}}{\partial a_i} = \sum_{n} \frac{\partial \mathcal{L}}{\partial a_i[n]},~\ \frac{\partial \mathcal{L}}{\partial b_i} = \sum_{n} \frac{\partial \mathcal{L}}{\partial b_i[n]}.\]

<h2 id="summary">Summary</h2>

<p>The above findings suggest a way to compute the TDF-II filter’s gradients efficiently.
To do this, the following steps are needed:</p>

<ol>
  <li>Implement the recursions of TDF-II and DF-II filters in C++/CUDA/Metal/etc.</li>
  <li>After doing the forward pass of TDF-II, store \(s_1[n]\), \(\mathbf{a}[n]\), \(\mathbf{b}[n]\), and \(x[n]\).</li>
  <li>When doing backpropagation, filter the output gradients \(\frac{\partial \mathcal{L}}{\partial y[n]}\) through the DF-II filter’s recusions in the opposite direction using the same <strong>a</strong> coefficients.</li>
  <li>Compute the gradients of the input and filter coefficients using the equations above. Note that although \(\frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]}\) is a sequence of vectors, since the higher-order states in DF-II are just time-delayed versions of the first state (\(v_M[n] = v_{M-1}[n-1] = \cdots = v_1[n-M+1]\)), we can just store \(\frac{\partial \mathcal{L}}{\partial z_1[n]}\) for gradient computation, reducing the memory usage by a factor of \(M\).</li>
</ol>

<h2 id="final-thoughts">Final thoughts</h2>
<p>The procedure above can be applied to derive the gradients of the DF-II filter as well.
The resulting algorithm is identical, but the roles of TDF-II and DF-II are swapped.
Personally, I found using a state-space formulation much easier, straightforward, and elegant than the <a href="/publications/2024-9-3-diffapf">derivation I did in 2024</a> to calculate the gradients of time-varying all-pole filters, which is basically the same problem.
(Man, I was basically brute-forcing it…)
Applying the method to TDF-I is straightforward, just set \(\mathbf{B}[n] = 0\).</p>

<p>Interestingly, since the backpropagation of TDF-II is a DF-II filter, it’s less numerically stable than TDF-II; in contrast, the backpropagation of DF-II is a TDF-II filter and is more stable.
We’ll always have this trade-off, so is TDF-II necessary if we want differentiability?
Probably yes, since besides backpropagation, the gradients can also be computed using <strong>forward-mode</strong> automatic differentiation, which computes the Jacobian in the opposite direction.
In this way, the forwarded gradients are computed in the same way as the filter’s forward pass, and the math is much easier to show than the backpropagation I wrote above. (Should realise earlier…)
Also, in the time-varying case and \(M &gt; 1\), neither of the two forms guarantees BIBO stability.
This is another interesting topic, but let’s just leave it for now.
I hope this post is helpful for those who are interested in differentiable IIR filters.</p>

<h2 id="notes">Notes</h2>

<p>The figures are from <a href="https://ccrma.stanford.edu/~jos/filters/Implementation_Structures_Recursive_Digital.html">Julius O. Smith III</a> and the notations are adapted from his blog<sup id="fnref:3:1" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup>.
The algorithm is based on the following papers:</p>

<ol>
  <li>Singing Voice Synthesis Using Differentiable LPC and Glottal-Flow-Inspired Wavetables (doi: 10.5281/zenodo.13916489)</li>
  <li>Differentiable Time-Varying Linear Prediction in the Context of End-to-End Analysis-by-Synthesis (doi: 10.21437/Interspeech.2024-1187)</li>
  <li>Differentiable All-pole Filters for Time-varying Audio Systems</li>
  <li>GOLF: A Singing Voice Synthesiser with Glottal Flow Wavetables and LPC Filters (doi: 10.5334/tismir.210)</li>
</ol>

<hr />
<p><strong>References:</strong></p>

<div class="footnotes" role="doc-endnotes">
  <ol>
    <li id="fn:1" role="doc-endnote">
      <p>https://ccrma.stanford.edu/~jos/filters/Numerical_Robustness_TDF_II.html <a href="#fnref:1" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:2" role="doc-endnote">
      <p>https://github.com/pytorch/audio/pull/1310#issuecomment-790408467 <a href="#fnref:2" class="reversefootnote" role="doc-backlink">&#8617;</a></p>
    </li>
    <li id="fn:3" role="doc-endnote">
      <p>https://ccrma.stanford.edu/~jos/fp/Converting_State_Space_Form_Hand.html <a href="#fnref:3" class="reversefootnote" role="doc-backlink">&#8617;</a> <a href="#fnref:3:1" class="reversefootnote" role="doc-backlink">&#8617;<sup>2</sup></a></p>
    </li>
  </ol>
</div>]]></content><author><name>Chin-Yun Yu</name><email>chin-yun.yu@qmul.ac.uk</email></author><category term="differentiable IIR" /><category term="scientific computing" /><category term="pytorch" /><summary type="html"><![CDATA[This blog is a continuation of some of my early calculations for propagating gradients through general IIR filters, including direct-form and transposed-direct-form.]]></summary></entry><entry><title type="html">How to Train Deep NMF Model in PyTorch</title><link href="https://iamycy.github.io/posts/2021/02/torchnmf-algorithm/" rel="alternate" type="text/html" title="How to Train Deep NMF Model in PyTorch" /><published>2021-02-09T00:00:00+00:00</published><updated>2021-02-09T00:00:00+00:00</updated><id>https://iamycy.github.io/posts/2021/02/torchnmf-algorithm</id><content type="html" xml:base="https://iamycy.github.io/posts/2021/02/torchnmf-algorithm/"><![CDATA[<p>Recently I updated the implementation of PyTorch-NMF to make it be able to scale on large and complex NMF models. In this blog post I will briefly explain how this was done thanks to the automatic differentiation of PyTorch.</p>

<h1 id="multiplicative-update-rules-with-beta-divergence">Multiplicative Update Rules with Beta Divergence</h1>

<p>Multiplicative Update is a classic update method that has been widely used in many NMF applications. Its form is easy to derive, gaurantees a monotonic decrease of loss value, and ensures nonnegativity of the parameter updates.</p>

<p>Below are the multiplicative update forms when using Beta-Divergence as our criterion:</p>

<svg xmlns="http://www.w3.org/2000/svg" width="28.616ex" height="7.793ex" viewBox="0 -2259 12648.4 3444.6" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" style=""><defs><path id="MJX-126-TEX-I-1D43B" d="M228 637Q194 637 192 641Q191 643 191 649Q191 673 202 682Q204 683 219 683Q260 681 355 681Q389 681 418 681T463 682T483 682Q499 682 499 672Q499 670 497 658Q492 641 487 638H485Q483 638 480 638T473 638T464 637T455 637Q416 636 405 634T387 623Q384 619 355 500Q348 474 340 442T328 395L324 380Q324 378 469 378H614L615 381Q615 384 646 504Q674 619 674 627T617 637Q594 637 587 639T580 648Q580 650 582 660Q586 677 588 679T604 682Q609 682 646 681T740 680Q802 680 835 681T871 682Q888 682 888 672Q888 645 876 638H874Q872 638 869 638T862 638T853 637T844 637Q805 636 794 634T776 623Q773 618 704 340T634 58Q634 51 638 51Q646 48 692 46H723Q729 38 729 37T726 19Q722 6 716 0H701Q664 2 567 2Q533 2 504 2T458 2T437 1Q420 1 420 10Q420 15 423 24Q428 43 433 45Q437 46 448 46H454Q481 46 514 49Q520 50 522 50T528 55T534 64T540 82T547 110T558 153Q565 181 569 198Q602 330 602 331T457 332H312L279 197Q245 63 245 58Q245 51 253 49T303 46H334Q340 38 340 37T337 19Q333 6 327 0H312Q275 2 178 2Q144 2 115 2T69 2T48 1Q31 1 31 10Q31 12 34 24Q39 43 44 45Q48 46 59 46H65Q92 46 125 49Q139 52 144 61Q147 65 216 339T285 628Q285 635 228 637Z"></path><path id="MJX-126-TEX-N-2190" d="M944 261T944 250T929 230H165Q167 228 182 216T211 189T244 152T277 96T303 25Q308 7 308 0Q308 -11 288 -11Q281 -11 278 -11T272 -7T267 2T263 21Q245 94 195 151T73 236Q58 242 55 247Q55 254 59 257T73 264Q121 283 158 314T215 375T247 434T264 480L267 497Q269 503 270 505T275 509T288 511Q308 511 308 500Q308 493 303 475Q293 438 278 406T246 352T215 315T185 287T165 270H929Q944 261 944 250Z"></path><path id="MJX-126-TEX-N-22C5" d="M78 250Q78 274 95 292T138 310Q162 310 180 294T199 251Q199 226 182 208T139 190T96 207T78 250Z"></path><path id="MJX-126-TEX-I-1D44A" d="M436 683Q450 683 486 682T553 680Q604 680 638 681T677 682Q695 682 695 674Q695 670 692 659Q687 641 683 639T661 637Q636 636 621 632T600 624T597 615Q597 603 613 377T629 138L631 141Q633 144 637 151T649 170T666 200T690 241T720 295T759 362Q863 546 877 572T892 604Q892 619 873 628T831 637Q817 637 817 647Q817 650 819 660Q823 676 825 679T839 682Q842 682 856 682T895 682T949 681Q1015 681 1034 683Q1048 683 1048 672Q1048 666 1045 655T1038 640T1028 637Q1006 637 988 631T958 617T939 600T927 584L923 578L754 282Q586 -14 585 -15Q579 -22 561 -22Q546 -22 542 -17Q539 -14 523 229T506 480L494 462Q472 425 366 239Q222 -13 220 -15T215 -19Q210 -22 197 -22Q178 -22 176 -15Q176 -12 154 304T131 622Q129 631 121 633T82 637H58Q51 644 51 648Q52 671 64 683H76Q118 680 176 680Q301 680 313 683H323Q329 677 329 674T327 656Q322 641 318 637H297Q236 634 232 620Q262 160 266 136L501 550L499 587Q496 629 489 632Q483 636 447 637Q428 637 422 639T416 648Q416 650 418 660Q419 664 420 669T421 676T424 680T428 682T436 683Z"></path><path id="MJX-126-TEX-I-1D447" d="M40 437Q21 437 21 445Q21 450 37 501T71 602L88 651Q93 669 101 677H569H659Q691 677 697 676T704 667Q704 661 687 553T668 444Q668 437 649 437Q640 437 637 437T631 442L629 445Q629 451 635 490T641 551Q641 586 628 604T573 629Q568 630 515 631Q469 631 457 630T439 622Q438 621 368 343T298 60Q298 48 386 46Q418 46 427 45T436 36Q436 31 433 22Q429 4 424 1L422 0Q419 0 415 0Q410 0 363 1T228 2Q99 2 64 0H49Q43 6 43 9T45 27Q49 40 55 46H83H94Q174 46 189 55Q190 56 191 56Q196 59 201 76T241 233Q258 301 269 344Q339 619 339 625Q339 630 310 630H279Q212 630 191 624Q146 614 121 583T67 467Q60 445 57 441T43 437H40Z"></path><path id="MJX-126-TEX-LO-5B" d="M224 -649V1150H455V1099H275V-598H455V-649H224Z"></path><path id="MJX-126-TEX-N-28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path><path id="MJX-126-TEX-N-29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path><path id="MJX-126-TEX-I-1D6FD" d="M29 -194Q23 -188 23 -186Q23 -183 102 134T186 465Q208 533 243 584T309 658Q365 705 429 705H431Q493 705 533 667T573 570Q573 465 469 396L482 383Q533 332 533 252Q533 139 448 65T257 -10Q227 -10 203 -2T165 17T143 40T131 59T126 65L62 -188Q60 -194 42 -194H29ZM353 431Q392 431 427 419L432 422Q436 426 439 429T449 439T461 453T472 471T484 495T493 524T501 560Q503 569 503 593Q503 611 502 616Q487 667 426 667Q384 667 347 643T286 582T247 514T224 455Q219 439 186 308T152 168Q151 163 151 147Q151 99 173 68Q204 26 260 26Q302 26 349 51T425 137Q441 171 449 214T457 279Q457 337 422 372Q380 358 347 358H337Q258 358 258 389Q258 396 261 403Q275 431 353 431Z"></path><path id="MJX-126-TEX-N-2212" d="M84 237T84 250T98 270H679Q694 262 694 250T679 230H98Q84 237 84 250Z"></path><path id="MJX-126-TEX-N-32" d="M109 429Q82 429 66 447T50 491Q50 562 103 614T235 666Q326 666 387 610T449 465Q449 422 429 383T381 315T301 241Q265 210 201 149L142 93L218 92Q375 92 385 97Q392 99 409 186V189H449V186Q448 183 436 95T421 3V0H50V19V31Q50 38 56 46T86 81Q115 113 136 137Q145 147 170 174T204 211T233 244T261 278T284 308T305 340T320 369T333 401T340 431T343 464Q343 527 309 573T212 619Q179 619 154 602T119 569T109 550Q109 549 114 549Q132 549 151 535T170 489Q170 464 154 447T109 429Z"></path><path id="MJX-126-TEX-I-1D449" d="M52 648Q52 670 65 683H76Q118 680 181 680Q299 680 320 683H330Q336 677 336 674T334 656Q329 641 325 637H304Q282 635 274 635Q245 630 242 620Q242 618 271 369T301 118L374 235Q447 352 520 471T595 594Q599 601 599 609Q599 633 555 637Q537 637 537 648Q537 649 539 661Q542 675 545 679T558 683Q560 683 570 683T604 682T668 681Q737 681 755 683H762Q769 676 769 672Q769 655 760 640Q757 637 743 637Q730 636 719 635T698 630T682 623T670 615T660 608T652 599T645 592L452 282Q272 -9 266 -16Q263 -18 259 -21L241 -22H234Q216 -22 216 -15Q213 -9 177 305Q139 623 138 626Q133 637 76 637H59Q52 642 52 648Z"></path><path id="MJX-126-TEX-LO-5D" d="M16 1099V1150H247V-649H16V-598H196V1099H16Z"></path><path id="MJX-126-TEX-N-31" d="M213 578L200 573Q186 568 160 563T102 556H83V602H102Q149 604 189 617T245 641T273 663Q275 666 285 666Q294 666 302 660V361L303 61Q310 54 315 52T339 48T401 46H427V0H416Q395 3 257 3Q121 3 100 0H88V46H114Q136 46 152 46T177 47T193 50T201 52T207 57T213 61V578Z"></path></defs><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="mi"><use xlink:href="#MJX-126-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(1165.8, 0)"><use xlink:href="#MJX-126-TEX-N-2190"></use></g><g data-mml-node="mi" transform="translate(2443.6, 0)"><use xlink:href="#MJX-126-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(3553.8, 0)"><use xlink:href="#MJX-126-TEX-N-22C5"></use></g><g data-mml-node="mfrac" transform="translate(4054, 0)"><g data-mml-node="mrow" transform="translate(220, 1109.5)"><g data-mml-node="msup"><g data-mml-node="mi"><use xlink:href="#MJX-126-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1103.2, 363) scale(0.707)"><use xlink:href="#MJX-126-TEX-I-1D447"></use></g></g><g data-mml-node="mrow" transform="translate(1651, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-126-TEX-LO-5B"></use></g><g data-mml-node="msup" transform="translate(472, 0)"><g data-mml-node="mrow"><g data-mml-node="mo"><use xlink:href="#MJX-126-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(389, 0)"><use xlink:href="#MJX-126-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1437, 0)"><use xlink:href="#MJX-126-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(2325, 0)"><use xlink:href="#MJX-126-TEX-N-29"></use></g></g><g data-mml-node="TeXAtom" transform="translate(2714, 477.1) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-126-TEX-I-1D6FD"></use></g><g data-mml-node="mo" transform="translate(566, 0)"><use xlink:href="#MJX-126-TEX-N-2212"></use></g><g data-mml-node="mn" transform="translate(1344, 0)"><use xlink:href="#MJX-126-TEX-N-32"></use></g></g></g><g data-mml-node="mo" transform="translate(4762.1, 0)"><use xlink:href="#MJX-126-TEX-N-22C5"></use></g><g data-mml-node="mi" transform="translate(5262.3, 0)"><use xlink:href="#MJX-126-TEX-I-1D449"></use></g><g data-mml-node="mo" transform="translate(6031.3, 0)"><use xlink:href="#MJX-126-TEX-LO-5D"></use></g></g></g><g data-mml-node="mrow" transform="translate(1437.7, -935.6)"><g data-mml-node="msup"><g data-mml-node="mi"><use xlink:href="#MJX-126-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1103.2, 289) scale(0.707)"><use xlink:href="#MJX-126-TEX-I-1D447"></use></g></g><g data-mml-node="msup" transform="translate(1651, 0)"><g data-mml-node="mrow"><g data-mml-node="mo"><use xlink:href="#MJX-126-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(389, 0)"><use xlink:href="#MJX-126-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1437, 0)"><use xlink:href="#MJX-126-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(2325, 0)"><use xlink:href="#MJX-126-TEX-N-29"></use></g></g><g data-mml-node="TeXAtom" transform="translate(2714, 477.1) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-126-TEX-I-1D6FD"></use></g><g data-mml-node="mo" transform="translate(566, 0)"><use xlink:href="#MJX-126-TEX-N-2212"></use></g><g data-mml-node="mn" transform="translate(1344, 0)"><use xlink:href="#MJX-126-TEX-N-31"></use></g></g></g></g><rect width="8354.4" height="60" x="120" y="220"></rect></g></g></g></svg>

<hr />

<svg xmlns="http://www.w3.org/2000/svg" width="28.973ex" height="7.793ex" viewBox="0 -2259 12806 3444.6" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" style=""><defs><path id="MJX-148-TEX-I-1D44A" d="M436 683Q450 683 486 682T553 680Q604 680 638 681T677 682Q695 682 695 674Q695 670 692 659Q687 641 683 639T661 637Q636 636 621 632T600 624T597 615Q597 603 613 377T629 138L631 141Q633 144 637 151T649 170T666 200T690 241T720 295T759 362Q863 546 877 572T892 604Q892 619 873 628T831 637Q817 637 817 647Q817 650 819 660Q823 676 825 679T839 682Q842 682 856 682T895 682T949 681Q1015 681 1034 683Q1048 683 1048 672Q1048 666 1045 655T1038 640T1028 637Q1006 637 988 631T958 617T939 600T927 584L923 578L754 282Q586 -14 585 -15Q579 -22 561 -22Q546 -22 542 -17Q539 -14 523 229T506 480L494 462Q472 425 366 239Q222 -13 220 -15T215 -19Q210 -22 197 -22Q178 -22 176 -15Q176 -12 154 304T131 622Q129 631 121 633T82 637H58Q51 644 51 648Q52 671 64 683H76Q118 680 176 680Q301 680 313 683H323Q329 677 329 674T327 656Q322 641 318 637H297Q236 634 232 620Q262 160 266 136L501 550L499 587Q496 629 489 632Q483 636 447 637Q428 637 422 639T416 648Q416 650 418 660Q419 664 420 669T421 676T424 680T428 682T436 683Z"></path><path id="MJX-148-TEX-N-2190" d="M944 261T944 250T929 230H165Q167 228 182 216T211 189T244 152T277 96T303 25Q308 7 308 0Q308 -11 288 -11Q281 -11 278 -11T272 -7T267 2T263 21Q245 94 195 151T73 236Q58 242 55 247Q55 254 59 257T73 264Q121 283 158 314T215 375T247 434T264 480L267 497Q269 503 270 505T275 509T288 511Q308 511 308 500Q308 493 303 475Q293 438 278 406T246 352T215 315T185 287T165 270H929Q944 261 944 250Z"></path><path id="MJX-148-TEX-N-22C5" d="M78 250Q78 274 95 292T138 310Q162 310 180 294T199 251Q199 226 182 208T139 190T96 207T78 250Z"></path><path id="MJX-148-TEX-LO-5B" d="M224 -649V1150H455V1099H275V-598H455V-649H224Z"></path><path id="MJX-148-TEX-N-28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path><path id="MJX-148-TEX-I-1D43B" d="M228 637Q194 637 192 641Q191 643 191 649Q191 673 202 682Q204 683 219 683Q260 681 355 681Q389 681 418 681T463 682T483 682Q499 682 499 672Q499 670 497 658Q492 641 487 638H485Q483 638 480 638T473 638T464 637T455 637Q416 636 405 634T387 623Q384 619 355 500Q348 474 340 442T328 395L324 380Q324 378 469 378H614L615 381Q615 384 646 504Q674 619 674 627T617 637Q594 637 587 639T580 648Q580 650 582 660Q586 677 588 679T604 682Q609 682 646 681T740 680Q802 680 835 681T871 682Q888 682 888 672Q888 645 876 638H874Q872 638 869 638T862 638T853 637T844 637Q805 636 794 634T776 623Q773 618 704 340T634 58Q634 51 638 51Q646 48 692 46H723Q729 38 729 37T726 19Q722 6 716 0H701Q664 2 567 2Q533 2 504 2T458 2T437 1Q420 1 420 10Q420 15 423 24Q428 43 433 45Q437 46 448 46H454Q481 46 514 49Q520 50 522 50T528 55T534 64T540 82T547 110T558 153Q565 181 569 198Q602 330 602 331T457 332H312L279 197Q245 63 245 58Q245 51 253 49T303 46H334Q340 38 340 37T337 19Q333 6 327 0H312Q275 2 178 2Q144 2 115 2T69 2T48 1Q31 1 31 10Q31 12 34 24Q39 43 44 45Q48 46 59 46H65Q92 46 125 49Q139 52 144 61Q147 65 216 339T285 628Q285 635 228 637Z"></path><path id="MJX-148-TEX-N-29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path><path id="MJX-148-TEX-I-1D6FD" d="M29 -194Q23 -188 23 -186Q23 -183 102 134T186 465Q208 533 243 584T309 658Q365 705 429 705H431Q493 705 533 667T573 570Q573 465 469 396L482 383Q533 332 533 252Q533 139 448 65T257 -10Q227 -10 203 -2T165 17T143 40T131 59T126 65L62 -188Q60 -194 42 -194H29ZM353 431Q392 431 427 419L432 422Q436 426 439 429T449 439T461 453T472 471T484 495T493 524T501 560Q503 569 503 593Q503 611 502 616Q487 667 426 667Q384 667 347 643T286 582T247 514T224 455Q219 439 186 308T152 168Q151 163 151 147Q151 99 173 68Q204 26 260 26Q302 26 349 51T425 137Q441 171 449 214T457 279Q457 337 422 372Q380 358 347 358H337Q258 358 258 389Q258 396 261 403Q275 431 353 431Z"></path><path id="MJX-148-TEX-N-2212" d="M84 237T84 250T98 270H679Q694 262 694 250T679 230H98Q84 237 84 250Z"></path><path id="MJX-148-TEX-N-32" d="M109 429Q82 429 66 447T50 491Q50 562 103 614T235 666Q326 666 387 610T449 465Q449 422 429 383T381 315T301 241Q265 210 201 149L142 93L218 92Q375 92 385 97Q392 99 409 186V189H449V186Q448 183 436 95T421 3V0H50V19V31Q50 38 56 46T86 81Q115 113 136 137Q145 147 170 174T204 211T233 244T261 278T284 308T305 340T320 369T333 401T340 431T343 464Q343 527 309 573T212 619Q179 619 154 602T119 569T109 550Q109 549 114 549Q132 549 151 535T170 489Q170 464 154 447T109 429Z"></path><path id="MJX-148-TEX-I-1D449" d="M52 648Q52 670 65 683H76Q118 680 181 680Q299 680 320 683H330Q336 677 336 674T334 656Q329 641 325 637H304Q282 635 274 635Q245 630 242 620Q242 618 271 369T301 118L374 235Q447 352 520 471T595 594Q599 601 599 609Q599 633 555 637Q537 637 537 648Q537 649 539 661Q542 675 545 679T558 683Q560 683 570 683T604 682T668 681Q737 681 755 683H762Q769 676 769 672Q769 655 760 640Q757 637 743 637Q730 636 719 635T698 630T682 623T670 615T660 608T652 599T645 592L452 282Q272 -9 266 -16Q263 -18 259 -21L241 -22H234Q216 -22 216 -15Q213 -9 177 305Q139 623 138 626Q133 637 76 637H59Q52 642 52 648Z"></path><path id="MJX-148-TEX-LO-5D" d="M16 1099V1150H247V-649H16V-598H196V1099H16Z"></path><path id="MJX-148-TEX-I-1D447" d="M40 437Q21 437 21 445Q21 450 37 501T71 602L88 651Q93 669 101 677H569H659Q691 677 697 676T704 667Q704 661 687 553T668 444Q668 437 649 437Q640 437 637 437T631 442L629 445Q629 451 635 490T641 551Q641 586 628 604T573 629Q568 630 515 631Q469 631 457 630T439 622Q438 621 368 343T298 60Q298 48 386 46Q418 46 427 45T436 36Q436 31 433 22Q429 4 424 1L422 0Q419 0 415 0Q410 0 363 1T228 2Q99 2 64 0H49Q43 6 43 9T45 27Q49 40 55 46H83H94Q174 46 189 55Q190 56 191 56Q196 59 201 76T241 233Q258 301 269 344Q339 619 339 625Q339 630 310 630H279Q212 630 191 624Q146 614 121 583T67 467Q60 445 57 441T43 437H40Z"></path><path id="MJX-148-TEX-N-31" d="M213 578L200 573Q186 568 160 563T102 556H83V602H102Q149 604 189 617T245 641T273 663Q275 666 285 666Q294 666 302 660V361L303 61Q310 54 315 52T339 48T401 46H427V0H416Q395 3 257 3Q121 3 100 0H88V46H114Q136 46 152 46T177 47T193 50T201 52T207 57T213 61V578Z"></path></defs><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="mi"><use xlink:href="#MJX-148-TEX-I-1D44A"></use></g><g data-mml-node="mo" transform="translate(1325.8, 0)"><use xlink:href="#MJX-148-TEX-N-2190"></use></g><g data-mml-node="mi" transform="translate(2603.6, 0)"><use xlink:href="#MJX-148-TEX-I-1D44A"></use></g><g data-mml-node="mo" transform="translate(3873.8, 0)"><use xlink:href="#MJX-148-TEX-N-22C5"></use></g><g data-mml-node="mfrac" transform="translate(4374, 0)"><g data-mml-node="mrow" transform="translate(220, 1109.5)"><g data-mml-node="mrow"><g data-mml-node="mo"><use xlink:href="#MJX-148-TEX-LO-5B"></use></g><g data-mml-node="msup" transform="translate(472, 0)"><g data-mml-node="mrow"><g data-mml-node="mo"><use xlink:href="#MJX-148-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(389, 0)"><use xlink:href="#MJX-148-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1437, 0)"><use xlink:href="#MJX-148-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(2325, 0)"><use xlink:href="#MJX-148-TEX-N-29"></use></g></g><g data-mml-node="TeXAtom" transform="translate(2714, 477.1) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-148-TEX-I-1D6FD"></use></g><g data-mml-node="mo" transform="translate(566, 0)"><use xlink:href="#MJX-148-TEX-N-2212"></use></g><g data-mml-node="mn" transform="translate(1344, 0)"><use xlink:href="#MJX-148-TEX-N-32"></use></g></g></g><g data-mml-node="mo" transform="translate(4762.1, 0)"><use xlink:href="#MJX-148-TEX-N-22C5"></use></g><g data-mml-node="mi" transform="translate(5262.3, 0)"><use xlink:href="#MJX-148-TEX-I-1D449"></use></g><g data-mml-node="mo" transform="translate(6031.3, 0)"><use xlink:href="#MJX-148-TEX-LO-5D"></use></g></g><g data-mml-node="msup" transform="translate(6503.3, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-148-TEX-I-1D43B"></use></g><g data-mml-node="mi" transform="translate(940.8, 363) scale(0.707)"><use xlink:href="#MJX-148-TEX-I-1D447"></use></g></g></g><g data-mml-node="mrow" transform="translate(1437.7, -935.6)"><g data-mml-node="msup"><g data-mml-node="mrow"><g data-mml-node="mo"><use xlink:href="#MJX-148-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(389, 0)"><use xlink:href="#MJX-148-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1437, 0)"><use xlink:href="#MJX-148-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(2325, 0)"><use xlink:href="#MJX-148-TEX-N-29"></use></g></g><g data-mml-node="TeXAtom" transform="translate(2714, 477.1) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-148-TEX-I-1D6FD"></use></g><g data-mml-node="mo" transform="translate(566, 0)"><use xlink:href="#MJX-148-TEX-N-2212"></use></g><g data-mml-node="mn" transform="translate(1344, 0)"><use xlink:href="#MJX-148-TEX-N-31"></use></g></g></g><g data-mml-node="msup" transform="translate(4067.9, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-148-TEX-I-1D43B"></use></g><g data-mml-node="mi" transform="translate(940.8, 289) scale(0.707)"><use xlink:href="#MJX-148-TEX-I-1D447"></use></g></g></g><rect width="8192" height="60" x="120" y="220"></rect></g></g></g></svg>

<h1 id="decoupling-the-derivative">Decoupling the Derivative</h1>

<p>The update weights are actually derived from the derivative of the criterion we choose respect to the parameter (<code class="language-plaintext highlighter-rouge">H</code> and <code class="language-plaintext highlighter-rouge">W</code>). Due to the property of Beta-Divergence, the derivative can be expressed as the difference of two nonnegative functions such that:</p>

<svg xmlns="http://www.w3.org/2000/svg" width="29.479ex" height="2.615ex" viewBox="0 -825.2 13029.9 1155.8" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" style=""><defs><path id="MJX-256-TEX-N-25BD" d="M59 480Q59 485 61 489T66 495T72 498L75 500H814Q828 493 828 480V474L644 132Q458 -210 455 -212Q451 -215 444 -215T433 -212Q429 -210 342 -49T164 282T64 466Q59 478 59 480ZM775 460H113Q113 459 278 153T444 -153T610 153T775 460Z"></path><path id="MJX-256-TEX-I-1D703" d="M35 200Q35 302 74 415T180 610T319 704Q320 704 327 704T339 705Q393 701 423 656Q462 596 462 495Q462 380 417 261T302 66T168 -10H161Q125 -10 99 10T60 63T41 130T35 200ZM383 566Q383 668 330 668Q294 668 260 623T204 521T170 421T157 371Q206 370 254 370L351 371Q352 372 359 404T375 484T383 566ZM113 132Q113 26 166 26Q181 26 198 36T239 74T287 161T335 307L340 324H145Q145 321 136 286T120 208T113 132Z"></path><path id="MJX-256-TEX-I-1D437" d="M287 628Q287 635 230 637Q207 637 200 638T193 647Q193 655 197 667T204 682Q206 683 403 683Q570 682 590 682T630 676Q702 659 752 597T803 431Q803 275 696 151T444 3L430 1L236 0H125H72Q48 0 41 2T33 11Q33 13 36 25Q40 41 44 43T67 46Q94 46 127 49Q141 52 146 61Q149 65 218 339T287 628ZM703 469Q703 507 692 537T666 584T629 613T590 629T555 636Q553 636 541 636T512 636T479 637H436Q392 637 386 627Q384 623 313 339T242 52Q242 48 253 48T330 47Q335 47 349 47T373 46Q499 46 581 128Q617 164 640 212T683 339T703 469Z"></path><path id="MJX-256-TEX-N-28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path><path id="MJX-256-TEX-N-29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path><path id="MJX-256-TEX-N-3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z"></path><path id="MJX-256-TEX-N-2B" d="M56 237T56 250T70 270H369V420L370 570Q380 583 389 583Q402 583 409 568V270H707Q722 262 722 250T707 230H409V-68Q401 -82 391 -82H389H387Q375 -82 369 -68V230H70Q56 237 56 250Z"></path><path id="MJX-256-TEX-N-2212" d="M84 237T84 250T98 270H679Q694 262 694 250T679 230H98Q84 237 84 250Z"></path></defs><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="msub"><g data-mml-node="mo"><use xlink:href="#MJX-256-TEX-N-25BD"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -150) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-256-TEX-I-1D703"></use></g></g></g><g data-mml-node="mi" transform="translate(1270.6, 0)"><use xlink:href="#MJX-256-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(2098.6, 0)"><use xlink:href="#MJX-256-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(2487.6, 0)"><use xlink:href="#MJX-256-TEX-I-1D703"></use></g><g data-mml-node="mo" transform="translate(2956.6, 0)"><use xlink:href="#MJX-256-TEX-N-29"></use></g><g data-mml-node="mo" transform="translate(3623.4, 0)"><use xlink:href="#MJX-256-TEX-N-3D"></use></g><g data-mml-node="msubsup" transform="translate(4679.2, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-256-TEX-N-25BD"></use></g><g data-mml-node="mo" transform="translate(889, 413) scale(0.707)"><use xlink:href="#MJX-256-TEX-N-2B"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -323.5) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-256-TEX-I-1D703"></use></g></g></g><g data-mml-node="mi" transform="translate(6168.3, 0)"><use xlink:href="#MJX-256-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(6996.3, 0)"><use xlink:href="#MJX-256-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(7385.3, 0)"><use xlink:href="#MJX-256-TEX-I-1D703"></use></g><g data-mml-node="mo" transform="translate(7854.3, 0)"><use xlink:href="#MJX-256-TEX-N-29"></use></g><g data-mml-node="mo" transform="translate(8465.5, 0)"><use xlink:href="#MJX-256-TEX-N-2212"></use></g><g data-mml-node="msubsup" transform="translate(9465.8, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-256-TEX-N-25BD"></use></g><g data-mml-node="mo" transform="translate(889, 413) scale(0.707)"><use xlink:href="#MJX-256-TEX-N-2212"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -323.5) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-256-TEX-I-1D703"></use></g></g></g><g data-mml-node="mi" transform="translate(10954.9, 0)"><use xlink:href="#MJX-256-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(11782.9, 0)"><use xlink:href="#MJX-256-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(12171.9, 0)"><use xlink:href="#MJX-256-TEX-I-1D703"></use></g><g data-mml-node="mo" transform="translate(12640.9, 0)"><use xlink:href="#MJX-256-TEX-N-29"></use></g></g></g></svg>

<p>Then, we can simply writes:</p>

<svg xmlns="http://www.w3.org/2000/svg" width="16.335ex" height="6.18ex" viewBox="0 -1615.8 7220.1 2731.6" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" style=""><defs><path id="MJX-498-TEX-I-1D703" d="M35 200Q35 302 74 415T180 610T319 704Q320 704 327 704T339 705Q393 701 423 656Q462 596 462 495Q462 380 417 261T302 66T168 -10H161Q125 -10 99 10T60 63T41 130T35 200ZM383 566Q383 668 330 668Q294 668 260 623T204 521T170 421T157 371Q206 370 254 370L351 371Q352 372 359 404T375 484T383 566ZM113 132Q113 26 166 26Q181 26 198 36T239 74T287 161T335 307L340 324H145Q145 321 136 286T120 208T113 132Z"></path><path id="MJX-498-TEX-N-2190" d="M944 261T944 250T929 230H165Q167 228 182 216T211 189T244 152T277 96T303 25Q308 7 308 0Q308 -11 288 -11Q281 -11 278 -11T272 -7T267 2T263 21Q245 94 195 151T73 236Q58 242 55 247Q55 254 59 257T73 264Q121 283 158 314T215 375T247 434T264 480L267 497Q269 503 270 505T275 509T288 511Q308 511 308 500Q308 493 303 475Q293 438 278 406T246 352T215 315T185 287T165 270H929Q944 261 944 250Z"></path><path id="MJX-498-TEX-N-22C5" d="M78 250Q78 274 95 292T138 310Q162 310 180 294T199 251Q199 226 182 208T139 190T96 207T78 250Z"></path><path id="MJX-498-TEX-N-25BD" d="M59 480Q59 485 61 489T66 495T72 498L75 500H814Q828 493 828 480V474L644 132Q458 -210 455 -212Q451 -215 444 -215T433 -212Q429 -210 342 -49T164 282T64 466Q59 478 59 480ZM775 460H113Q113 459 278 153T444 -153T610 153T775 460Z"></path><path id="MJX-498-TEX-N-2212" d="M84 237T84 250T98 270H679Q694 262 694 250T679 230H98Q84 237 84 250Z"></path><path id="MJX-498-TEX-I-1D437" d="M287 628Q287 635 230 637Q207 637 200 638T193 647Q193 655 197 667T204 682Q206 683 403 683Q570 682 590 682T630 676Q702 659 752 597T803 431Q803 275 696 151T444 3L430 1L236 0H125H72Q48 0 41 2T33 11Q33 13 36 25Q40 41 44 43T67 46Q94 46 127 49Q141 52 146 61Q149 65 218 339T287 628ZM703 469Q703 507 692 537T666 584T629 613T590 629T555 636Q553 636 541 636T512 636T479 637H436Q392 637 386 627Q384 623 313 339T242 52Q242 48 253 48T330 47Q335 47 349 47T373 46Q499 46 581 128Q617 164 640 212T683 339T703 469Z"></path><path id="MJX-498-TEX-N-28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path><path id="MJX-498-TEX-N-29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path><path id="MJX-498-TEX-N-2B" d="M56 237T56 250T70 270H369V420L370 570Q380 583 389 583Q402 583 409 568V270H707Q722 262 722 250T707 230H409V-68Q401 -82 391 -82H389H387Q375 -82 369 -68V230H70Q56 237 56 250Z"></path></defs><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="mi"><use xlink:href="#MJX-498-TEX-I-1D703"></use></g><g data-mml-node="mo" transform="translate(746.8, 0)"><use xlink:href="#MJX-498-TEX-N-2190"></use></g><g data-mml-node="mi" transform="translate(2024.6, 0)"><use xlink:href="#MJX-498-TEX-I-1D703"></use></g><g data-mml-node="mo" transform="translate(2715.8, 0)"><use xlink:href="#MJX-498-TEX-N-22C5"></use></g><g data-mml-node="mfrac" transform="translate(3216, 0)"><g data-mml-node="mrow" transform="translate(220, 792)"><g data-mml-node="msubsup"><g data-mml-node="mo"><use xlink:href="#MJX-498-TEX-N-25BD"></use></g><g data-mml-node="mo" transform="translate(889, 411.6) scale(0.707)"><use xlink:href="#MJX-498-TEX-N-2212"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -324.9) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-498-TEX-I-1D703"></use></g></g></g><g data-mml-node="mi" transform="translate(1489.1, 0)"><use xlink:href="#MJX-498-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(2317.1, 0)"><use xlink:href="#MJX-498-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(2706.1, 0)"><use xlink:href="#MJX-498-TEX-I-1D703"></use></g><g data-mml-node="mo" transform="translate(3175.1, 0)"><use xlink:href="#MJX-498-TEX-N-29"></use></g></g><g data-mml-node="mrow" transform="translate(220, -783.8)"><g data-mml-node="msubsup"><g data-mml-node="mo"><use xlink:href="#MJX-498-TEX-N-25BD"></use></g><g data-mml-node="mo" transform="translate(889, 411.6) scale(0.707)"><use xlink:href="#MJX-498-TEX-N-2B"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -324.9) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-498-TEX-I-1D703"></use></g></g></g><g data-mml-node="mi" transform="translate(1489.1, 0)"><use xlink:href="#MJX-498-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(2317.1, 0)"><use xlink:href="#MJX-498-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(2706.1, 0)"><use xlink:href="#MJX-498-TEX-I-1D703"></use></g><g data-mml-node="mo" transform="translate(3175.1, 0)"><use xlink:href="#MJX-498-TEX-N-29"></use></g></g><rect width="3764.1" height="60" x="120" y="220"></rect></g></g></g></svg>

<p>Following the chain rule, we can also decoupling the derivative respect to parameter as (take <code class="language-plaintext highlighter-rouge">H</code> for example):</p>

<svg xmlns="http://www.w3.org/2000/svg" width="32ex" height="2.262ex" viewBox="0 -750 14143.8 1000" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" style=""><defs><path id="MJX-1304-TEX-N-25BD" d="M59 480Q59 485 61 489T66 495T72 498L75 500H814Q828 493 828 480V474L644 132Q458 -210 455 -212Q451 -215 444 -215T433 -212Q429 -210 342 -49T164 282T64 466Q59 478 59 480ZM775 460H113Q113 459 278 153T444 -153T610 153T775 460Z"></path><path id="MJX-1304-TEX-I-1D43B" d="M228 637Q194 637 192 641Q191 643 191 649Q191 673 202 682Q204 683 219 683Q260 681 355 681Q389 681 418 681T463 682T483 682Q499 682 499 672Q499 670 497 658Q492 641 487 638H485Q483 638 480 638T473 638T464 637T455 637Q416 636 405 634T387 623Q384 619 355 500Q348 474 340 442T328 395L324 380Q324 378 469 378H614L615 381Q615 384 646 504Q674 619 674 627T617 637Q594 637 587 639T580 648Q580 650 582 660Q586 677 588 679T604 682Q609 682 646 681T740 680Q802 680 835 681T871 682Q888 682 888 672Q888 645 876 638H874Q872 638 869 638T862 638T853 637T844 637Q805 636 794 634T776 623Q773 618 704 340T634 58Q634 51 638 51Q646 48 692 46H723Q729 38 729 37T726 19Q722 6 716 0H701Q664 2 567 2Q533 2 504 2T458 2T437 1Q420 1 420 10Q420 15 423 24Q428 43 433 45Q437 46 448 46H454Q481 46 514 49Q520 50 522 50T528 55T534 64T540 82T547 110T558 153Q565 181 569 198Q602 330 602 331T457 332H312L279 197Q245 63 245 58Q245 51 253 49T303 46H334Q340 38 340 37T337 19Q333 6 327 0H312Q275 2 178 2Q144 2 115 2T69 2T48 1Q31 1 31 10Q31 12 34 24Q39 43 44 45Q48 46 59 46H65Q92 46 125 49Q139 52 144 61Q147 65 216 339T285 628Q285 635 228 637Z"></path><path id="MJX-1304-TEX-I-1D437" d="M287 628Q287 635 230 637Q207 637 200 638T193 647Q193 655 197 667T204 682Q206 683 403 683Q570 682 590 682T630 676Q702 659 752 597T803 431Q803 275 696 151T444 3L430 1L236 0H125H72Q48 0 41 2T33 11Q33 13 36 25Q40 41 44 43T67 46Q94 46 127 49Q141 52 146 61Q149 65 218 339T287 628ZM703 469Q703 507 692 537T666 584T629 613T590 629T555 636Q553 636 541 636T512 636T479 637H436Q392 637 386 627Q384 623 313 339T242 52Q242 48 253 48T330 47Q335 47 349 47T373 46Q499 46 581 128Q617 164 640 212T683 339T703 469Z"></path><path id="MJX-1304-TEX-N-28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path><path id="MJX-1304-TEX-N-29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path><path id="MJX-1304-TEX-N-3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z"></path><path id="MJX-1304-TEX-I-1D44A" d="M436 683Q450 683 486 682T553 680Q604 680 638 681T677 682Q695 682 695 674Q695 670 692 659Q687 641 683 639T661 637Q636 636 621 632T600 624T597 615Q597 603 613 377T629 138L631 141Q633 144 637 151T649 170T666 200T690 241T720 295T759 362Q863 546 877 572T892 604Q892 619 873 628T831 637Q817 637 817 647Q817 650 819 660Q823 676 825 679T839 682Q842 682 856 682T895 682T949 681Q1015 681 1034 683Q1048 683 1048 672Q1048 666 1045 655T1038 640T1028 637Q1006 637 988 631T958 617T939 600T927 584L923 578L754 282Q586 -14 585 -15Q579 -22 561 -22Q546 -22 542 -17Q539 -14 523 229T506 480L494 462Q472 425 366 239Q222 -13 220 -15T215 -19Q210 -22 197 -22Q178 -22 176 -15Q176 -12 154 304T131 622Q129 631 121 633T82 637H58Q51 644 51 648Q52 671 64 683H76Q118 680 176 680Q301 680 313 683H323Q329 677 329 674T327 656Q322 641 318 637H297Q236 634 232 620Q262 160 266 136L501 550L499 587Q496 629 489 632Q483 636 447 637Q428 637 422 639T416 648Q416 650 418 660Q419 664 420 669T421 676T424 680T428 682T436 683Z"></path></defs><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="mtable"><g data-mml-node="mtr"><g data-mml-node="mtd"><g data-mml-node="msub"><g data-mml-node="mo"><use xlink:href="#MJX-1304-TEX-N-25BD"></use></g><g data-mml-node="mi" transform="translate(889, -150) scale(0.707)"><use xlink:href="#MJX-1304-TEX-I-1D43B"></use></g></g><g data-mml-node="mi" transform="translate(1566.9, 0)"><use xlink:href="#MJX-1304-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(2394.9, 0)"><use xlink:href="#MJX-1304-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(2783.9, 0)"><use xlink:href="#MJX-1304-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(3671.9, 0)"><use xlink:href="#MJX-1304-TEX-N-29"></use></g><g data-mml-node="mo" transform="translate(4338.7, 0)"><use xlink:href="#MJX-1304-TEX-N-3D"></use></g><g data-mml-node="msub" transform="translate(5394.5, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1304-TEX-N-25BD"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -150) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-1304-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1048, 0)"><use xlink:href="#MJX-1304-TEX-I-1D43B"></use></g></g></g><g data-mml-node="mi" transform="translate(7702.4, 0)"><use xlink:href="#MJX-1304-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(8530.4, 0)"><use xlink:href="#MJX-1304-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(8919.4, 0)"><use xlink:href="#MJX-1304-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(9807.4, 0)"><use xlink:href="#MJX-1304-TEX-N-29"></use></g><g data-mml-node="msub" transform="translate(10418.6, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1304-TEX-N-25BD"></use></g><g data-mml-node="mi" transform="translate(889, -150) scale(0.707)"><use xlink:href="#MJX-1304-TEX-I-1D43B"></use></g></g><g data-mml-node="mi" transform="translate(12207.8, 0)"><use xlink:href="#MJX-1304-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(13255.8, 0)"><use xlink:href="#MJX-1304-TEX-I-1D43B"></use></g></g></g></g></g></g></svg>

<p>The derivative of <code class="language-plaintext highlighter-rouge">WH</code> respect to <code class="language-plaintext highlighter-rouge">H</code> is <code class="language-plaintext highlighter-rouge">W^T</code>, which is always non-negative, so the ability to decouple into two non-negative functions is actually comes from Beta-Divergence itself.</p>

<svg xmlns="http://www.w3.org/2000/svg" width="36.364ex" height="12.361ex" viewBox="0 -2981.8 16073 5463.6" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" style=""><defs><path id="MJX-1242-TEX-N-25BD" d="M59 480Q59 485 61 489T66 495T72 498L75 500H814Q828 493 828 480V474L644 132Q458 -210 455 -212Q451 -215 444 -215T433 -212Q429 -210 342 -49T164 282T64 466Q59 478 59 480ZM775 460H113Q113 459 278 153T444 -153T610 153T775 460Z"></path><path id="MJX-1242-TEX-I-1D44A" d="M436 683Q450 683 486 682T553 680Q604 680 638 681T677 682Q695 682 695 674Q695 670 692 659Q687 641 683 639T661 637Q636 636 621 632T600 624T597 615Q597 603 613 377T629 138L631 141Q633 144 637 151T649 170T666 200T690 241T720 295T759 362Q863 546 877 572T892 604Q892 619 873 628T831 637Q817 637 817 647Q817 650 819 660Q823 676 825 679T839 682Q842 682 856 682T895 682T949 681Q1015 681 1034 683Q1048 683 1048 672Q1048 666 1045 655T1038 640T1028 637Q1006 637 988 631T958 617T939 600T927 584L923 578L754 282Q586 -14 585 -15Q579 -22 561 -22Q546 -22 542 -17Q539 -14 523 229T506 480L494 462Q472 425 366 239Q222 -13 220 -15T215 -19Q210 -22 197 -22Q178 -22 176 -15Q176 -12 154 304T131 622Q129 631 121 633T82 637H58Q51 644 51 648Q52 671 64 683H76Q118 680 176 680Q301 680 313 683H323Q329 677 329 674T327 656Q322 641 318 637H297Q236 634 232 620Q262 160 266 136L501 550L499 587Q496 629 489 632Q483 636 447 637Q428 637 422 639T416 648Q416 650 418 660Q419 664 420 669T421 676T424 680T428 682T436 683Z"></path><path id="MJX-1242-TEX-I-1D43B" d="M228 637Q194 637 192 641Q191 643 191 649Q191 673 202 682Q204 683 219 683Q260 681 355 681Q389 681 418 681T463 682T483 682Q499 682 499 672Q499 670 497 658Q492 641 487 638H485Q483 638 480 638T473 638T464 637T455 637Q416 636 405 634T387 623Q384 619 355 500Q348 474 340 442T328 395L324 380Q324 378 469 378H614L615 381Q615 384 646 504Q674 619 674 627T617 637Q594 637 587 639T580 648Q580 650 582 660Q586 677 588 679T604 682Q609 682 646 681T740 680Q802 680 835 681T871 682Q888 682 888 672Q888 645 876 638H874Q872 638 869 638T862 638T853 637T844 637Q805 636 794 634T776 623Q773 618 704 340T634 58Q634 51 638 51Q646 48 692 46H723Q729 38 729 37T726 19Q722 6 716 0H701Q664 2 567 2Q533 2 504 2T458 2T437 1Q420 1 420 10Q420 15 423 24Q428 43 433 45Q437 46 448 46H454Q481 46 514 49Q520 50 522 50T528 55T534 64T540 82T547 110T558 153Q565 181 569 198Q602 330 602 331T457 332H312L279 197Q245 63 245 58Q245 51 253 49T303 46H334Q340 38 340 37T337 19Q333 6 327 0H312Q275 2 178 2Q144 2 115 2T69 2T48 1Q31 1 31 10Q31 12 34 24Q39 43 44 45Q48 46 59 46H65Q92 46 125 49Q139 52 144 61Q147 65 216 339T285 628Q285 635 228 637Z"></path><path id="MJX-1242-TEX-I-1D437" d="M287 628Q287 635 230 637Q207 637 200 638T193 647Q193 655 197 667T204 682Q206 683 403 683Q570 682 590 682T630 676Q702 659 752 597T803 431Q803 275 696 151T444 3L430 1L236 0H125H72Q48 0 41 2T33 11Q33 13 36 25Q40 41 44 43T67 46Q94 46 127 49Q141 52 146 61Q149 65 218 339T287 628ZM703 469Q703 507 692 537T666 584T629 613T590 629T555 636Q553 636 541 636T512 636T479 637H436Q392 637 386 627Q384 623 313 339T242 52Q242 48 253 48T330 47Q335 47 349 47T373 46Q499 46 581 128Q617 164 640 212T683 339T703 469Z"></path><path id="MJX-1242-TEX-N-28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path><path id="MJX-1242-TEX-N-29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path><path id="MJX-1242-TEX-N-3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z"></path><path id="MJX-1242-TEX-I-1D447" d="M40 437Q21 437 21 445Q21 450 37 501T71 602L88 651Q93 669 101 677H569H659Q691 677 697 676T704 667Q704 661 687 553T668 444Q668 437 649 437Q640 437 637 437T631 442L629 445Q629 451 635 490T641 551Q641 586 628 604T573 629Q568 630 515 631Q469 631 457 630T439 622Q438 621 368 343T298 60Q298 48 386 46Q418 46 427 45T436 36Q436 31 433 22Q429 4 424 1L422 0Q419 0 415 0Q410 0 363 1T228 2Q99 2 64 0H49Q43 6 43 9T45 27Q49 40 55 46H83H94Q174 46 189 55Q190 56 191 56Q196 59 201 76T241 233Q258 301 269 344Q339 619 339 625Q339 630 310 630H279Q212 630 191 624Q146 614 121 583T67 467Q60 445 57 441T43 437H40Z"></path><path id="MJX-1242-TEX-N-2B" d="M56 237T56 250T70 270H369V420L370 570Q380 583 389 583Q402 583 409 568V270H707Q722 262 722 250T707 230H409V-68Q401 -82 391 -82H389H387Q375 -82 369 -68V230H70Q56 237 56 250Z"></path><path id="MJX-1242-TEX-N-2212" d="M84 237T84 250T98 270H679Q694 262 694 250T679 230H98Q84 237 84 250Z"></path></defs><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="mtable"><g data-mml-node="mtr" transform="translate(0, 2231.8)"><g data-mml-node="mtd" transform="translate(7323.7, 0)"><g data-mml-node="msub"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-25BD"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -150) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-1242-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1048, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g></g></g><g data-mml-node="mi" transform="translate(2308, 0)"><use xlink:href="#MJX-1242-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(3136, 0)"><use xlink:href="#MJX-1242-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(3525, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(4413, 0)"><use xlink:href="#MJX-1242-TEX-N-29"></use></g><g data-mml-node="msub" transform="translate(5024.2, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-25BD"></use></g><g data-mml-node="mi" transform="translate(889, -150) scale(0.707)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g></g><g data-mml-node="mi" transform="translate(6813.3, 0)"><use xlink:href="#MJX-1242-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(7861.3, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g></g><g data-mml-node="mtd" transform="translate(16073, 0)"></g></g><g data-mml-node="mtr" transform="translate(0, 790.1)"><g data-mml-node="mtd" transform="translate(1761.9, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-3D"></use></g><g data-mml-node="msup" transform="translate(1055.8, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-1242-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1103.2, 413) scale(0.707)"><use xlink:href="#MJX-1242-TEX-I-1D447"></use></g></g><g data-mml-node="mo" transform="translate(2706.8, 0)"><use xlink:href="#MJX-1242-TEX-N-28"></use></g><g data-mml-node="msubsup" transform="translate(3095.8, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-25BD"></use></g><g data-mml-node="mo" transform="translate(889, 413) scale(0.707)"><use xlink:href="#MJX-1242-TEX-N-2B"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -307.9) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-1242-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1048, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g></g></g><g data-mml-node="mi" transform="translate(5403.7, 0)"><use xlink:href="#MJX-1242-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(6231.7, 0)"><use xlink:href="#MJX-1242-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(6620.7, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(7508.7, 0)"><use xlink:href="#MJX-1242-TEX-N-29"></use></g><g data-mml-node="mo" transform="translate(8120, 0)"><use xlink:href="#MJX-1242-TEX-N-2212"></use></g><g data-mml-node="msubsup" transform="translate(9120.2, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-25BD"></use></g><g data-mml-node="mo" transform="translate(889, 413) scale(0.707)"><use xlink:href="#MJX-1242-TEX-N-2212"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -307.9) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-1242-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1048, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g></g></g><g data-mml-node="mi" transform="translate(11428.1, 0)"><use xlink:href="#MJX-1242-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(12256.1, 0)"><use xlink:href="#MJX-1242-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(12645.1, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(13533.1, 0)"><use xlink:href="#MJX-1242-TEX-N-29"></use></g><g data-mml-node="mo" transform="translate(13922.1, 0)"><use xlink:href="#MJX-1242-TEX-N-29"></use></g></g><g data-mml-node="mtd" transform="translate(16073, 0)"></g></g><g data-mml-node="mtr" transform="translate(0, -725.1)"><g data-mml-node="mtd"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-3D"></use></g><g data-mml-node="msup" transform="translate(1055.8, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-1242-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1103.2, 413) scale(0.707)"><use xlink:href="#MJX-1242-TEX-I-1D447"></use></g></g><g data-mml-node="msubsup" transform="translate(2929, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-25BD"></use></g><g data-mml-node="mo" transform="translate(889, 413) scale(0.707)"><use xlink:href="#MJX-1242-TEX-N-2B"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -307.9) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-1242-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1048, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g></g></g><g data-mml-node="mi" transform="translate(5459.2, 0)"><use xlink:href="#MJX-1242-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(6287.2, 0)"><use xlink:href="#MJX-1242-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(6676.2, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(7564.2, 0)"><use xlink:href="#MJX-1242-TEX-N-29"></use></g><g data-mml-node="mo" transform="translate(8175.4, 0)"><use xlink:href="#MJX-1242-TEX-N-2212"></use></g><g data-mml-node="msup" transform="translate(9175.6, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-1242-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1103.2, 413) scale(0.707)"><use xlink:href="#MJX-1242-TEX-I-1D447"></use></g></g><g data-mml-node="msubsup" transform="translate(11048.9, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-25BD"></use></g><g data-mml-node="mo" transform="translate(889, 413) scale(0.707)"><use xlink:href="#MJX-1242-TEX-N-2212"></use></g><g data-mml-node="TeXAtom" transform="translate(889, -307.9) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-1242-TEX-I-1D44A"></use></g><g data-mml-node="mi" transform="translate(1048, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g></g></g><g data-mml-node="mi" transform="translate(13579, 0)"><use xlink:href="#MJX-1242-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(14407, 0)"><use xlink:href="#MJX-1242-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(14796, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(15684, 0)"><use xlink:href="#MJX-1242-TEX-N-29"></use></g></g><g data-mml-node="mtd" transform="translate(16073, 0)"></g></g><g data-mml-node="mtr" transform="translate(0, -2173.9)"><g data-mml-node="mtd" transform="translate(5673, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-3D"></use></g><g data-mml-node="msubsup" transform="translate(1055.8, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-25BD"></use></g><g data-mml-node="mo" transform="translate(889, 413) scale(0.707)"><use xlink:href="#MJX-1242-TEX-N-2B"></use></g><g data-mml-node="mi" transform="translate(889, -307.9) scale(0.707)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g></g><g data-mml-node="mi" transform="translate(2622.7, 0)"><use xlink:href="#MJX-1242-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(3450.7, 0)"><use xlink:href="#MJX-1242-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(3839.7, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(4727.7, 0)"><use xlink:href="#MJX-1242-TEX-N-29"></use></g><g data-mml-node="mo" transform="translate(5338.9, 0)"><use xlink:href="#MJX-1242-TEX-N-2212"></use></g><g data-mml-node="msubsup" transform="translate(6339.1, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-1242-TEX-N-25BD"></use></g><g data-mml-node="mo" transform="translate(889, 413) scale(0.707)"><use xlink:href="#MJX-1242-TEX-N-2212"></use></g><g data-mml-node="mi" transform="translate(889, -307.9) scale(0.707)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g></g><g data-mml-node="mi" transform="translate(7906, 0)"><use xlink:href="#MJX-1242-TEX-I-1D437"></use></g><g data-mml-node="mo" transform="translate(8734, 0)"><use xlink:href="#MJX-1242-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(9123, 0)"><use xlink:href="#MJX-1242-TEX-I-1D43B"></use></g><g data-mml-node="mo" transform="translate(10011, 0)"><use xlink:href="#MJX-1242-TEX-N-29"></use></g></g></g></g></g></g></svg>

<p>The above steps can be applied on W as well.</p>

<h2 id="derivative-of-beta-divergence">Derivative of Beta-Divergence</h2>

<p>The form of Beta-Divergence is:</p>

<svg xmlns="http://www.w3.org/2000/svg" width="48.439ex" height="5.208ex" viewBox="0 -1342 21410.1 2302" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" style=""><defs><path id="MJX-2294-TEX-I-1D451" d="M366 683Q367 683 438 688T511 694Q523 694 523 686Q523 679 450 384T375 83T374 68Q374 26 402 26Q411 27 422 35Q443 55 463 131Q469 151 473 152Q475 153 483 153H487H491Q506 153 506 145Q506 140 503 129Q490 79 473 48T445 8T417 -8Q409 -10 393 -10Q359 -10 336 5T306 36L300 51Q299 52 296 50Q294 48 292 46Q233 -10 172 -10Q117 -10 75 30T33 157Q33 205 53 255T101 341Q148 398 195 420T280 442Q336 442 364 400Q369 394 369 396Q370 400 396 505T424 616Q424 629 417 632T378 637H357Q351 643 351 645T353 664Q358 683 366 683ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path><path id="MJX-2294-TEX-I-1D6FD" d="M29 -194Q23 -188 23 -186Q23 -183 102 134T186 465Q208 533 243 584T309 658Q365 705 429 705H431Q493 705 533 667T573 570Q573 465 469 396L482 383Q533 332 533 252Q533 139 448 65T257 -10Q227 -10 203 -2T165 17T143 40T131 59T126 65L62 -188Q60 -194 42 -194H29ZM353 431Q392 431 427 419L432 422Q436 426 439 429T449 439T461 453T472 471T484 495T493 524T501 560Q503 569 503 593Q503 611 502 616Q487 667 426 667Q384 667 347 643T286 582T247 514T224 455Q219 439 186 308T152 168Q151 163 151 147Q151 99 173 68Q204 26 260 26Q302 26 349 51T425 137Q441 171 449 214T457 279Q457 337 422 372Q380 358 347 358H337Q258 358 258 389Q258 396 261 403Q275 431 353 431Z"></path><path id="MJX-2294-TEX-N-28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path><path id="MJX-2294-TEX-I-1D449" d="M52 648Q52 670 65 683H76Q118 680 181 680Q299 680 320 683H330Q336 677 336 674T334 656Q329 641 325 637H304Q282 635 274 635Q245 630 242 620Q242 618 271 369T301 118L374 235Q447 352 520 471T595 594Q599 601 599 609Q599 633 555 637Q537 637 537 648Q537 649 539 661Q542 675 545 679T558 683Q560 683 570 683T604 682T668 681Q737 681 755 683H762Q769 676 769 672Q769 655 760 640Q757 637 743 637Q730 636 719 635T698 630T682 623T670 615T660 608T652 599T645 592L452 282Q272 -9 266 -16Q263 -18 259 -21L241 -22H234Q216 -22 216 -15Q213 -9 177 305Q139 623 138 626Q133 637 76 637H59Q52 642 52 648Z"></path><path id="MJX-2294-TEX-N-7C" d="M139 -249H137Q125 -249 119 -235V251L120 737Q130 750 139 750Q152 750 159 735V-235Q151 -249 141 -249H139Z"></path><path id="MJX-2294-TEX-I-1D443" d="M287 628Q287 635 230 637Q206 637 199 638T192 648Q192 649 194 659Q200 679 203 681T397 683Q587 682 600 680Q664 669 707 631T751 530Q751 453 685 389Q616 321 507 303Q500 302 402 301H307L277 182Q247 66 247 59Q247 55 248 54T255 50T272 48T305 46H336Q342 37 342 35Q342 19 335 5Q330 0 319 0Q316 0 282 1T182 2Q120 2 87 2T51 1Q33 1 33 11Q33 13 36 25Q40 41 44 43T67 46Q94 46 127 49Q141 52 146 61Q149 65 218 339T287 628ZM645 554Q645 567 643 575T634 597T609 619T560 635Q553 636 480 637Q463 637 445 637T416 636T404 636Q391 635 386 627Q384 621 367 550T332 412T314 344Q314 342 395 342H407H430Q542 342 590 392Q617 419 631 471T645 554Z"></path><path id="MJX-2294-TEX-N-29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path><path id="MJX-2294-TEX-N-3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z"></path><path id="MJX-2294-TEX-N-31" d="M213 578L200 573Q186 568 160 563T102 556H83V602H102Q149 604 189 617T245 641T273 663Q275 666 285 666Q294 666 302 660V361L303 61Q310 54 315 52T339 48T401 46H427V0H416Q395 3 257 3Q121 3 100 0H88V46H114Q136 46 152 46T177 47T193 50T201 52T207 57T213 61V578Z"></path><path id="MJX-2294-TEX-N-2212" d="M84 237T84 250T98 270H679Q694 262 694 250T679 230H98Q84 237 84 250Z"></path><path id="MJX-2294-TEX-SO-28" d="M152 251Q152 646 388 850H416Q422 844 422 841Q422 837 403 816T357 753T302 649T255 482T236 250Q236 124 255 19T301 -147T356 -251T403 -315T422 -340Q422 -343 416 -349H388Q359 -325 332 -296T271 -213T212 -97T170 56T152 251Z"></path><path id="MJX-2294-TEX-N-2B" d="M56 237T56 250T70 270H369V420L370 570Q380 583 389 583Q402 583 409 568V270H707Q722 262 722 250T707 230H409V-68Q401 -82 391 -82H389H387Q375 -82 369 -68V230H70Q56 237 56 250Z"></path><path id="MJX-2294-TEX-SO-29" d="M305 251Q305 -145 69 -349H56Q43 -349 39 -347T35 -338Q37 -333 60 -307T108 -239T160 -136T204 27T221 250T204 473T160 636T108 740T60 807T35 839Q35 850 50 850H56H69Q197 743 256 566Q305 425 305 251Z"></path></defs><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="msub"><g data-mml-node="mi"><use xlink:href="#MJX-2294-TEX-I-1D451"></use></g><g data-mml-node="TeXAtom" transform="translate(520, -150) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-2294-TEX-I-1D6FD"></use></g></g></g><g data-mml-node="mo" transform="translate(970.2, 0)"><use xlink:href="#MJX-2294-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(1359.2, 0)"><use xlink:href="#MJX-2294-TEX-I-1D449"></use></g><g data-mml-node="TeXAtom" data-mjx-texclass="ORD" transform="translate(2128.2, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-2294-TEX-N-7C"></use></g></g><g data-mml-node="mi" transform="translate(2406.2, 0)"><use xlink:href="#MJX-2294-TEX-I-1D443"></use></g><g data-mml-node="mo" transform="translate(3157.2, 0)"><use xlink:href="#MJX-2294-TEX-N-29"></use></g><g data-mml-node="mo" transform="translate(3824, 0)"><use xlink:href="#MJX-2294-TEX-N-3D"></use></g><g data-mml-node="mfrac" transform="translate(4879.8, 0)"><g data-mml-node="mn" transform="translate(1786.2, 676)"><use xlink:href="#MJX-2294-TEX-N-31"></use></g><g data-mml-node="mrow" transform="translate(220, -710)"><g data-mml-node="mi"><use xlink:href="#MJX-2294-TEX-I-1D6FD"></use></g><g data-mml-node="mo" transform="translate(566, 0)"><use xlink:href="#MJX-2294-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(955, 0)"><use xlink:href="#MJX-2294-TEX-I-1D6FD"></use></g><g data-mml-node="mo" transform="translate(1743.2, 0)"><use xlink:href="#MJX-2294-TEX-N-2212"></use></g><g data-mml-node="mn" transform="translate(2743.4, 0)"><use xlink:href="#MJX-2294-TEX-N-31"></use></g><g data-mml-node="mo" transform="translate(3243.4, 0)"><use xlink:href="#MJX-2294-TEX-N-29"></use></g></g><rect width="3832.4" height="60" x="120" y="220"></rect></g><g data-mml-node="mrow" transform="translate(8952.2, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-2294-TEX-SO-28"></use></g><g data-mml-node="msup" transform="translate(458, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-2294-TEX-I-1D449"></use></g><g data-mml-node="TeXAtom" transform="translate(828.3, 413) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-2294-TEX-I-1D6FD"></use></g></g></g><g data-mml-node="mo" transform="translate(1958.7, 0)"><use xlink:href="#MJX-2294-TEX-N-2B"></use></g><g data-mml-node="mrow" transform="translate(2959, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-2294-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(389, 0)"><use xlink:href="#MJX-2294-TEX-I-1D6FD"></use></g><g data-mml-node="mo" transform="translate(1177.2, 0)"><use xlink:href="#MJX-2294-TEX-N-2212"></use></g><g data-mml-node="mn" transform="translate(2177.4, 0)"><use xlink:href="#MJX-2294-TEX-N-31"></use></g><g data-mml-node="mo" transform="translate(2677.4, 0)"><use xlink:href="#MJX-2294-TEX-N-29"></use></g></g><g data-mml-node="msup" transform="translate(6025.4, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-2294-TEX-I-1D443"></use></g><g data-mml-node="TeXAtom" transform="translate(806.5, 413) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-2294-TEX-I-1D6FD"></use></g></g></g><g data-mml-node="mo" transform="translate(7504.3, 0)"><use xlink:href="#MJX-2294-TEX-N-2212"></use></g><g data-mml-node="mi" transform="translate(8504.5, 0)"><use xlink:href="#MJX-2294-TEX-I-1D6FD"></use></g><g data-mml-node="mi" transform="translate(9070.5, 0)"><use xlink:href="#MJX-2294-TEX-I-1D449"></use></g><g data-mml-node="msup" transform="translate(9839.5, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-2294-TEX-I-1D443"></use></g><g data-mml-node="TeXAtom" transform="translate(806.5, 413) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-2294-TEX-I-1D6FD"></use></g><g data-mml-node="mo" transform="translate(566, 0)"><use xlink:href="#MJX-2294-TEX-N-2212"></use></g><g data-mml-node="mn" transform="translate(1344, 0)"><use xlink:href="#MJX-2294-TEX-N-31"></use></g></g></g><g data-mml-node="mo" transform="translate(11999.9, 0)"><use xlink:href="#MJX-2294-TEX-SO-29"></use></g></g></g></g></svg>

<p>where <code class="language-plaintext highlighter-rouge">P = WH</code> and its derivative respect to <code class="language-plaintext highlighter-rouge">P</code>:</p>

<svg xmlns="http://www.w3.org/2000/svg" width="28.647ex" height="2.712ex" viewBox="0 -911.5 12662 1198.7" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" style=""><defs><path id="MJX-2388-TEX-N-25BD" d="M59 480Q59 485 61 489T66 495T72 498L75 500H814Q828 493 828 480V474L644 132Q458 -210 455 -212Q451 -215 444 -215T433 -212Q429 -210 342 -49T164 282T64 466Q59 478 59 480ZM775 460H113Q113 459 278 153T444 -153T610 153T775 460Z"></path><path id="MJX-2388-TEX-I-1D443" d="M287 628Q287 635 230 637Q206 637 199 638T192 648Q192 649 194 659Q200 679 203 681T397 683Q587 682 600 680Q664 669 707 631T751 530Q751 453 685 389Q616 321 507 303Q500 302 402 301H307L277 182Q247 66 247 59Q247 55 248 54T255 50T272 48T305 46H336Q342 37 342 35Q342 19 335 5Q330 0 319 0Q316 0 282 1T182 2Q120 2 87 2T51 1Q33 1 33 11Q33 13 36 25Q40 41 44 43T67 46Q94 46 127 49Q141 52 146 61Q149 65 218 339T287 628ZM645 554Q645 567 643 575T634 597T609 619T560 635Q553 636 480 637Q463 637 445 637T416 636T404 636Q391 635 386 627Q384 621 367 550T332 412T314 344Q314 342 395 342H407H430Q542 342 590 392Q617 419 631 471T645 554Z"></path><path id="MJX-2388-TEX-I-1D451" d="M366 683Q367 683 438 688T511 694Q523 694 523 686Q523 679 450 384T375 83T374 68Q374 26 402 26Q411 27 422 35Q443 55 463 131Q469 151 473 152Q475 153 483 153H487H491Q506 153 506 145Q506 140 503 129Q490 79 473 48T445 8T417 -8Q409 -10 393 -10Q359 -10 336 5T306 36L300 51Q299 52 296 50Q294 48 292 46Q233 -10 172 -10Q117 -10 75 30T33 157Q33 205 53 255T101 341Q148 398 195 420T280 442Q336 442 364 400Q369 394 369 396Q370 400 396 505T424 616Q424 629 417 632T378 637H357Q351 643 351 645T353 664Q358 683 366 683ZM352 326Q329 405 277 405Q242 405 210 374T160 293Q131 214 119 129Q119 126 119 118T118 106Q118 61 136 44T179 26Q233 26 290 98L298 109L352 326Z"></path><path id="MJX-2388-TEX-I-1D6FD" d="M29 -194Q23 -188 23 -186Q23 -183 102 134T186 465Q208 533 243 584T309 658Q365 705 429 705H431Q493 705 533 667T573 570Q573 465 469 396L482 383Q533 332 533 252Q533 139 448 65T257 -10Q227 -10 203 -2T165 17T143 40T131 59T126 65L62 -188Q60 -194 42 -194H29ZM353 431Q392 431 427 419L432 422Q436 426 439 429T449 439T461 453T472 471T484 495T493 524T501 560Q503 569 503 593Q503 611 502 616Q487 667 426 667Q384 667 347 643T286 582T247 514T224 455Q219 439 186 308T152 168Q151 163 151 147Q151 99 173 68Q204 26 260 26Q302 26 349 51T425 137Q441 171 449 214T457 279Q457 337 422 372Q380 358 347 358H337Q258 358 258 389Q258 396 261 403Q275 431 353 431Z"></path><path id="MJX-2388-TEX-N-28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z"></path><path id="MJX-2388-TEX-I-1D449" d="M52 648Q52 670 65 683H76Q118 680 181 680Q299 680 320 683H330Q336 677 336 674T334 656Q329 641 325 637H304Q282 635 274 635Q245 630 242 620Q242 618 271 369T301 118L374 235Q447 352 520 471T595 594Q599 601 599 609Q599 633 555 637Q537 637 537 648Q537 649 539 661Q542 675 545 679T558 683Q560 683 570 683T604 682T668 681Q737 681 755 683H762Q769 676 769 672Q769 655 760 640Q757 637 743 637Q730 636 719 635T698 630T682 623T670 615T660 608T652 599T645 592L452 282Q272 -9 266 -16Q263 -18 259 -21L241 -22H234Q216 -22 216 -15Q213 -9 177 305Q139 623 138 626Q133 637 76 637H59Q52 642 52 648Z"></path><path id="MJX-2388-TEX-N-7C" d="M139 -249H137Q125 -249 119 -235V251L120 737Q130 750 139 750Q152 750 159 735V-235Q151 -249 141 -249H139Z"></path><path id="MJX-2388-TEX-N-29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z"></path><path id="MJX-2388-TEX-N-221D" d="M56 124T56 216T107 375T238 442Q260 442 280 438T319 425T352 407T382 385T406 361T427 336T442 315T455 297T462 285L469 297Q555 442 679 442Q687 442 722 437V398H718Q710 400 694 400Q657 400 623 383T567 343T527 294T503 253T495 235Q495 231 520 192T554 143Q625 44 696 44Q717 44 719 46H722V-5Q695 -11 678 -11Q552 -11 457 141Q455 145 454 146L447 134Q362 -11 235 -11Q157 -11 107 56ZM93 213Q93 143 126 87T220 31Q258 31 292 48T349 88T389 137T413 178T421 196Q421 200 396 239T362 288Q322 345 288 366T213 387Q163 387 128 337T93 213Z"></path><path id="MJX-2388-TEX-N-2212" d="M84 237T84 250T98 270H679Q694 262 694 250T679 230H98Q84 237 84 250Z"></path><path id="MJX-2388-TEX-N-31" d="M213 578L200 573Q186 568 160 563T102 556H83V602H102Q149 604 189 617T245 641T273 663Q275 666 285 666Q294 666 302 660V361L303 61Q310 54 315 52T339 48T401 46H427V0H416Q395 3 257 3Q121 3 100 0H88V46H114Q136 46 152 46T177 47T193 50T201 52T207 57T213 61V578Z"></path><path id="MJX-2388-TEX-N-32" d="M109 429Q82 429 66 447T50 491Q50 562 103 614T235 666Q326 666 387 610T449 465Q449 422 429 383T381 315T301 241Q265 210 201 149L142 93L218 92Q375 92 385 97Q392 99 409 186V189H449V186Q448 183 436 95T421 3V0H50V19V31Q50 38 56 46T86 81Q115 113 136 137Q145 147 170 174T204 211T233 244T261 278T284 308T305 340T320 369T333 401T340 431T343 464Q343 527 309 573T212 619Q179 619 154 602T119 569T109 550Q109 549 114 549Q132 549 151 535T170 489Q170 464 154 447T109 429Z"></path></defs><g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><g data-mml-node="math"><g data-mml-node="msub"><g data-mml-node="mo"><use xlink:href="#MJX-2388-TEX-N-25BD"></use></g><g data-mml-node="mi" transform="translate(889, -150) scale(0.707)"><use xlink:href="#MJX-2388-TEX-I-1D443"></use></g></g><g data-mml-node="msub" transform="translate(1470, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-2388-TEX-I-1D451"></use></g><g data-mml-node="TeXAtom" transform="translate(520, -150) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-2388-TEX-I-1D6FD"></use></g></g></g><g data-mml-node="mo" transform="translate(2440.3, 0)"><use xlink:href="#MJX-2388-TEX-N-28"></use></g><g data-mml-node="mi" transform="translate(2829.3, 0)"><use xlink:href="#MJX-2388-TEX-I-1D449"></use></g><g data-mml-node="TeXAtom" data-mjx-texclass="ORD" transform="translate(3598.3, 0)"><g data-mml-node="mo"><use xlink:href="#MJX-2388-TEX-N-7C"></use></g></g><g data-mml-node="mi" transform="translate(3876.3, 0)"><use xlink:href="#MJX-2388-TEX-I-1D443"></use></g><g data-mml-node="mo" transform="translate(4627.3, 0)"><use xlink:href="#MJX-2388-TEX-N-29"></use></g><g data-mml-node="mo" transform="translate(5294, 0)"><use xlink:href="#MJX-2388-TEX-N-221D"></use></g><g data-mml-node="msup" transform="translate(6349.8, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-2388-TEX-I-1D443"></use></g><g data-mml-node="TeXAtom" transform="translate(806.5, 413) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-2388-TEX-I-1D6FD"></use></g><g data-mml-node="mo" transform="translate(566, 0)"><use xlink:href="#MJX-2388-TEX-N-2212"></use></g><g data-mml-node="mn" transform="translate(1344, 0)"><use xlink:href="#MJX-2388-TEX-N-31"></use></g></g></g><g data-mml-node="mo" transform="translate(8732.4, 0)"><use xlink:href="#MJX-2388-TEX-N-2212"></use></g><g data-mml-node="mi" transform="translate(9732.6, 0)"><use xlink:href="#MJX-2388-TEX-I-1D449"></use></g><g data-mml-node="msup" transform="translate(10501.6, 0)"><g data-mml-node="mi"><use xlink:href="#MJX-2388-TEX-I-1D443"></use></g><g data-mml-node="TeXAtom" transform="translate(806.5, 413) scale(0.707)" data-mjx-texclass="ORD"><g data-mml-node="mi"><use xlink:href="#MJX-2388-TEX-I-1D6FD"></use></g><g data-mml-node="mo" transform="translate(566, 0)"><use xlink:href="#MJX-2388-TEX-N-2212"></use></g><g data-mml-node="mn" transform="translate(1344, 0)"><use xlink:href="#MJX-2388-TEX-N-32"></use></g></g></g></g></g></svg>

<p>It is indeed composed by two non-negative functions.</p>

<h1 id="derive-weights-via-back-propagation">Derive Weights via Back-propagation</h1>

<h2 id="2-backward-pass-algorithm">2 Backward-Pass Algorithm</h2>

<p>Now we can see that the two non-negative functions respect to the parameter can be viewed as two non-negative functions respect to the NMF output each multiplied by the derivative of NMF output respect to the parameter. The latter can be evaluated by PyTorch automatic differentiation, so we only need to calculate the former. After calculating the former, we just need to back-propagate the computational graph 2 times, then we can get the multiplicative update weights.</p>

<h3 id="steps">Steps</h3>

<ol>
  <li>Calculate the NMF output <code class="language-plaintext highlighter-rouge">P</code>.</li>
  <li>Given <code class="language-plaintext highlighter-rouge">P</code> and target <code class="language-plaintext highlighter-rouge">V</code>, derive the two non-negative components (<code class="language-plaintext highlighter-rouge">pos</code> and <code class="language-plaintext highlighter-rouge">neg</code>) of the derivative respect to <code class="language-plaintext highlighter-rouge">P</code>.</li>
  <li>Derive one non-negative components of the derivative respect to the parameter that needs to be updated by back-propagation (in PyTorch, <code class="language-plaintext highlighter-rouge">P.backward(pos, retain_graph=True)</code>).</li>
  <li>Derive the remaining non-negative components of the derivative by back-propagation (in PyTorch, <code class="language-plaintext highlighter-rouge">P.backward(neg)</code>).</li>
  <li>Derive the multiplicative update weights by dividing step 4 by step 3.</li>
</ol>

<h1 id="whats-the-benefit-of-this-approach">What’s the Benefit of this Approach?</h1>

<p>Well, because most of the update weights now can be done by automatic differentiation, we can apply the following feature more easily without writing closed form solutions:</p>

<ul>
  <li><strong>Advanced matrix/tensor operations</strong>: Some NMF variants (like De-convolutional NMF) use convolution instead of simple matrix multiplication to calculate the output; in PyTorch, convolution is supported natively and is fully differentiable.</li>
  <li><strong>Deeper NMF structure</strong> : Recently, some research tried to learn much higher level features by stacking multiple NMF layer by layer, which probably inspired by the rapid progress of Deep Learning in the last decade. But due to non-negative constraints, derive a closed form update solution is non-trivial. With PyTorch-NMF, as long as the gradients are all non-negative along the back-propagation path in the computational graph, we can put arbitray number of NMF layers in our model, or even more complex structure of operations, and train them jointly.</li>
</ul>

<h1 id="conclusion">Conclusion</h1>

<p>In this post I show you how PyTorch-NMF apply multiplicative update rules on much more advanced (or Deeper) NMF model, and I hope this project can benefits researchers from various field.</p>

<p>(This project is still in early developement, if you have interests to support the project, please contact me.)</p>

<h1 id="reference">Reference</h1>

<ul>
  <li>Févotte, Cédric, and Jérôme Idier. “Algorithms for nonnegative matrix factorization with the β-divergence.” Neural computation 23.9 (2011): 2421-2456.</li>
  <li>PyTorch-NMF, <a href="https://github.com/yoyolicoris/pytorch-NMF">source code</a></li>
  <li>PyTorch-NMF, <a href="https://pytorch-nmf.readthedocs.io/">documentation</a></li>
</ul>]]></content><author><name>Chin-Yun Yu</name><email>chin-yun.yu@qmul.ac.uk</email></author><category term="side project" /><category term="scientific computing" /><category term="pytorch" /><category term="nmf" /><summary type="html"><![CDATA[Recently I updated the implementation of PyTorch-NMF to make it be able to scale on large and complex NMF models. In this blog post I will briefly explain how this was done thanks to the automatic differentiation of PyTorch.]]></summary></entry></feed>