Jekyll2026-03-04T02:09:31+00:00https://iamycy.github.io/feed.xmlYCY / KazaneyaResearcher & Software/Algorithm EngineerChin-Yun Yu[email protected]Block-based Fast Differentiable IIR in PyTorch2025-06-28T00:00:00+00:002025-06-28T00:00:00+00:00https://iamycy.github.io/posts/2025/06/28/unroll-ssmI recently came across a presentation by Andres Ezequiel Viso from GPU Audio at ADC 2022, in which he talked about how they accelerate IIR filters on the GPU. The approach they use is to formulate the IIR filter as a state-space model (SSM) and augment the transition matrix so that each step processes multiple samples at once. The primary speedup stems from the fact that GPUs are very good at performing large matrix multiplications, and the SSM formulation enables us to leverage this capability.


Speeding up IIR filters while maintaining differentiability has always been my interest. The most recent method I worked on is from my recent submission to DAFx 25, where my co-author Ben proposed using parallel associative scan to speed up the recursion on the GPU. Nevertheless, since PyTorch does not have a built-in associative scan operator (in contrast to JAX), we must implement custom kernels for it, which is non-trivial. It also requires that the filter has distinct poles so that the state-space transition matrix is diagonalisable. The method that GPU Audio presented appears to be feasible solely using the PyTorch Python API and doesn’t have the restrictions I mentioned; thus, I decided to benchmark it and see how it performs.

Since it’s just a proof of concept, the filter I’m going to test is a time-invariant all-pole IIR filter, which is the minimal case of a recursive filter. This allows us to leverage some special optimisations that won’t work with time-varying general IIR filters, but that won’t affect the main idea I’m going to present here.

Naive implementation of an all-pole IIR filter

The difference equation of an \(M\)-th order all-pole IIR filter is given by:

\[y[n] = x[n] -\sum_{m=1}^{M} a_m y[n-m].\]

Let’s implement this in PyTorch:

import torch
from torch import Tensor

@torch.jit.script
def naive_allpole(x: Tensor, a: Tensor) -> Tensor:
    """
    Naive all-pole filter implementation.
    
    Args:
        x (Tensor): Input signal.
        a (Tensor): All-pole coefficients.
        
    Returns:
        Tensor: Filtered output signal.
    """
    assert x.dim() == 2, "Input signal must be a 2D tensor (batch_size, signal_length)"
    assert a.dim() == 1, "All-pole coefficients must be a 1D tensor"

    # list to store output at each time step
    output = []
    # assume initial condition is zero
    zi = x.new_zeros(x.size(0), a.size(0))

    for xt in x.unbind(1):
        # use addmv for efficient matrix-vector multiplication
        yt = torch.addmv(xt, zi, a, alpha=-1.0)
        output.append(yt)

        # update the state for the next time step
        zi = torch.cat([yt.unsqueeze(1), zi[:, :-1]], dim=1)

    return torch.stack(output, dim=1)

In this implementation, I didn’t use any in-place operations for speedup since it would break the differentiability of the function. This naive implementation is not very efficient, as torch.addmv and torch.cat are called at each time step. Typically, the audio signal is hundreds of thousands of samples long, resulting in a significant amount of function call overhead. For details, please take a look at my tutorial on differentiable IIR filters at ISMIR 2023.

Notice that I used torch.jit.script to compile the function for some slight speedup. I tried the newer compilation feature torch.compile, but it didn’t work. The compilation hangs forever, I don’t know why… I never found torch.compile to be useful in my research projects, and torch.jit.* has proven to be way more reliable.

Let’s benchmark its speed on my Ubuntu with an Intel i7-7700K. We’ll use a batch size of 8, a signal length of 16384, and \(M=2\), which is a reasonable setting for audio processing.

from torch.utils.benchmark import Timer

batch_size = 8
signal_length = 16384
order = 2

def order2a(order: int) -> Tensor:
    a = torch.randn(order)
    # simple way to ensure stability
    a = a / a.abs().sum()
    return a

a = order2a(order)
x = torch.randn(batch_size, signal_length)

naive_allpole_t = Timer(
    stmt="naive_allpole(x, a)",
    globals={"naive_allpole": naive_allpole, "x": x, "a": a},
    label="naive_allpole",
    description="Naive All-Pole Filter",
    num_threads=4,
)
naive_allpole_t.blocked_autorange(min_run_time=1.0)
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5b4423b260>
naive_allpole
Naive All-Pole Filter
  Median: 168.93 ms
  IQR:    0.54 ms (168.57 to 169.11)
  6 measurements, 1 runs per measurement, 4 threads

168.93 ms is relatively slow, but it is expected.

State-space model formulation

Before we proceed to showing the sample unrolling trick, let’s first introduce the state-space model (SSM) formulation of the all-pole IIR filter. The model is similar to the one used in my previous blogpost on TDF-II filter:

\[\begin{align} \mathbf{h}[n] &= \begin{bmatrix} -a_1 & -a_2 & \cdots & -a_{M-1} & -a_M \\ 1 & 0 &\cdots & 0 & 0 \\ 0 & 1 & \cdots & 0 & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & \cdots & 1 & 0 \\ \end{bmatrix} \mathbf{h}[n-1] + \begin{bmatrix} 1 \\ 0 \\ 0 \\ \vdots \\ 0 \\ \end{bmatrix} x[n] \\ &= \mathbf{A} \mathbf{h}[n-1] + \mathbf{B} x[n] \\ y[n] &= \mathbf{B}^\top \mathbf{h}[n]. \end{align}\]

