Skip to content

Commit 06ac9a5

Browse files
committed
dynamic conversational-train loading in progress
1 parent 219f7b4 commit 06ac9a5

File tree

3 files changed

+75
-7
lines changed

3 files changed

+75
-7
lines changed

src/train/SeleKT/selekt.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,10 @@ def train(args):
300300
print(f'Resuming from checkpoint: {last_checkpoint}')
301301

302302

303-
# response_template = "#RESPONSE\n"
304-
# collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
303+
collator = None
304+
if args.is_conversational_training:
305+
response_template = "#RESPONSE\n"
306+
collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
305307

306308
callback = Callback(base_model_path=args.base_model_path, flush_steps=1, alpha=args.alpha)
307309
trainer = SFTTrainer(
@@ -310,7 +312,7 @@ def train(args):
310312
train_dataset=dataset,
311313
args=training_config,
312314
callbacks=[callback],
313-
# data_collator=collator,
315+
data_collator=collator,
314316
)
315317
callback.set_trainer(trainer)
316318
print(f"Starting training for epoch {args.num_train_epochs}")

src/train/sft/run.sh

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/bin/bash
2+
3+
4+
export MODEL_NAME=""
5+
export DESC=""
6+
7+
OUTPUT_DIR=""
8+
TRAIN_DATA=""
9+
MODEL_PATH=""
10+
11+
mkdir -p $OUTPUT_DIR
12+
13+
accelerate launch \
14+
--config_file=../configs/general_acc.yaml \
15+
sft.py \
16+
--model_name_or_path "$MODEL_PATH" \
17+
--train_data_path "$TRAIN_DATA" \
18+
--output_dir ${OUTPUT_DIR} \
19+
--num_train_epochs 3 \
20+
--model_max_length 16384 \
21+
--per_device_train_batch_size 1 \
22+
--gradient_accumulation_steps 4 \
23+
--save_strategy "epoch" \
24+
--save_steps 760 \
25+
--save_total_limit 25 \
26+
--learning_rate 1e-5 \
27+
--warmup_ratio 0.1 \
28+
--weight_decay 0.1 \
29+
--logging_steps 5 \
30+
--lr_scheduler_type "cosine" \
31+
--report_to "wandb" \
32+
--gradient_checkpointing True \
33+
--deepspeed ../configs/ds_config.json \
34+
--bf16 True \
35+
--run_name "" \
36+
37+
38+
39+
accelerate launch \
40+
--config_file=../configs/general_acc.yaml \
41+
sft.py \
42+
--model_name_or_path "${MODEL_PATH}" \
43+
--train_data_path "$TRAIN_DATA" \
44+
--output_dir ${OUTPUT_DIR} \
45+
--num_train_epochs 3 \
46+
--model_max_length 16384 \
47+
--per_device_train_batch_size 1 \
48+
--gradient_accumulation_steps 4 \
49+
--save_strategy "epoch" \
50+
--save_steps 760 \
51+
--save_total_limit 25 \
52+
--learning_rate 1e-5 \
53+
--warmup_ratio 0.1 \
54+
--weight_decay 0.1 \
55+
--logging_steps 5 \
56+
--lr_scheduler_type "cosine" \
57+
--report_to "wandb" \
58+
--gradient_checkpointing True \
59+
--deepspeed ../configs/ds_config.json \
60+
--bf16 True \
61+
--run_name "" \
62+
--is_conversational_training \
63+

src/train/sft/sft.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def parse_args():
6565
parser.add_argument("--debug", type=bool, default=False)
6666
parser.add_argument("--packing", type=bool, default=True,
6767
help="Whether to use packing for training")
68+
parser.add_argument("--is_conversational_training", type=bool, action='store_true',
69+
help="Whether to use conversational training format")
6870

6971
args, _ = parser.parse_known_args()
7072
return args
@@ -108,7 +110,6 @@ def __init__(self, flush_steps=None):
108110
self.flush_steps = flush_steps
109111

110112
def on_step_end(self, args, state, control, model, processing_class , **kwargs):
111-
# import sys; sys.exit(0)
112113
if state.global_step % self.flush_steps == 0:
113114
get_accelerator().empty_cache()
114115
if dist.is_initialized():
@@ -172,8 +173,10 @@ def main():
172173
if last_checkpoint:
173174
print(f'Resuming from checkpoint: {last_checkpoint}')
174175

175-
# response_template = "#RESPONSE\n"
176-
# collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
176+
collator = None
177+
if args.is_conversational_training:
178+
response_template = "#RESPONSE\n"
179+
collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
177180

178181
# Initialize trainer
179182
trainer = SFTTrainer(
@@ -182,7 +185,7 @@ def main():
182185
train_dataset=dataset,
183186
args=training_config,
184187
callbacks=[Callback(flush_steps=1)],
185-
# data_collator=collator,
188+
data_collator=collator,
186189
)
187190

188191
# Start training

0 commit comments

Comments
 (0)