Skip to content

Commit 6bd0646

Browse files
author
pseeth
committed
Fixing Github workflow, and setting .yml files up for configuring experiments.
1 parent 369674e commit 6bd0646

File tree

7 files changed

+117
-8
lines changed

7 files changed

+117
-8
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
4141
- name: Build the book
4242
run: |
43-
python common/download.py
43+
python -m common.data --run.cmd='download'
4444
jupyter-book build book/
4545
4646
# Push the book's HTML to github-pages

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ python -m common.data --run.cmd='download'
2020
jb build --all book/
2121
```
2222

23+
## Running an experiment
24+
25+
To run a basic mask estimation experiment with a Chimera network,
26+
do the following:
27+
28+
```
29+
python -m common.data --run.cmd='prepare_musdb'
30+
python -m common.exp.chimera --run.cmd='train'
31+
python -m common.exp.chimera --run.cmd='evaluate'
32+
```
33+
2334
## Questions? Comments? Typos? Bugs? Issues?
2435

2536
Open a github issue [here](https://github.com/source-separation/tutorial/issues/new)

common/data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,9 @@ def __init__(
287287
'source_time': ('uniform', 0, MAX_SOURCE_TIME),
288288
'event_time': ('const', 0),
289289
'event_duration': ('const', duration),
290-
'snr': snr,
291-
'pitch_shift': pitch_shift,
292-
'time_stretch': time_stretch,
290+
'snr': tuple(snr),
291+
'pitch_shift': tuple(pitch_shift),
292+
'time_stretch': tuple(time_stretch),
293293
}
294294
self.fg_path = fg_path
295295
self.sample_rate = sample_rate

common/exp/chimera.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
def train(
1414
args,
1515
seed : int = 0,
16-
num_epochs : int = 1,
17-
epoch_length : int = 1,
16+
num_epochs : int = 100,
17+
epoch_length : int = 1000,
1818
lr : float = 1e-3,
1919
batch_size : int = 1,
2020
dpcl_weight : float = .75,
@@ -128,6 +128,7 @@ def evaluate(
128128
output_folder : str = './results',
129129
num_workers : int = 1,
130130
):
131+
output_folder = Path(output_folder)
131132
stft_params, sample_rate = data.signal()
132133
# Output of net is always in alphabetical order
133134
labels = ['bass', 'drums', 'other', 'vocals']
@@ -140,6 +141,8 @@ def evaluate(
140141

141142
_device = utils.device()
142143
separator = models.deep_mask_estimation(_device)
144+
145+
utils.plot_metrics(separator, 'l1_loss', output_folder / 'metrics.png')
143146

144147
pbar = tqdm.tqdm(musdb)
145148
for item in pbar:
@@ -158,7 +161,7 @@ def evaluate(
158161
json.dump(scores, f, indent=4)
159162
break
160163

161-
output_file = Path(output_folder) / 'report_card.txt'
164+
output_file = output_folder / 'report_card.txt'
162165

163166
json_files = glob.glob(f"{output_folder}/*.json")
164167
if not json_files:

common/exp/conf/base.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
autoclip.percentile: 10
2+
3+
evaluate.folder: data/test
4+
evaluate.num_workers: 1
5+
evaluate.output_folder: ./results
6+
7+
logger.level: info
8+
9+
mixer.coherent_prob: 0.25
10+
mixer.master_label: vocals
11+
mixer.n_channels: 1
12+
mixer.num_mixtures: 10
13+
mixer.pitch_shift: [uniform, -2, 2]
14+
mixer.ref_db: [-30, -10]
15+
mixer.sample_rate: 44100
16+
mixer.snr: [uniform, -5, 5]
17+
mixer.source_file: [choose, []]
18+
mixer.time_stretch: [uniform, .9, 1.1]
19+
20+
train/mixer.fg_path: data/train
21+
train/mixer.duration: 4.0
22+
train/transform.excerpt_length: 4.0
23+
24+
val/mixer.duration: 10.0
25+
val/transform.excerpt_length: 10.0
26+
27+
val/mixer.coherent_prob: 1.0
28+
val.mixer.pitch_shift: [const, 0]
29+
val/mixer.time_stretch: [const, 1.0]
30+
val/mixer.fg_path: data/valid
31+
val/mixer.num_mixtures: 100
32+
33+
device.use: cuda
34+
35+
early_stopping.cumulative_delta: 0
36+
early_stopping.epochs: 30
37+
early_stopping.min_delta: 0.0
38+
39+
patience.epochs: 5
40+
patience.factor: 0.5
41+
patience.mode: min
42+
patience.verbose: false
43+
44+
transform.audio_only: false
45+
transform.mask_type: msa

common/exp/conf/chimera.yml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
$include:
2+
- common/conf/exp/base.yml
3+
4+
build.bidirectional: 1
5+
build.dropout: 0.3
6+
build.embedding_activation:
7+
- sigmoid
8+
- unit_norm
9+
build.embedding_size: 20
10+
build.hidden_size: 300
11+
build.mask_activation:
12+
- sigmoid
13+
build.num_audio_channels: 1
14+
build.num_layers: 4
15+
build.num_sources: 4
16+
build.rnn_type: lstm
17+
18+
deep_mask_estimation.mask_type: soft
19+
deep_mask_estimation.model_path: checkpoints/best.model.pth
20+
21+
signal.hop_length: 512
22+
signal.sample_rate: 44100
23+
signal.window_length: 2048
24+
signal.window_type: sqrt_hann
25+
26+
train.batch_size: 32
27+
train.dpcl_weight: 0.75
28+
train.epoch_length: 1000
29+
train.lr: 0.001
30+
train.mi_weight: 0.25
31+
train.num_epochs: 100
32+
train.num_workers: 4
33+
train.output_folder: .
34+
train.seed: 0

common/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import sys
44
import torch
5-
5+
import matplotlib.pyplot as plt
66
from . import argbind
77

88
@argbind.bind_to_parser()
@@ -69,3 +69,19 @@ def pprint(data):
6969
f"Source time : {obs.value['source_time']} \n"
7070
)
7171
logging.info('\n' + desc)
72+
73+
74+
def plot_metrics(separator, key, output_path):
75+
data = separator.metadata['trainer.state.epoch_history']
76+
plt.figure(figsize=(5, 4))
77+
78+
plt.subplot(111)
79+
plt.plot(data[f'validation/{key}'], label='val')
80+
plt.plot(data[f'train/{key}'], label='train')
81+
plt.xlabel('Epoch')
82+
plt.ylabel('Loss')
83+
plt.title('Loss')
84+
plt.legend()
85+
plt.tight_layout()
86+
87+
plt.savefig(output_path)

0 commit comments

Comments
 (0)