Here, I simplified the original SSM by omitting the direct path, as it can be derived from the state vector (for the all-pole filter only). Below is the PyTorch implementation of it:

@torch.jit.script
def a2companion(a: Tensor) -> Tensor:
    """
    Convert all-pole coefficients to a companion matrix.

    Args:
        a (Tensor): All-pole coefficients.

    Returns:
        Tensor: Companion matrix.
    """
    assert a.dim() == 1, "All-pole coefficients must be a 1D tensor"
    order = a.size(0)
    c = torch.diag(a.new_ones(order - 1), -1)
    c[0, :] = -a
    return c

@torch.jit.script
def state_space_allpole(x: Tensor, a: Tensor) -> Tensor:
    """
    State-space implementation of all-pole filtering.

    Args:
        x (Tensor): Input signal.
        a (Tensor): All-pole coefficients.

    Returns:
        Tensor: Filtered output signal.
    """
    assert x.dim() == 2, "Input signal must be a 2D tensor (batch_size, signal_length)"
    assert a.dim() == 1, "All-pole coefficients must be a 1D tensor"

    c = a2companion(a).T

    output = []
    # assume initial condition is zero
    h = x.new_zeros(x.size(0), c.size(0))

    # B * x
    x = torch.cat(
        [x.unsqueeze(-1), x.new_zeros(x.size(0), x.size(1), c.size(0) - 1)], dim=2
    )

    for xt in x.unbind(1):
        h = torch.addmm(xt, h, c)
        # B^T @ h
        output.append(h[:, 0])
    return torch.stack(output, dim=1)

a2companion converts the all-pole coefficients to a companion matrix, which is \(\mathbf{A}\) in the SSM formulation.

Before we benchmark the speed of this implementation, let’s predict how fast it will be. Intuitively, since the complexity of vector-dot product is \(O(M)\) and matrix-vector multiplication is \(O(M^2)\), the SSM implementation uses more computational resources, so it should be slower than the naive implementation. Let’s benchmark its speed:

state_space_allpole_t = Timer(
    stmt="state_space_allpole(x, a)",
    globals={"state_space_allpole": state_space_allpole, "x": x, "a": a},
    label="state_space_allpole",
    description="State-Space All-Pole Filter",
    num_threads=4,
)
state_space_allpole_t.blocked_autorange(min_run_time=1.0)
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5a02eaf4a0>
state_space_allpole
State-Space All-Pole Filter
  Median: 118.41 ms
  IQR:    1.17 ms (117.79 to 118.96)
  9 measurements, 1 runs per measurement, 4 threads

Interestingly, the SSM implementation is approximately 50 ms faster.

By using torch.profiler.profile, I found that, in the naive implementation, torch.cat for updating the last M outputs accounts for a significant portion of the total time (~20%). The actual computation, torch.addmv, takes only about 10% of the time. Regarding memory usage, the most memory-intensive operation is torch.addmv, which consumes approximately 512 Kb of memory. In contrast, the SSM implementation uses more memory (> 1 Mb) due to matrix multiplication, but roughly 38% of the time is spent on filtering since it no longer has to call torch.cat at each time step. The state vector (a.k.a the last M outputs) is automatically updated during the matrix multiplication.

Conclusion: Tensor concatenation (including torch.cat and torch.stack) is computationally expensive, and it is advisable to avoid it whenever possible.

Unrolling the SSM

Now we can apply the unrolling trick to the SSM implementation. The idea is to divide the input signal into blocks of size \(T\) and perform the recursion on these blocks instead of processing them sample-by-sample. Each recursion takes the last vector state \(\mathbf{h}[n-1]\) and predicts the next \(T\) states \([\mathbf{h}[n], \mathbf{h}[n+1], \ldots, \mathbf{h}[n+T-1]]^\top\) at once. To see how to calculate these states, let’s unroll the SSM recursion for \(T\) steps:

\[\begin{align} \mathbf{h}[n] &= \mathbf{A} \mathbf{h}[n-1] + \mathbf{B} x[n] \\ \mathbf{h}[n+1] &= \mathbf{A} \mathbf{h}[n] + \mathbf{B} x[n+1] \\ &= \mathbf{A} (\mathbf{A} \mathbf{h}[n-1] + \mathbf{B} x[n]) + \mathbf{B} x[n+1] \\ &= \mathbf{A}^2 \mathbf{h}[n-1] + \mathbf{A} \mathbf{B} x[n] + \mathbf{B} x[n+1] \\ \mathbf{h}[n+2] &= \mathbf{A} \mathbf{h}[n+1] + \mathbf{B} x[n+2] \\ &= \mathbf{A} (\mathbf{A}^2 \mathbf{h}[n-1] + \mathbf{A} \mathbf{B} x[n] + \mathbf{B} x[n+1]) + \mathbf{B} x[n+2] \\ &= \mathbf{A}^3 \mathbf{h}[n-1] + \mathbf{A}^2 \mathbf{B} x[n] + \mathbf{A} \mathbf{B} x[n+1] + \mathbf{B} x[n+2] \\ & \vdots \\ \mathbf{h}[n+T-1] &= \mathbf{A}^{T} \mathbf{h}[n-1] + \sum_{t=0}^{T-1} \mathbf{A}^{T - t -1} \mathbf{B} x[n+t] \\ \end{align}\]

