@@ -286,8 +286,7 @@ def forward(
286286
287287 # Forward iterations
288288 z_H , z_L = carry .z_H , carry .z_L
289- z_L = self .L_level (z_L , z_H + input_embeddings , ** seq_info )
290- z_H = self .L_level (z_H , z_L , ** seq_info )
289+ z_H = self .L_level (z_L , z_H + input_embeddings , ** seq_info )
291290
292291 # LM Outputs
293292 new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry (
@@ -356,6 +355,39 @@ def forward(
356355
357356 halted = is_last_step
358357
358+ # if training, and ACT is enabled
359+ if self .training and (self .config .halt_max_steps > 1 ):
360+
361+ # Halt signal
362+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
363+
364+ if self .config .no_ACT_continue :
365+ halted = halted | (q_halt_logits > 0 )
366+ else :
367+ halted = halted | (q_halt_logits > q_continue_logits )
368+
369+ # Exploration
370+ min_halt_steps = (
371+ torch .rand_like (q_halt_logits ) < self .config .halt_exploration_prob
372+ ) * torch .randint_like (new_steps , low = 2 , high = self .config .halt_max_steps + 1 )
373+ halted = halted & (new_steps >= min_halt_steps )
374+
375+ if not self .config .no_ACT_continue :
376+ # Compute target Q
377+ # NOTE: No replay buffer and target networks for computing target Q-value.
378+ # As batch_size is large, there're many parallel envs.
379+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
380+ _ , _ , (next_q_halt_logits , next_q_continue_logits ), _ , _ = self .inner (
381+ new_inner_carry , new_current_data
382+ )
383+ outputs ["target_q_continue" ] = torch .sigmoid (
384+ torch .where (
385+ is_last_step ,
386+ next_q_halt_logits ,
387+ torch .maximum (next_q_halt_logits , next_q_continue_logits ),
388+ )
389+ )
390+
359391 return (
360392 TinyRecursiveReasoningModel_ACTV1Carry (
361393 new_inner_carry , new_steps , halted , new_current_data
0 commit comments