Skip to content

Commit c502b1b

Browse files
committed
feat: 1-level TRM
1 parent a6df540 commit c502b1b

2 files changed

Lines changed: 37 additions & 4 deletions

File tree

.vscode/launch.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
"arch=trm",
1616
"data_paths=[data/sudoku-extreme-1k-aug-1000]",
1717
"evaluators=[]",
18-
"epochs=5",
19-
"eval_interval=1",
18+
"epochs=50000",
19+
"eval_interval=5000",
2020
"lr=1e-4",
21+
"global_batch_size=768",
2122
"puzzle_emb_lr=1e-4",
2223
"weight_decay=1.0",
2324
"puzzle_emb_weight_decay=1.0",

src/recursion/models/recursive_reasoning/trm.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,7 @@ def forward(
286286

287287
# Forward iterations
288288
z_H, z_L = carry.z_H, carry.z_L
289-
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
290-
z_H = self.L_level(z_H, z_L, **seq_info)
289+
z_H = self.L_level(z_L, z_H + input_embeddings, **seq_info)
291290

292291
# LM Outputs
293292
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(
@@ -356,6 +355,39 @@ def forward(
356355

357356
halted = is_last_step
358357

358+
# if training, and ACT is enabled
359+
if self.training and (self.config.halt_max_steps > 1):
360+
361+
# Halt signal
362+
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
363+
364+
if self.config.no_ACT_continue:
365+
halted = halted | (q_halt_logits > 0)
366+
else:
367+
halted = halted | (q_halt_logits > q_continue_logits)
368+
369+
# Exploration
370+
min_halt_steps = (
371+
torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
372+
) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
373+
halted = halted & (new_steps >= min_halt_steps)
374+
375+
if not self.config.no_ACT_continue:
376+
# Compute target Q
377+
# NOTE: No replay buffer and target networks for computing target Q-value.
378+
# As batch_size is large, there're many parallel envs.
379+
# Similar concept as PQN https://arxiv.org/abs/2407.04811
380+
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(
381+
new_inner_carry, new_current_data
382+
)
383+
outputs["target_q_continue"] = torch.sigmoid(
384+
torch.where(
385+
is_last_step,
386+
next_q_halt_logits,
387+
torch.maximum(next_q_halt_logits, next_q_continue_logits),
388+
)
389+
)
390+
359391
return (
360392
TinyRecursiveReasoningModel_ACTV1Carry(
361393
new_inner_carry, new_steps, halted, new_current_data

0 commit comments

Comments
 (0)