We can rewrite the above equation in matrix form as follows:

\[\begin{align} \begin{bmatrix} \mathbf{h}[n] \\ \mathbf{h}[n+1] \\ \vdots \\ \mathbf{h}[n+T-1] \end{bmatrix} &= \begin{bmatrix} \mathbf{A} \\ \mathbf{A}^2 \\ \vdots \\ \mathbf{A}^T \\ \end{bmatrix} \mathbf{h}[n-1] + \begin{bmatrix} \mathbf{I} & 0 & \cdots & 0 \\ \mathbf{A} & \mathbf{I} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ \mathbf{A}^{T-1} & \mathbf{A}^{T-2} & \cdots & \mathbf{I} \end{bmatrix} \begin{bmatrix} \mathbf{B}x[n] \\ \mathbf{B}x[n+1] \\ \vdots \\ \mathbf{B}x[n+T-1] \end{bmatrix} \\ & = \begin{bmatrix} \mathbf{A} \\ \mathbf{A}^2 \\ \vdots \\ \mathbf{A}^T \\ \end{bmatrix} \mathbf{h}[n-1] + \begin{bmatrix} \mathbf{I}_{.1} & 0 & \cdots & 0 \\ \mathbf{A}_{.1} & \mathbf{I}_{.1} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ \mathbf{A}_{.1}^{T-1} & \mathbf{A}_{.1}^{T-2} & \cdots & \mathbf{I}_{.1} \end{bmatrix} \begin{bmatrix} x[n] \\ x[n+1] \\ \vdots \\ x[n+T-1] \end{bmatrix} \\ & = \mathbf{M} \mathbf{h}[n-1] + \mathbf{V} \begin{bmatrix} x[n] \\ x[n+1] \\ \vdots \\ x[n+T-1] \end{bmatrix} \\ \end{align}\]

Notice that in the second line, I utilised the fact that \(\mathbf{B}\) has only one non-zero entry to simplify the matrix. (This is not possible if the filter is not strictly all-pole.) \(\mathbf{I}_{.1}\) denotes the first column of the identity matrix and so on.

Now, the number of autoregressive steps is reduced from \(N\) to \(\frac{N}{T}\) and the matrix multiplication is done in parallel for every \(T\) samples. There are added costs for pre-computing the transition matrix \(\mathbf{M}\) and the input matrix \(\mathbf{V}\), though. However, as long as the extra cost is relatively small compared to the cost of \(N - \frac{N}{T}\) autoregressive steps, we should observe a speedup.

Here’s the PyTorch implementation of the unrolled SSM:

@torch.jit.script
def state_space_allpole_unrolled(
    x: Tensor, a: Tensor, unroll_factor: int = 1
) -> Tensor:
    """
    Unrolled state-space implementation of all-pole filtering.

    Args:
        x (Tensor): Input signal.
        a (Tensor): All-pole coefficients.
        unroll_factor (int): Factor by which to unroll the loop.

    Returns:
        Tensor: Filtered output signal.
    """
    if unroll_factor == 1:
        return state_space_allpole(x, a)
    elif unroll_factor < 1:
        raise ValueError("Unroll factor must be >= 1")

    assert x.dim() == 2, "Input signal must be a 2D tensor (batch_size, signal_length)"
    assert a.dim() == 1, "All-pole coefficients must be a 1D tensor"
    assert (
        x.size(1) % unroll_factor == 0
    ), "Signal length must be divisible by unroll factor"

    c = a2companion(a)

    # create an initial identity matrix
    initial = torch.eye(c.size(0), device=c.device, dtype=c.dtype)
    c_list = [initial]
    # TODO: use parallel scan to improve speed
    for _ in range(unroll_factor):
        c_list.append(c_list[-1] @ c)

    # c_list = [I c c^2 ... c^unroll_factor]
    M = torch.cat(c_list[1:], dim=0).T
    flatten_c_list = torch.cat(
        [c.new_zeros(c.size(0) * (unroll_factor - 1))]
        + [xx[:, 0] for xx in c_list[:-1]],
        dim=0,
    )
    V = flatten_c_list.unfold(0, c.size(0) * unroll_factor, c.size(0)).flip(0)

    # divide the input signal into blocks of size unroll_factor
    unrolled_x = x.unflatten(1, (-1, unroll_factor)) @ V

    output = []
    # assume initial condition is zero
    h = x.new_zeros(x.size(0), c.size(0))
    for xt in unrolled_x.unbind(1):
        h = torch.addmm(xt, h, M)
        # B^T @ h
        output.append(h[:, :: c.size(0)])
        h = h[
            :, -c.size(0) :
        ]  # take the last state vector as the initial condition for the next step
    return torch.cat(output, dim=1)

The unroll_factor parameter controls the number of samples to process in parallel. If it is set to 1, the function is the original SSM implementation.

Now let’s benchmark the speed of the unrolled SSM implementation. We’ll use unroll_factor=128 since I already tested that it is the optimal value :)

