This repository includes source code for the following paper:
@article{saijo2026input,
title={Input-Adaptive Spectral Feature Compression by Sequence Modeling for Source Separation},
author={Saijo, Kohei and Bando, Yoshiaki},
journal={IEEE Transactions on Audio, Speech and Language Processing},
year={2026},
publisher={IEEE}
}
# install this repo
git clone https://github.com/b-sigpro/spectral-feature-compression.git
cd spectral-feature-compression
pip install -e .
pip install -r requirements.txt
# install Mamba, see https://github.com/state-spaces/mamba for more details
pip install --no-build-isolation mamba-ssm[causal-conv1d]
# install aiaccel
git clone https://github.com/aistairc/aiaccel.git
cd aiaccel
git checkout 117d8d5d335540b6d331993ffc02d4b64f5e02a1 # the commit where we tested our code
pip install -e .
cd ../The repository supports three modules:
CrossAttnEncoder, CrossAttnDecoder: Spectral Feature Compression by Cross-Attention (SFC-CA; Section III-B in the paper)MambaEncoder, MambaDecoder: Spectral Feature Compression by Mamba (SFC-CA; Section III-C in the paper)BanditEncoder, BanditDecoder: Band-split module from the BandIt paper
Once the installation is done, one can use the these modules easily by importing them. The code below is an example to use SFC-CA:
import torch
from spectral_feature_compression import CrossAttnEncoder, CrossAttnDecoder
sample_rate = 44100
n_fft = 2048
n_batch, n_chan, n_frames, n_freqs = 4, 2, 100, n_fft//2+1
n_src = 4
encoder = CrossAttnEncoder(d_inner=64, d_model=128, n_chan=n_chan, sample_rate=sample_rate, n_fft=n_fft, n_bands=64).to("cuda")
decoder = CrossAttnDecoder(d_inner=64, d_model=128, n_src=n_src, n_chan=n_chan, sample_rate=sample_rate, n_fft=n_fft, n_bands=64).to("cuda")
# the modules assume a complex input of (n_batch, n_chan, n_frames, n_freqs) or float input of (n_batch, 2*n_chan, n_frames, n_freqs)
input = torch.randn((n_batch, n_chan, n_frames, n_freqs), dtype=torch.complex64, device="cuda")
enc_output, dec_query = encoder(input) # enc_output: (n_batch, d_model, n_frames, n_bands)
dec_output, _ = decoder(enc_output, query=dec_query) # dec_output: (n_batch, 2*n_chan*n_src, n_frames, n_freqs)Note that the SFC encoder returns two tensors; the first one is the compressed output and the second one is the non-compressed tensor. The second tensor will be used as the query in the decoder when using the adaptive query, while it's just ignored when using the learnable query (please refer to the paper for more details).
The repository supports the TF-Locoformer separator, which can be used as
from spectral_feature_compression import BSLocoformer
# assuming the encoder and decoder are initialized, as shown above
# the default parameters are for the small-sized model
separator = BSLocoformer(encoder=encoder, decoder=decoder, n_src=4, n_chan=2).to("cuda")
# the modules assume a complex input of (n_batch, n_chan, n_freqs, n_frames)
input = torch.randn((n_batch, n_chan, n_freqs, n_frames), dtype=torch.complex64, device="cuda")
output = separator(input) # (n_batch, n_src, n_chan, n_freqs, n_frames)We provide some pre-trained weights of the TF-Locoformer model trained on the MUSDB18HQ or DnR dataset at a Hugging Face repository.
python model_weights/download_pretrained_weights.py --dst_dir ./model_weightsBy default, it makes the model_weights directory and download models under it.
Once the models are downloaded, one can use them to separate sources by separate_sample.py. An example to run it is shown below:
python separate_sample.py model_weights/musdb18hq/locoformer-small.enc-crossattn64dim.dec-crossattn64dim.musical64.learnable-query /path/to/audio-file /path/to/output-directoryAssume you are now at ./spectral-feature-compression.
We provide a shell script, data.sh, to easily prepare the data. data.sh does the following processes:
- Download the MUSDB18-HQ or the DnR dataset and uncompress it
- Only on MUSDB18-HQ: Split training and validation set following the common split
- Apply unsupervised source activity detection (introduced in the BSRNN paper) to the training data and save the segmented audio files as an HDF5 file
It can be run as:
dataset_name=musdb18hq # or dnr
./recipes/${dataset_name}/scripts/data.shTraining can be run by running train.sh at each directory:
./recipes/musdb18hq/models/locoformer-small.enc-crossattn64dim.dec-crossattn64dim.musical64.learnable-query/train.shThe directory strucure after training is as follows. The lightning's checkpoints are saved under checkpoints. The training progress can be watched with Tensorboard.
recipes/musdb18hq/models/locoformer-small.enc-crossattn64dim.dec-crossattn64dim.musical64.learnable-query
├── checkpoints
│ ├── epoch=xxxx.ckpt
│ ├── epoch=xxxx.ckpt
│ ├── epoch=xxxx.ckpt
│ ├── epoch=xxxx.ckpt
│ ├── epoch=xxxx.ckpt
│ └── last.ckpt
├── config.yaml
├── events.out.tfevents.xxx.xxx.xxx.x
├── hparams.yaml
├── log.txt
├── merged_config.yaml
└── train.shOnce you finish the training, running separate.sh runs inference and scoring:
./recipes/musdb18hq/scripts/separate.sh /path/to/model_directoryHere, /path/to/model_directory can be either a path to directory including the checkpoints directory or a direct path to .ckpt file.
For instance, when evaluating recipes/musdb18hq/models/locoformer-small.enc-crossattn64dim.dec-crossattn64dim.musical64.learnable-query, you can give recipes/musdb18hq/models/locoformer-small.enc-crossattn64dim.dec-crossattn64dim.musical64.learnable-query or recipes/musdb18hq/models/locoformer-small.enc-crossattn64dim.dec-crossattn64dim.musical64.learnable-query/checkpoints/xxx.ckpt.
In the former case, all the checkpoints under that directory except for last.ckpt are averaged, and the averaged parameters are used for evaluation.
The segment and shift size in inference are by default set to 12 and 6 seconds, respectively.
One can change these configurations by giving them as the second and third arguments when running separate.sh.
Released under MIT license, as found in the LICENSE.md file.
All files, except as noted below:
Copyright (c) 2026 National Institute of Advanced Industrial Science and Technology (AIST), Japan
SPDX-License-Identifier: MIT
The following file:
spectral_feature_compression/core/model/bandit_split.py
was adapted from https://github.com/kwatcharasupat/bandit (license included in LICENSES/Apache-2.0.md)
Copyright (c) 2026 National Institute of Advanced Industrial Science and Technology (AIST), Japan
Copyright (c) 2023 Karn Watcharasupat
SPDX-License-Identifier: MIT
SPDX-License-Identifier: Apache-2.0
The following file:
spectral_feature_compression/core/model/bslocoformer.py
was adapted from https://github.com/merlresearch/tf-locoformer (license included in LICENSES/Apache-2.0.md)
Copyright (c) 2026 National Institute of Advanced Industrial Science and Technology (AIST), Japan
Copyright (c) 2024 Mitsubishi Electric Research Laboratories (MERL)
SPDX-License-Identifier: MIT
SPDX-License-Identifier: Apache-2.0
The following file:
spectral_feature_compression/core/model/average_model_params.py
was adapted from https://github.com/espnet/espnet (license included in LICENSES/Apache-2.0.md)
Copyright (c) 2026 National Institute of Advanced Industrial Science and Technology (AIST), Japan
Copyright (c) 2017 ESPnet Developers
SPDX-License-Identifier: MIT
SPDX-License-Identifier: Apache-2.0
The following files:
spectral_feature_compression/core/loss/snr.py
were adapted from https://github.com/kohei0209/self-remixing (license included in LICENSES/MIT.md)
Copyright (c) 2026 National Institute of Advanced Industrial Science and Technology (AIST), Japan
Copyright (c) 2024 Kohei Saijo
SPDX-License-Identifier: MIT
SPDX-License-Identifier: MIT