File tree Expand file tree Collapse file tree
src/recursion/models/recursive_reasoning Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -282,18 +282,11 @@ def forward(
282282
283283 # Input encoding
284284 input_embeddings = self ._input_embeddings (batch ["inputs" ], batch ["puzzle_identifiers" ])
285+ # shape: [batch_size, 97, 512] (97 tokens of dim 512)
285286
286287 # Forward iterations
287288 z_H , z_L = carry .z_H , carry .z_L
288- # H_cycles-1 without grad
289- with torch .no_grad ():
290- for _H_step in range (self .config .H_cycles - 1 ):
291- for _L_step in range (self .config .L_cycles ):
292- z_L = self .L_level (z_L , z_H + input_embeddings , ** seq_info )
293- z_H = self .L_level (z_H , z_L , ** seq_info )
294- # 1 with grad
295- for _L_step in range (self .config .L_cycles ):
296- z_L = self .L_level (z_L , z_H + input_embeddings , ** seq_info )
289+ z_L = self .L_level (z_L , z_H + input_embeddings , ** seq_info )
297290 z_H = self .L_level (z_H , z_L , ** seq_info )
298291
299292 # LM Outputs
You can’t perform that action at this time.
0 commit comments