state_space_allpole_unrolled_t = Timer(
    stmt="state_space_allpole_unrolled(x, a, unroll_factor=unroll_factor)",
    globals={
        "state_space_allpole_unrolled": state_space_allpole_unrolled,
        "x": x,
        "a": a,
        "unroll_factor": 128,
    },
    label="state_space_allpole_unrolled",
    description="State-Space All-Pole Filter Unrolled",
    num_threads=4,
)
state_space_allpole_unrolled_t.blocked_autorange(min_run_time=1.0)
<torch.utils.benchmark.utils.common.Measurement object at 0x7f5a01d75160>
state_space_allpole_unrolled
State-Space All-Pole Filter Unrolled
  Median: 1.89 ms
  IQR:    0.08 ms (1.88 to 1.96)
  6 measurements, 100 runs per measurement, 4 threads

1.89 ms! What sorcery is this? That’s a whopping 60x speedup compared to the standard SSM implementation!

A closer look at the profiling results shows that in total, 38% of the time is spent on matrix multiplication and addition. The speedup comes with a cost of increased memory usage, requiring more than 2 MB for filtering. Not a significant cost for modern Hardwares.

For convenience, I ran the above benchmarks using the CPU, which has very limited parallelism compared to the GPU. Thus, the significant speedup we observe indicates that function call overhead is the major bottleneck for running recursions.

More comparison

Since \(T\) is an essential parameter for the unrolled SSM, I did some benchmarks to see how it affects the speed.

Varying sequence length

In this benchmark, I fixed the batch size to 8 and the order to 2, and varied the sequence length from 4096 to 262144. The results suggest that the best unroll factor increases as the sequence length increases, and it’s very likely to be \(\sqrt{N}\). Additionally, the longer the sequence length, the greater the speedup we achieve from the unrolled SSM.

Varying filter order

To examine the impact of filter order on speed, I set the batch size to 8 and the sequence length to 16384, and then varied the filter order from 2 to 16. It appears that my hypothesis that the best factor is \(\sqrt{N}\) still holds, but the peak gradually shifts to the left as the order increases. Moreover, the speedup is less significant for higher orders, which is expected as the \(\mathbf{V}\) matrix becomes larger.

Varying batch size

The speedup is less as the batch size increases, which is expected. However, the peak of the best unroll factor also shifts slightly to the left as the batch size increases.

Memory usage

To observe how memory usage changes in a differentiable training context, I ran the unrolled SSM on a 5060 Ti, allowing me to use torch.cuda.max_memory_allocated() to measure memory usage. When batch size is 1, as expected, the memory usage grows quadratically with the unroll factor, due to the creation of the \(\mathbf{V}\) matrix.

When using a larger batch size (32 in this case), this cost becomes less significant compared to the more memory used for the input signal.

Discussion

So far, we have seen that the unrolled SSM can achieve a significant speedup for IIR filtering in PyTorch. However, determining the best unrolling factor automatically is still unclear. From the benchmarks I did on an i7 CPU, it seems that the optimal \(T^*\) is \(\sqrt{N}\alpha\) and \(0 < \alpha \leq 1\) is given by a function of the filter order and batch size. Since I also observe similar behaviour on the GPU, it is likely that this hypothesis holds true for other hardware as well.

One thing I didn’t mention is numerical accuracy. If \(|\mathbf{A}|\) is very small, the precomputed exponentials \(\mathbf{A}^T \to \mathbf{0}\) which may not be accurately represented in floating point, especially in deep learning applications we use single precision a lot. This is less of a problem for the standard SSM, since at each time step, the input is mixed with the state vector, which could help cancel out the numerical errors.

The idea should apply when there are zeros in the filter. \(\mathbf{B}\) will not be a simple one-hot vector anymore so \(\mathbf{V}\) has to be a full \(MT\times MT\) square matrix. Time-varying filters will benefit less from the unrolling trick since \(\mathbf{V}\) will also be time-varying, and computing \(\frac{N}{T}\) such matrices in advance increases the cost.

Conclusion & Thoughts

In this post, I demonstrate that the unrolling trick can significantly accelerate differentiable IIR filtering in PyTorch. The extra memory cost is less of a problem for large batch sizes. Although the filter I tested is a simple all-pole filter, it’s trivial to extend the idea to general IIR filters.

This method might help address one of the issues for future TorchAudio, after the Meta developers announced their future plan for it. In the next major release, all the specialised kernels written in C++, including the lfilter I contributed years ago, will be removed from TorchAudio. The filter I presented here is written entirely in Python and can serve as a straightforward drop-in replacement for the current compiled lfilter implementation.

Notes

The complete code is available in the Jupyter notebook version of this post on Gist.

Update (29.6.2025)

I realised that the state_space_allpole_unrolled function I made is very close to a two-level parallel scan, and with some modifications, we can squeeze a bit more performance out of it. Instead of computing all the \(T\) states at once per block, we can just compute the last state, which is the only one we need for the next block. Thus, the matrix size for the multiplication is reduced from \(\mathbf{M} \in \mathbb{R}^{MT\times M}\) to \(\mathbf{A}^T \in \mathbb{R}^{M\times M}\). The first \(M-1\) states for all the blocks can be computed later in parallel. The algorithm (parallel scan) is as follows:

Firstly, compute the input to the last state in the block:

\[\mathbf{z}[n+T-1] = \begin{bmatrix} \mathbf{A}_{.1}^{T-1} & \mathbf{A}_{.1}^{T-2} & \cdots & \mathbf{I}_{.1} \end{bmatrix} \begin{bmatrix} x[n] \\ x[n+1] \\ \vdots \\ x[n+T-1] \end{bmatrix}.\]

