Skip to content

Commit 97f9071

Browse files
committed
feat: Simplify TRM
1 parent bb38440 commit 97f9071

1 file changed

Lines changed: 2 additions & 9 deletions

File tree

  • src/recursion/models/recursive_reasoning

src/recursion/models/recursive_reasoning/trm.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -282,18 +282,11 @@ def forward(
282282

283283
# Input encoding
284284
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
285+
# shape: [batch_size, 97, 512] (97 tokens of dim 512)
285286

286287
# Forward iterations
287288
z_H, z_L = carry.z_H, carry.z_L
288-
# H_cycles-1 without grad
289-
with torch.no_grad():
290-
for _H_step in range(self.config.H_cycles - 1):
291-
for _L_step in range(self.config.L_cycles):
292-
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
293-
z_H = self.L_level(z_H, z_L, **seq_info)
294-
# 1 with grad
295-
for _L_step in range(self.config.L_cycles):
296-
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
289+
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
297290
z_H = self.L_level(z_H, z_L, **seq_info)
298291

299292
# LM Outputs

0 commit comments

Comments
 (0)