forked from thomasthebaud/speechLLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
146 lines (129 loc) · 6.14 KB
/
utils.py
File metadata and controls
146 lines (129 loc) · 6.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
import json
import numpy as np
import os
def get_data_config(args):
if args.use_config is None: #use default config
datasets = {
"train":['librispeech_train-clean-100', 'iemocap_ses01-03', 'CV-EN_train', 'MSP_Podcast_Train', 'voxceleb2_enriched_dev'],
"dev":['librispeech_dev-clean', 'iemocap_ses04', 'CV-EN_dev', 'MSP_Podcast_Validation', 'voxceleb2_enriched_test'],
"test":['librispeech_test-clean', 'iemocap_ses05', 'CV-EN_test', 'MSP_Podcast_Test', 'voxceleb2_enriched_test'],
}
datasets = {split:{data:[] for data in datasets[split]} for split in datasets} #empty list means use all available fields for that dataset
else:
with open(f'config/data/{args.use_config}', 'r') as file:
datasets = json.load(file)
use_summaries = False
for data in datasets['dev']:
if "summary" in datasets['dev'][data]:
use_summaries=True
break
return datasets, use_summaries
def get_connector_config(args):
connector_config = args.connector+".json"
if connector_config not in os.listdir('config/model/'): connector_config = "cnn_str1.2.1.json"
with open(f'config/model/{connector_config}', 'r') as file:
connector = json.load(file)
# for k in connector:
# print(k, connector[k], type(connector[k]))
return connector
def get_model_config():
# Parse args
parser = argparse.ArgumentParser()
parser.add_argument('--encoder')
parser.add_argument('--connector')
parser.add_argument('--llm')
parser.add_argument('--batch-size', default=16, type=int)
parser.add_argument('--truncate-sec', default=-1, type=int)
parser.add_argument('--lr', default=1.0)
parser.add_argument('--encoder-lr', default=-1)
parser.add_argument("--no-lora", action='store_true')
parser.add_argument("--ft-encoder", action='store_true')
parser.add_argument("--ft-layers", type=str, default='all')
parser.add_argument("--use-text", action='store_true')
parser.add_argument("--prob-text", default=0.5, type=float)
parser.add_argument("--no-audio", action='store_true')
parser.add_argument('--epoch-to-test', default=1, type=int)
parser.add_argument("--meanpool", default=1, type=int)
parser.add_argument("--total-training-epoch", default=1000, type=int)
parser.add_argument("--use-config", default=None, type=str)
parser.add_argument("--group", default='August experiments', type=str)
parser.add_argument("--nickname", default='_', type=str)
parser.add_argument("--test-on", default='A', type=str)
args = parser.parse_args()
# Datasets config
datasets, use_summaries = get_data_config(args)
# Connector config
connector = get_connector_config(args)
# Training Parameters
lr = float(args.lr)
if args.encoder_lr==-1: args.encoder_lr = float(args.lr)/50
if lr == 1.0: lr = 1e-4 if 'linear' not in connector['name'] else 1e-5
batch_size = int(args.batch_size)
use_lora = not args.no_lora
# Model naming
model_name = f"{args.encoder.split('/')[-1]}-{connector['name']}-{args.llm.split('-')[0]}-bs{batch_size}"
if args.no_lora: model_name = model_name+'_nolora'
if args.use_text:
model_name = model_name + f"_p{float(args.prob_text)}"
if args.no_audio: model_name = 'T_'+model_name
else: model_name = 'AT_'+model_name
else:
if args.no_audio: exit("not using text nor audio!")
else: model_name = 'A_'+model_name
if args.ft_encoder: model_name = model_name+'_ft_encoder'
if args.ft_layers!='all':
start_ft, end_ft = args.ft_layers.split('-')
ft_layers = (int(start_ft), int(end_ft))
model_name = model_name+f'_layers{int(start_ft)+1}-{end_ft}'
else: ft_layers = (0,100)
if args.meanpool!=1: model_name = model_name+f'_mp{args.meanpool}'
if len(connector['in_meanpool'])>0: model_name = model_name+"_inmp"+'.'.join([str(i[1]) for i in connector['in_meanpool']])
if args.encoder_lr==-1: model_name = f"{model_name}_lrenc{args.encoder_lr}"
if connector['name']=='cnn': model_name = model_name+"_str2"#+'.'.join([str(i) for i in connector['k']])
model_name = f"{model_name}_lr{lr}"
if args.nickname!='_': model_name = model_name + str(args.nickname)
# Wandb params
log_path = 'logs/'+model_name
group = args.group
# Encoder
audio_encoder_name=args.encoder
if args.encoder=='MFCC': audio_enc_dim = 80
else: audio_enc_dim = connector['input_dim']
# LLM
if args.llm=='TinyLlama-1.1B-Chat-v1.0':llm_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Get all infos
model_config = {
'audio_enc_dim':audio_enc_dim,
'audio_encoder_name': audio_encoder_name,
'connector_args': connector,
'llm_name': llm_name,
'finetune_encoder': args.ft_encoder,
'ft_layers': ft_layers,
'meanpool':int(args.meanpool),
'use_lora': use_lora,
'use_text':args.use_text,
'prob_text':float(args.prob_text),
'use_audio':not args.no_audio,
'lora_r': 8,
'lora_alpha': 16,
'max_lr': lr,
'enc_lr': float(args.encoder_lr),
'batch_size':batch_size,
'total_training_epoch': int(args.total_training_epoch),
'warmup_steps': 100,
'grad_accumulate_steps': 64//batch_size,
'max_number_seconds': args.truncate_sec,
'train_batch_per_epoch': 20_000,#4096
'train_sets':datasets['train'],
'dev_sets':datasets['dev'],
'test_sets':datasets['test'],
'max_size_per_dev_set':100,
'log_path':log_path,
'group':group,
'model_name':model_name,
'epoch_to_test':int(args.epoch_to_test),
'use_summaries':use_summaries,
'test_on':str(args.test_on)
}
return model_config