Then, compute the last state in each block recursively as follows:

\[\mathbf{h}[n+T-1] = \mathbf{A}^{T} \mathbf{h}[n-1] + \mathbf{z}[n+T-1].\]

Lastly, compute the remaining states in parallel:

\[\begin{bmatrix} \mathbf{h}[n] \\ \mathbf{h}[n+1] \\ \vdots \\ \mathbf{h}[n+T-2] \end{bmatrix} = \begin{bmatrix} \mathbf{A} & \mathbf{I}_{.1} & 0 & \cdots & 0 \\ \mathbf{A}^2 & \mathbf{A}_{.1} & \mathbf{I}_{.1} & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ \mathbf{A}^{T-1} & \mathbf{A}_{.1}^{T-2} & \mathbf{A}_{.1}^{T-3} & \cdots & \mathbf{I}_{.1} \end{bmatrix} \begin{bmatrix} \mathbf{h}[n-1] \\ x[n] \\ x[n+1] \\ \vdots \\ x[n+T-2] \end{bmatrix}.\]

The following code implements this algorithm, modified from the previous state_space_allpole_unrolled function.

@torch.jit.script
def state_space_allpole_unrolled_v2(
    x: Tensor, a: Tensor, unroll_factor: int = 1
) -> Tensor:
    """
    Unrolled state-space implementation of all-pole filtering.

    Args:
        x (Tensor): Input signal.
        a (Tensor): All-pole coefficients.
        unroll_factor (int): Factor by which to unroll the loop.

    Returns:
        Tensor: Filtered output signal.
    """
    if unroll_factor == 1:
        return state_space_allpole(x, a)
    elif unroll_factor < 1:
        raise ValueError("Unroll factor must be >= 1")

    assert x.dim() == 2, "Input signal must be a 2D tensor (batch_size, signal_length)"
    assert a.dim() == 1, "All-pole coefficients must be a 1D tensor"
    assert (
        x.size(1) % unroll_factor == 0
    ), "Signal length must be divisible by unroll factor"

    c = a2companion(a)

    # create an initial identity matrix
    I = torch.eye(c.size(0), device=c.device, dtype=c.dtype)
    c_list = [I]
    # TODO: use parallel scan to improve speed
    for _ in range(unroll_factor):
        c_list.append(c_list[-1] @ c)

    # c_list = [I c c^2 ... c^unroll_factor]
    flatten_c_list = torch.cat(
        [c.new_zeros(c.size(0) * (unroll_factor - 1))]
        + [xx[:, 0] for xx in c_list[:-1]],
        dim=0,
    )
    V = flatten_c_list.unfold(0, c.size(0) * unroll_factor, c.size(0)).flip(0)

    # divide the input signal into blocks of size unroll_factor
    unrolled_x = x.unflatten(1, (-1, unroll_factor))

    # get the last row of Vx
    last_x = unrolled_x @ V[:, -c.size(0) :]

    # initial condition
    zi = x.new_zeros(x.size(0), c.size(0))

    # transition matrix on the block level
    AT = c_list[-1].T
    block_output = []
    h = zi
    # block level recursion
    for xt in last_x.unbind(1):
        h = torch.addmm(xt, h, AT)
        block_output.append(h)

    # stack the accumulated last outputs of the blocks as initial conditions for the intermediate steps
    initials = torch.stack([zi] + block_output, dim=1)

    # prepare the augmented matrix and input for all the remaining steps
    aug_x = torch.cat([initials[:, :-1], unrolled_x[..., :-1]], dim=2)
    aug_A = torch.cat(
        [
            torch.stack([c[0] for c in c_list[1:-1]], dim=1),
            V[:-1, : -c.size(0) : c.size(0)],
        ],
        dim=0,
    )
    output = aug_x @ aug_A

    # concat the first M - 1 outputs with the last one
    output = torch.cat([output, initials[:, 1:, :1]], dim=2)
    return output.flatten(1, 2)

Let’s benchmark it!

<torch.utils.benchmark.utils.common.Measurement object at 0x78d297b8b290>
state_space_allpole_unrolled_v2
State-Space All-Pole Filter Unrolled
  Median: 1.40 ms
  IQR:    0.01 ms (1.40 to 1.41)
  7 measurements, 100 runs per measurement, 4 threads

1.40 ms! That’s approximately 1.35 times faster than the previous version. It might be worth redoing the benchmarks again, but I’m too lazy to do it now :D It should be similar to the previous result. I’ll upload benchmark results to Gist soon.

]]>
Chin-Yun Yu[email protected]
Notes on Differentiable TDF-II Filter2025-04-26T00:00:00+00:002025-04-26T00:00:00+00:00https://iamycy.github.io/posts/2025/04/differentiable-tdf-iiThis blog is a continuation of some of my early calculations for propagating gradients through general IIR filters, including direct-form and transposed-direct-form.

Back story

In early 2021, I implemented a differentiable lfilter function for torchaudio (a few core details were published two years later here). The basic idea is to implement the backpropagation of gradients in C++ for optimal performance. The implementation was based on Direct-Form-I (DF-I). This differs from the popular implementation of SciPy’s lfilter, which is based on Transposed-Direct-Form-II (TDF-II) and is more numerically stable1.

Implementing it in this form would be better, but… at the time, my knowledge base was insufficient to generalise the idea to TDF-II. In DF-I/II, the gradients of FIR and all-pole filters can be treated independently, so I worked only on the recursive part of the filter (the all-pole).

TDF-II
DF-I

However, in TDF-II, the two parts are combined and the registers are shared, so my previous approach does not work. I left this as a TODO for the future2.

DF-I

Many things have changed since then. I started my PhD in 2022 and have more time to think thoroughly about the problem. My understanding of filters improved after exploring the idea with some publications a few times. It’s time to revisit the problem, a differentiable TDF-II filter.

TL;DR, the backpropagation of TDF-II filter is a DF-II filter, and vice versa.

The following calculation considers the general case when the filter parameters are time-varying. Time-invariant systems are a special case and are trivial once we have the time-varying results.

(Transposed-)Direct-Form-II

Given time-varying coefficients \(\{b_0[n], b_1[n],\dots,b_M[n]\}\) and \(\{a_1[n],\dots,a_N[n]\}\), the TDF-II filter can be expressed as:

\[y[n] = s_1[n] + b_0[n] x[n]\] \[s_1[n+1] = s_2[n] + b_1[n] x[n] - a_1[n] y[n]\\\] \[s_2[n+1] = s_3[n] + b_2[n] x[n] - a_2[n] y[n]\] \[\vdots\] \[s_M[n+1] = b_M[n] x[n] - a_M[n] y[n].\]

We can also write it in observable canonical form:

\[\mathbf{s}[n+1] = \mathbf{A}[n] \mathbf{s}[n] + \mathbf{B}[n] x[n]\] \[y[n] = \mathbf{C}\mathbf{s}[n] + b_0[n] x[n]\] \[\mathbf{A}[n] = \begin{bmatrix} -a_1[n] & 1 & 0 & \cdots & 0 \\ -a_2[n] & 0 & 1 & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ -a_{M-1}[n] & 0 & 0 & \cdots & 1 \\ -a_M[n] & 0 & 0 & \cdots & 0 \end{bmatrix}\] \[\mathbf{C} = \begin{bmatrix} 1 & 0 & 0 & \cdots & 0 \\ \end{bmatrix}.\]

The values of \(\mathbf{B}[n] \) can be referred from Julius’ blog3.

Regarding DF-II, its difference equations are:

\[v[n] = x[n] - \sum_{i=1}^{M} a_i[n] v[n-i]\] \[y[n] = \sum_{i=0}^{M} b_i[n] v[n-i].\]

Similarly, it can be expressed as a state-space model using the controller canonical form:

\[\mathbf{v}[n+1] = \begin{bmatrix} -a_1[n] & -a_2[n] & \cdots & -a_{M-1}[n] & -a_M[n] \\ 1 & 0 & \cdots & 0 & 0 \\ 0 & 1 & \cdots & 0 & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & \cdots & 1 & 0 \end{bmatrix} \mathbf{v}[n] + \begin{bmatrix} 1 \\ 0 \\ \vdots \\ 0 \end{bmatrix} x[n] = \mathbf{A}^\top[n]\mathbf{v}[n] + \mathbf{C}^\top x[n]\] \[y[n] = \mathbf{B}^\top[n] \mathbf{v}[n] + b_0[n] x[n].\]

As I have shown above, the forms are very similar. The transition matrix of TDF-II is the transpose of the DF-II, and the vectors B and C are swapped. (This is the reason why we call it transposed-DF-II.) Note that the resulting transfer function is not the same due to the difference in computation order in the time-varying case. (They are the same if the coefficients are time-invariant!) I will use the state-space form for simplicity in the following sections.

Backpropagation through TDF-II

Supposed we have evaluated some loss function \(\mathcal{L}\) on the output of the filter \(y[n]\) and has the instantaneous gradients \(\frac{\partial \mathcal{L}}{\partial \mathbf{s}[n]}\). We want to backpropagate the gradients through the filter to get the gradients of the input \(\frac{\partial \mathcal{L}}{\partial x[n]}\) and the filter coefficients \(\frac{\partial \mathcal{L}}{\partial a_i[n]}\) and \(\frac{\partial \mathcal{L}}{\partial b_i[n]}\). Let’s first denote \(\mathbf{z}[n] = \mathbf{B}[n] x[n]\) since once we get the gradients of \(\mathbf{z}[n]\), it’s easy to get the gradients of the two using the chain rule. The recursion in TDF-II state-space form becomes:

\[\mathbf{s}[n+1] = \mathbf{A}[n] \mathbf{s}[n] + \mathbf{z}[n].\]

If we unroll the recursion so there’s no s in the right-hand side, we get:

\[\mathbf{s}[n+1] = \sum_{i=1}^{\infty} \left(\prod_{j=1}^{i} \mathbf{A}[n-j+1]\right) \mathbf{z}[n-i] + \mathbf{z}[n].\]

The gradients for z can be computed as:

\[\frac{\partial \mathbf{s}[n]}{\partial \mathbf{z}[i]} = \begin{cases} \prod_{j=1}^{n-i-1} \mathbf{A}[n-j] & i < n - 1 \\ \mathbf{I} & i = n -1 \\ 0 & i \geq n \end{cases}\] \[\frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]} = \sum_{i=n+1}^{\infty} \frac{\partial \mathcal{L}}{\partial \mathbf{s}[i]} \frac{\partial \mathbf{s}[i]}{\partial \mathbf{z}[n]}\] \[= \frac{\partial \mathcal{L}}{\partial \mathbf{s}[n+1]} + \sum_{i=n+2}^{\infty} \frac{\partial \mathcal{L}}{\partial \mathbf{s}[i]} \prod_{j=1}^{i-n-1} \mathbf{A}[i-j]\] \[= \frac{\partial \mathcal{L}}{\partial \mathbf{s}[n+1]} + \sum_{i=n+2}^{\infty} \left( \prod_{j=i-n-1}^{1} \mathbf{A}^\top[i-j] \right) \frac{\partial \mathcal{L}}{\partial \mathbf{s}[i]}\] \[= \frac{\partial \mathcal{L}}{\partial \mathbf{s}[n+1]} + \sum_{i=n+2}^{\infty} \left( \prod_{j=1}^{i-n-1} \mathbf{A}^\top[n+j] \right) \frac{\partial \mathcal{L}}{\partial \mathbf{s}[i]}.\] \[= \mathbf{A}^\top[n+1] \frac{\partial \mathcal{L}}{\partial \mathbf{z}[n+1]} + \frac{\partial \mathcal{L}}{\partial \mathbf{s}[n+1]}.\]

For simplicity, I omitted the transpose sign for the vector. The last recursion involves \(\mathbf{A}^\top\), which implies that, to backpropagate the gradients through the recursion of TDF-II, we need to use the recursion of DF-II but in the opposite direction! Their roles will be swapped if we compute the gradients of DF-II using the same procedure, but I’ll leave it as an exercise for the reader :D

For completeness, the following are the procedures to compute the gradients of the input and filter coefficients.

Gradients of the input

\[\frac{\partial \mathcal{L}}{\partial \mathbf{s}[n]} = \mathbf{C}^\top \frac{\partial \mathcal{L}}{\partial y[n]} % \begin{bmatrix} % \frac{\partial \mathcal{L}}{\partial y[n]} \\ % 0 \\ % \vdots \\ % 0 % \end{bmatrix}\] \[\frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]} = \mathbf{A}^\top[n+1] \frac{\partial \mathcal{L}}{\partial \mathbf{z}[n+1]} + \mathbf{C}^\top \frac{\partial \mathcal{L}}{\partial y[n+1]}\]

(Note that the above line is the same as in DF-II! Just the input and output variables are changed.)

\[\frac{\partial \mathcal{L}}{\partial x[n]} = \mathbf{B}^\top[n] \frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]} + b_0[n] \frac{\partial \mathcal{L}}{\partial y[n]}\]

Gradients of the b coefficients

\[\frac{\partial \mathcal{L}}{\partial \mathbf{B}[n]} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]} x[n]\] \[\frac{\partial \mathcal{L}}{\partial b_0[n]} = \frac{\partial \mathcal{L}}{\partial y[n]} x[n]\]

Gradients of the a coefficients

\[\frac{\partial \mathcal{L}}{\partial \mathbf{A}[n]} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]} \mathbf{s}^\top[n] \to a_i[n] = -\frac{\partial \mathcal{L}}{\partial z_i[n]} s_1[n]\]

Time-invariant case

In the time-invariant case, the parameters are constant.

\[a_i[n] = a_i[m] \quad \forall n, m, \quad i = 1, \dots, M\] \[b_i[n] = b_i[m] \quad \forall n, m, \quad i = 0, \dots, M\]

In this case, we can just sum the gradients over time:

\[\frac{\partial \mathcal{L}}{\partial a_i} = \sum_{n} \frac{\partial \mathcal{L}}{\partial a_i[n]},~\ \frac{\partial \mathcal{L}}{\partial b_i} = \sum_{n} \frac{\partial \mathcal{L}}{\partial b_i[n]}.\]

Summary

The above findings suggest a way to compute the TDF-II filter’s gradients efficiently. To do this, the following steps are needed:

  1. Implement the recursions of TDF-II and DF-II filters in C++/CUDA/Metal/etc.
  2. After doing the forward pass of TDF-II, store \(s_1[n]\), \(\mathbf{a}[n]\), \(\mathbf{b}[n]\), and \(x[n]\).
  3. When doing backpropagation, filter the output gradients \(\frac{\partial \mathcal{L}}{\partial y[n]}\) through the DF-II filter’s recusions in the opposite direction using the same a coefficients.
  4. Compute the gradients of the input and filter coefficients using the equations above. Note that although \(\frac{\partial \mathcal{L}}{\partial \mathbf{z}[n]}\) is a sequence of vectors, since the higher-order states in DF-II are just time-delayed versions of the first state (\(v_M[n] = v_{M-1}[n-1] = \cdots = v_1[n-M+1]\)), we can just store \(\frac{\partial \mathcal{L}}{\partial z_1[n]}\) for gradient computation, reducing the memory usage by a factor of \(M\).

Final thoughts

The procedure above can be applied to derive the gradients of the DF-II filter as well. The resulting algorithm is identical, but the roles of TDF-II and DF-II are swapped. Personally, I found using a state-space formulation much easier, straightforward, and elegant than the derivation I did in 2024 to calculate the gradients of time-varying all-pole filters, which is basically the same problem. (Man, I was basically brute-forcing it…) Applying the method to TDF-I is straightforward, just set \(\mathbf{B}[n] = 0\).

Interestingly, since the backpropagation of TDF-II is a DF-II filter, it’s less numerically stable than TDF-II; in contrast, the backpropagation of DF-II is a TDF-II filter and is more stable. We’ll always have this trade-off, so is TDF-II necessary if we want differentiability? Probably yes, since besides backpropagation, the gradients can also be computed using forward-mode automatic differentiation, which computes the Jacobian in the opposite direction. In this way, the forwarded gradients are computed in the same way as the filter’s forward pass, and the math is much easier to show than the backpropagation I wrote above. (Should realise earlier…) Also, in the time-varying case and \(M > 1\), neither of the two forms guarantees BIBO stability. This is another interesting topic, but let’s just leave it for now. I hope this post is helpful for those who are interested in differentiable IIR filters.

Notes

The figures are from Julius O. Smith III and the notations are adapted from his blog3. The algorithm is based on the following papers:

  1. Singing Voice Synthesis Using Differentiable LPC and Glottal-Flow-Inspired Wavetables (doi: 10.5281/zenodo.13916489)
  2. Differentiable Time-Varying Linear Prediction in the Context of End-to-End Analysis-by-Synthesis (doi: 10.21437/Interspeech.2024-1187)
  3. Differentiable All-pole Filters for Time-varying Audio Systems
  4. GOLF: A Singing Voice Synthesiser with Glottal Flow Wavetables and LPC Filters (doi: 10.5334/tismir.210)

References:

  1. https://ccrma.stanford.edu/~jos/filters/Numerical_Robustness_TDF_II.html 

  2. https://github.com/pytorch/audio/pull/1310#issuecomment-790408467 

  3. https://ccrma.stanford.edu/~jos/fp/Converting_State_Space_Form_Hand.html  2

]]>
Chin-Yun Yu[email protected]
How to Train Deep NMF Model in PyTorch2021-02-09T00:00:00+00:002021-02-09T00:00:00+00:00https://iamycy.github.io/posts/2021/02/torchnmf-algorithmRecently I updated the implementation of PyTorch-NMF to make it be able to scale on large and complex NMF models. In this blog post I will briefly explain how this was done thanks to the automatic differentiation of PyTorch.

Multiplicative Update Rules with Beta Divergence

Multiplicative Update is a classic update method that has been widely used in many NMF applications. Its form is easy to derive, gaurantees a monotonic decrease of loss value, and ensures nonnegativity of the parameter updates.

Below are the multiplicative update forms when using Beta-Divergence as our criterion:


Decoupling the Derivative

The update weights are actually derived from the derivative of the criterion we choose respect to the parameter (H and W). Due to the property of Beta-Divergence, the derivative can be expressed as the difference of two nonnegative functions such that:

Then, we can simply writes:

Following the chain rule, we can also decoupling the derivative respect to parameter as (take H for example):

The derivative of WH respect to H is W^T, which is always non-negative, so the ability to decouple into two non-negative functions is actually comes from Beta-Divergence itself.

The above steps can be applied on W as well.

Derivative of Beta-Divergence

The form of Beta-Divergence is:

where P = WH and its derivative respect to P:

It is indeed composed by two non-negative functions.

Derive Weights via Back-propagation

2 Backward-Pass Algorithm

Now we can see that the two non-negative functions respect to the parameter can be viewed as two non-negative functions respect to the NMF output each multiplied by the derivative of NMF output respect to the parameter. The latter can be evaluated by PyTorch automatic differentiation, so we only need to calculate the former. After calculating the former, we just need to back-propagate the computational graph 2 times, then we can get the multiplicative update weights.

Steps

  1. Calculate the NMF output P.
  2. Given P and target V, derive the two non-negative components (pos and neg) of the derivative respect to P.
  3. Derive one non-negative components of the derivative respect to the parameter that needs to be updated by back-propagation (in PyTorch, P.backward(pos, retain_graph=True)).
  4. Derive the remaining non-negative components of the derivative by back-propagation (in PyTorch, P.backward(neg)).
  5. Derive the multiplicative update weights by dividing step 4 by step 3.

What’s the Benefit of this Approach?

Well, because most of the update weights now can be done by automatic differentiation, we can apply the following feature more easily without writing closed form solutions:

  • Advanced matrix/tensor operations: Some NMF variants (like De-convolutional NMF) use convolution instead of simple matrix multiplication to calculate the output; in PyTorch, convolution is supported natively and is fully differentiable.
  • Deeper NMF structure : Recently, some research tried to learn much higher level features by stacking multiple NMF layer by layer, which probably inspired by the rapid progress of Deep Learning in the last decade. But due to non-negative constraints, derive a closed form update solution is non-trivial. With PyTorch-NMF, as long as the gradients are all non-negative along the back-propagation path in the computational graph, we can put arbitray number of NMF layers in our model, or even more complex structure of operations, and train them jointly.

Conclusion

In this post I show you how PyTorch-NMF apply multiplicative update rules on much more advanced (or Deeper) NMF model, and I hope this project can benefits researchers from various field.

(This project is still in early developement, if you have interests to support the project, please contact me.)

Reference

  • Févotte, Cédric, and Jérôme Idier. “Algorithms for nonnegative matrix factorization with the β-divergence.” Neural computation 23.9 (2011): 2421-2456.
  • PyTorch-NMF, source code
  • PyTorch-NMF, documentation
]]>
Chin-Yun Yu[email protected]