Conversation
with rod_torque and therefore never updating the best network.
…armRL into samtov_reward_shift
…armRL into samtov_reward_shift
WalkthroughAdds EspressoMD.add_cross and CI tests, introduces a new RodTorque task with unit tests, updates rod rotation to use velocity history and pass colloid directors, realigns rewards/values in policy/value losses, extends colloid utils for torque/force calculations, and adds optional best-network saving to episodic training. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Test as CI Test
participant MD as EspressoMD
participant Sys as espressomd.System
Test->>MD: add_cross(center,length,thickness,start_angle,n, ...)
MD->>MD: validate 2D, odd n, z=0, friction/type defaults
MD->>Sys: add center real particle
MD->>Sys: add virtual particles along orthogonal arms
MD->>MD: register radii/types
MD-->>Test: return center handle
sequenceDiagram
autonumber
actor Task as RodTorque
participant Col as Colloid list
participant U as ColloidUtils
Task->>Task: initialize(colloids) // record director, history
Task->>Col: filter rod & consumer colloids
Task->>U: compute_torque_on_rod(rod_positions,colloid_directors,colloid_positions)
U-->>Task: torque partition per consumer colloid
Task->>Task: update angular-velocity history, compute scaled reward
Task-->>Col: return per-colloid rewards
sequenceDiagram
autonumber
actor Trainer as EpisodicTrainer
participant Env as Environment
participant Net as Network
participant FS as Filesystem
loop episodes
Trainer->>Env: run_episode()
Env-->>Trainer: transitions, rewards
Trainer->>Trainer: update running_reward (mean last 10)
Trainer->>Net: train with aligned rewards/values
alt save_best_network set & running_reward >= 0.99*best
Trainer->>FS: save best_network
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
swarmrl/tasks/object_movement/rod_rotation.py (1)
149-158: Normalize decomp_fn output before applying rewardWhen
partition=True, divide the raw torque magnitudes by their sum (plus a small ε) so the partitions sum to 1:if self.partition: - colloid_partitions = self.decomp_fn( - colloid_positions, colloid_directors, rod_positions, rod_directors - ) + weights = self.decomp_fn( + colloid_positions, colloid_directors, rod_positions, rod_directors + ) + colloid_partitions = weights / (weights.sum() + 1e-8)
🧹 Nitpick comments (10)
swarmrl/losses/proximal_policy_loss.py (1)
103-111: Clarify slicing logic with a comment
The two-stage slicing is intentional: droppingrewards[0]yields T rewards and passing the original T+1predicted_valuessupplies the extra bootstrap value GAE needs, then trimmingpredicted_values[:-1]aligns all tensors back to T. Add a brief comment around lines 103–111 to document why we drop the first reward and then the final value.CI/espresso_tests/unit_tests/test_cross.py (1)
57-68: Reduce potential flakiness in rotation assertionRotation after 100 steps with random colloid placements may occasionally be small. Consider:
- Seeding explicitly (EspressoMD(seed=...)) and/or increasing steps.
- Asserting angle change via arccos(dot) instead of vector norm threshold.
CI/unit_tests/tasks/test_rod_torque.py (2)
30-33: Rename test class for clarityUse TestRodTorque to match the class under test.
-class TestRodRotation: +class TestRodTorque:
42-73: Great coverage; consider adding end-to-end call testYou thoroughly test internals. Add one test invoking task(colloids) end-to-end to assert reward sign/direction and shape.
Also applies to: 78-99, 100-118, 119-158
swarmrl/engine/espresso.py (2)
796-812: Duplicate radius-register updateType cross_particle_type is updated twice; keep a single update.
self.colloid_radius_register.update( {cross_particle_type: {"radius": partcl_radius, "aspect_ratio": 1.0}} ) - - director = create_orthonormal_vector(director) + director = create_orthonormal_vector(director) for k in range(n_particles - 1): dist_to_center = (-1) ** k * (k // 2 + 1) * point_dist pos_virt = center_pos + dist_to_center * director virtual_partcl = self.system.part.add( pos=pos_virt, director=director, virtual=True, type=cross_particle_type ) virtual_partcl.vs_auto_relate_to(center_part) self.colloids.append(virtual_partcl) - - self.colloid_radius_register.update( - {cross_particle_type: {"radius": partcl_radius, "aspect_ratio": 1.0}} - )
667-678: PEP 484: use Optional[...] in type hintsDefaults of None should be annotated Optional[T]. Also import Optional.
+from typing import Optional @@ - def add_cross( - self, - cross_center: pint.Quantity = None, - cross_length: pint.Quantity = None, - cross_thickness: pint.Quantity = None, - cross_start_angle: float = None, - n_particles: int = None, - friction_trans: pint.Quantity = None, - friction_rot: pint.Quantity = None, - cross_particle_type: int = None, - fixed: bool = True, - ): + def add_cross( + self, + cross_center: Optional[pint.Quantity] = None, + cross_length: Optional[pint.Quantity] = None, + cross_thickness: Optional[pint.Quantity] = None, + cross_start_angle: Optional[float] = None, + n_particles: Optional[int] = None, + friction_trans: Optional[pint.Quantity] = None, + friction_rot: Optional[pint.Quantity] = None, + cross_particle_type: Optional[int] = None, + fixed: bool = True, + ):Based on static analysis hints
swarmrl/tasks/object_movement/rod_torque.py (2)
26-27: Type hint: use float for angular_velocity_scaleDocstring says float; hint is int.
- angular_velocity_scale: int = 100, + angular_velocity_scale: float = 100.0,
234-247: Guard against empty selectionsIf no rod or no target colloids exist, indexing rod_directors[0] or building arrays will error. Validate and early-return zeros or raise clear ValueError.
- rod = [colloid for colloid in colloids if colloid.type == self.rod_type] + rod = [colloid for colloid in colloids if colloid.type == self.rod_type] + if not rod: + raise ValueError("No rod particles (type={}) present.".format(self.rod_type)) @@ - chosen_colloids = [ + chosen_colloids = [ colloid for colloid in colloids if colloid.type == self.particle_type ] + if not chosen_colloids: + return np.zeros(0)swarmrl/utils/colloid_utils.py (2)
58-66: Doc cleanup and stability note
- Referencing “line 61” in the doc is brittle; remove line numbers.
- Good call adding epsilons to avoid NaNs in grad at r≈0 and large r. Consider noting that the returned force is along the colloid director by design (non-physical), as the doc already hints.
No code change required here.
166-205: Clarify output semantics in compute_torque_on_rodThis returns a normalized torque vector per colloid (not a probability partition). Consider renaming to torque_contributions (or returning scalar weights as in compute_torque_partition_on_rod) for consistency.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
CI/espresso_tests/CTestTestfile.cmake(1 hunks)CI/espresso_tests/unit_tests/test_cross.py(1 hunks)CI/unit_tests/tasks/test_rod_torque.py(1 hunks)swarmrl/engine/espresso.py(1 hunks)swarmrl/losses/policy_gradient_loss.py(1 hunks)swarmrl/losses/proximal_policy_loss.py(1 hunks)swarmrl/tasks/object_movement/__init__.py(1 hunks)swarmrl/tasks/object_movement/rod_rotation.py(10 hunks)swarmrl/tasks/object_movement/rod_torque.py(1 hunks)swarmrl/trainers/episodic_trainer.py(4 hunks)swarmrl/utils/colloid_utils.py(3 hunks)swarmrl/value_functions/generalized_advantage_estimate.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
CI/unit_tests/tasks/test_rod_torque.py (2)
swarmrl/components/colloid.py (1)
Colloid(13-49)swarmrl/tasks/object_movement/rod_torque.py (3)
RodTorque(16-249)_compute_angular_velocity(119-153)_compute_torque_on_rod(91-117)
swarmrl/trainers/episodic_trainer.py (1)
swarmrl/trainers/trainer.py (1)
export_models(102-116)
CI/espresso_tests/unit_tests/test_cross.py (3)
swarmrl/force_functions/force_fn.py (1)
ForceFunction(13-86)swarmrl/engine/espresso.py (2)
add_colloids(459-544)add_cross(667-813)swarmrl/agents/dummy_models.py (1)
ConstForce(9-14)
swarmrl/tasks/object_movement/rod_torque.py (4)
swarmrl/components/colloid.py (1)
Colloid(13-49)swarmrl/tasks/task.py (1)
Task(15-118)swarmrl/utils/colloid_utils.py (1)
compute_torque_on_rod(166-205)swarmrl/tasks/object_movement/rod_rotation.py (2)
initialize(63-84)_compute_angular_velocity(86-118)
swarmrl/engine/espresso.py (1)
swarmrl/utils/utils.py (1)
vector_from_angles(24-27)
swarmrl/tasks/object_movement/__init__.py (2)
swarmrl/tasks/object_movement/rod_torque.py (1)
RodTorque(16-249)swarmrl/tasks/object_movement/rod_rotation.py (1)
RotateRod(16-231)
🪛 Ruff (0.13.3)
swarmrl/trainers/episodic_trainer.py
35-35: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
swarmrl/tasks/object_movement/rod_torque.py
49-49: Avoid specifying long messages outside the exception class
(TRY003)
54-57: Avoid specifying long messages outside the exception class
(TRY003)
swarmrl/engine/espresso.py
672-672: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
673-673: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
676-676: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
725-727: Avoid specifying long messages outside the exception class
(TRY003)
729-729: Avoid specifying long messages outside the exception class
(TRY003)
731-731: Avoid specifying long messages outside the exception class
(TRY003)
734-734: Avoid specifying long messages outside the exception class
(TRY003)
736-738: Avoid specifying long messages outside the exception class
(TRY003)
740-740: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build (3.12)
🔇 Additional comments (10)
swarmrl/trainers/episodic_trainer.py (3)
128-130: Verify the 0.99 multiplier threshold.The comment explains this helps "rule out lucky encounters" by requiring only 99% of the best reward. However, this might cause excessive saves in the early training phase when improvements are frequent. Consider whether a stricter threshold (e.g., 1.0 or 1.01) or a minimum episode count before saving would be more appropriate.
Please confirm this threshold aligns with the intended training behavior, especially during early episodes when the running reward is rapidly improving.
126-126: LGTM!The running reward calculation correctly handles cases with fewer than 10 episodes by averaging only the available rewards.
142-143: LGTM!The increased precision (4 decimals) and use of the pre-calculated
running_rewardimprove the display quality.CI/espresso_tests/CTestTestfile.cmake (1)
33-33: Test registration looks goodCross unit-test is properly wired into CTest.
swarmrl/tasks/object_movement/__init__.py (1)
6-11: Public API export LGTMRodTorque import and all update are correct.
CI/espresso_tests/unit_tests/test_cross.py (1)
69-74: WCA cutoff check is precise and valuableGood validation of sigma→cutoff vs radii; this guards regressions.
swarmrl/tasks/object_movement/rod_rotation.py (4)
116-118: Confirm sign/clipping semanticsClipping negatives to zero is fine for directional reward, given the CW/CCW sign flip. Please confirm this is intended for both directions.
58-60: LGTM: velocity history initZero-init and append index are clear and correct with the rolling update below.
224-225: LGTM: passing colloid_directorsIncluding directors aligns with updated decomposition utilities.
105-107: Verify JAX cross and rotations-per-hour conversion
- Confirm that jax.numpy.cross supports 2-element vectors (shape (2,)); if not, compute the “z-cross” via determinant (
v_prev[0]*v_new[1] - v_prev[1]*v_new[0]).- Ensure the hard-coded
/10in the RPH formula matches the actual sampling rate (parameterize or document dt instead of assuming 10 Hz).
| if friction_trans is None and not fixed: | ||
| raise ValueError( | ||
| "If you want the cross to move, you must provide a friction coefficient" | ||
| ) | ||
| if friction_rot is None: | ||
| raise ValueError("You must provide a rotational friction coefficient") | ||
| if cross_particle_type is None: | ||
| raise ValueError("You must provide a particle type for the cross") |
There was a problem hiding this comment.
Possible None dereference for friction_trans (fixed=True path)
When fixed=True and friction_trans=None, fric_trans = friction_trans.m_as(...) will raise. Either require friction_trans or compute a default.
Apply one of the following:
Option A (require input; minimal, consistent with tests):
- if friction_trans is None and not fixed:
- raise ValueError(
- "If you want the cross to move, you must provide a friction coefficient"
- )
+ if friction_trans is None:
+ raise ValueError(
+ "friction_trans must be provided (even if fixed=True)."
+ )Option B (derive safe default when fixed):
- fric_trans = friction_trans.m_as("sim_force/sim_velocity") # [F / v]
+ if friction_trans is None:
+ # Fixed center: use a tiny gamma to satisfy espresso; not used dynamically.
+ fric_trans = 1e-20
+ else:
+ fric_trans = friction_trans.m_as("sim_force/sim_velocity") # [F / v]Also applies to: 754-758
🧰 Tools
🪛 Ruff (0.13.3)
725-727: Avoid specifying long messages outside the exception class
(TRY003)
729-729: Avoid specifying long messages outside the exception class
(TRY003)
731-731: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In swarmrl/engine/espresso.py around lines 724-731 (and similarly update
754-758), the code permits friction_trans to be None when fixed=True but later
dereferences it causing a possible None dereference; enforce that friction_trans
is provided always: change the validation to raise ValueError when
friction_trans is None regardless of fixed (i.e., remove the exception
only-if-not-fixed branch) and keep the existing message, and ensure any later
use sites assume a non-None friction_trans (or add explicit casts/guards) so
downstream fric_trans = friction_trans.m_as(...) is safe; apply the same
validation/update at the 754-758 block.
| # Restructure the data to shift the rewards to after the action | ||
| # is taken. | ||
| feature_data = jnp.array(episode_data.features)[:-1] | ||
| action_data = jnp.array(episode_data.actions)[:-1] | ||
| reward_data = jnp.array(episode_data.rewards)[1:] |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Reward-action alignment logic looks correct.
The alignment ensures that reward_data[i] corresponds to the reward received after taking action_data[i] from feature_data[i], which is the standard reinforcement learning timing convention. The explanatory comments are helpful.
Verify that the alignment is covered by tests:
🏁 Script executed:
#!/bin/bash
# Description: Check for tests covering the reward-action alignment logic
# Search for test files that test PolicyGradientLoss
rg -n --type=py -C3 'PolicyGradientLoss' | rg -C3 'test_|Test'
# Search for tests that validate compute_loss with episode data
rg -n --type=py -C5 'compute_loss.*episode_data'Length of output: 3438
🏁 Script executed:
#!/bin/bash
# Inspect the PolicyGradientLoss tests for compute_loss alignment checks
sed -n '1,200p' CI/unit_tests/losses/_test_policy_gradient_loss.pyLength of output: 3689
Cover compute_loss alignment with tests.
No existing tests invoke PolicyGradientLoss.compute_loss, so the trimming and shifting logic (features[:-1], actions[:-1], rewards[1:]) isn’t validated. Add a unit test that supplies a simple episode_data, calls compute_loss, and asserts that features, actions, and rewards are correctly aligned.
🤖 Prompt for AI Agents
In swarmrl/losses/policy_gradient_loss.py around lines 121 to 125, the code
trims features and actions with [:-1] and shifts rewards with [1:] but there are
no tests validating this alignment; add a unit test that constructs a minimal
EpisodeData with known feature, action, and reward sequences (e.g., length N),
calls PolicyGradientLoss.compute_loss, and asserts that inside the method the
arrays are aligned as expected (features == original_features[:-1], actions ==
original_actions[:-1], rewards == original_rewards[1:]); ensure the test checks
lengths and element-wise values so future refactors won't break the
trimming/shift logic.
| self._velocity_history = np.zeros(self.velocity_history) | ||
| self._append_index = int(self.velocity_history - 1) | ||
| for item in colloids: |
There was a problem hiding this comment.
Guard against missing rod in initialize()
If no colloid with type == rod_type is present, _historic_rod_director remains unset and later calls will fail. Add a check and raise with a clear message.
Example:
for item in colloids:
if item.type == self.rod_type:
self._historic_rod_director = onp.copy(item.director)
break
+ else:
+ raise ValueError("RotateRod.initialize: no rod particles found (type "
+ f"{self.rod_type}).")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self._velocity_history = np.zeros(self.velocity_history) | |
| self._append_index = int(self.velocity_history - 1) | |
| for item in colloids: | |
| self._velocity_history = np.zeros(self.velocity_history) | |
| self._append_index = int(self.velocity_history - 1) | |
| for item in colloids: | |
| if item.type == self.rod_type: | |
| self._historic_rod_director = onp.copy(item.director) | |
| break | |
| else: | |
| raise ValueError("RotateRod.initialize: no rod particles found (type " | |
| f"{self.rod_type}).") |
| if self.angular_velocity_scale > 0.0: | ||
| torques_in_direction = colloid_torques_on_rod.at[ | ||
| colloid_torques_on_rod > 0.0 | ||
| ].set(0.0) | ||
| else: | ||
| torques_in_direction = colloid_torques_on_rod.at[ | ||
| colloid_torques_on_rod < 0.0 | ||
| ].set(0.0) | ||
|
|
||
| return ( | ||
| torques_in_direction * -1 | ||
| ) # Sign of the torques has to be inverted to avoid negativ rewards |
There was a problem hiding this comment.
Reward sign flips for CW; make rewards positive in the commanded direction
With direction="CW", torques become negative, velocity is clipped negative, and scale is negative ⇒ product is negative. Ensure positive rewards for correct-direction torque.
Apply this refactor:
- if self.angular_velocity_scale > 0.0:
- torques_in_direction = colloid_torques_on_rod.at[
- colloid_torques_on_rod > 0.0
- ].set(0.0)
- else:
- torques_in_direction = colloid_torques_on_rod.at[
- colloid_torques_on_rod < 0.0
- ].set(0.0)
-
- return (
- torques_in_direction * -1
- ) # Sign of the torques has to be inverted to avoid negativ rewards
+ # Keep only torques in the commanded direction and return positive magnitudes.
+ if self.angular_velocity_scale > 0.0:
+ # CCW: keep negative torques, flip sign to positive magnitude
+ torques_in_direction = np.where(
+ colloid_torques_on_rod < 0.0, -colloid_torques_on_rod, 0.0
+ )
+ else:
+ # CW: keep positive torques, already positive magnitude
+ torques_in_direction = np.where(
+ colloid_torques_on_rod > 0.0, colloid_torques_on_rod, 0.0
+ )
+ return torques_in_direction
@@
- return torques * velocity * self.angular_velocity_scale
+ # Use magnitudes for velocity and scale so rewards are positive for the
+ # commanded direction only (torques already filtered to correct direction).
+ return torques * np.abs(velocity) * np.abs(self.angular_velocity_scale)Also applies to: 218-219
| reset_frequency: int = 1, | ||
| load_bar: bool = True, | ||
| save_episodic_data: bool = True, | ||
| save_best_network: str = None, |
There was a problem hiding this comment.
Fix type annotation per PEP 484.
The parameter should use explicit Optional[str] instead of implicit str = None.
Apply this diff to fix the type annotation:
- save_best_network: str = None,
+ save_best_network: Optional[str] = None,Also add the import at the top of the file if not already present:
from typing import Optional🧰 Tools
🪛 Ruff (0.13.3)
35-35: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
🤖 Prompt for AI Agents
In swarmrl/trainers/episodic_trainer.py around line 35, the parameter annotation
"save_best_network: str = None" violates PEP 484; change it to use Optional by
updating the signature to "save_best_network: Optional[str] = None" and add
"from typing import Optional" to the top of the file if not already imported;
ensure imports are sorted/placed with other typing imports and run type-checking
to confirm no other annotations rely on implicit None typings.
| best_running_reward = 0.0 | ||
| running_reward = 0.0 |
There was a problem hiding this comment.
Best reward logic fails with negative rewards.
Initializing best_running_reward to 0.0 creates problematic edge cases:
- If rewards are always negative (e.g., penalty-based tasks), no network is ever saved since negative values will never exceed
0.0 * 0.99. - Conversely, the first episode with any positive
running_rewardtriggers a save, even if it's a poor result (e.g., 0.001 > 0.0).
Additionally, the path concatenation using str(save_best_network) + "/best_network" is not cross-platform safe.
Apply this diff to initialize with negative infinity and use proper path handling:
- best_running_reward = 0.0
+ best_running_reward = float('-inf')
running_reward = 0.0+ import os
best_running_reward = running_reward
self.export_models(
- directory=str(save_best_network) + "/best_network"
+ directory=os.path.join(str(save_best_network), "best_network")
)Alternatively, use pathlib.Path for modern path handling:
from pathlib import Path
# ...
directory=str(Path(save_best_network) / "best_network")Also applies to: 128-135
| def compute_torque_partition_on_rod( | ||
| colloid_positions, colloid_directors, rod_positions, rod_directions | ||
| ): | ||
| """ | ||
| Compute the torque on a rod using a WCA potential. | ||
| Compute the torque partition on a rod using a WCA potential. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| colloid_positions : jnp.ndarray (n_colloids, 3) | ||
| Positions of the colloids. | ||
| colloid_directors : jnp.ndarray (n_colloids, 3) | ||
| Directors of the colloids. | ||
| rod_positions : jnp.ndarray (rod_particles, 3) | ||
| Positions of the rod particles. | ||
| rod_directions : jnp.ndarray (rod_particles, 3) | ||
| Directors of the rod particles. | ||
| """ | ||
| # (n_colloids, rod_particles, 3) | ||
| distance_matrix = compute_distance_matrix(colloid_positions, rod_positions) | ||
| distance_matrix = distance_matrix[:, :, :2] | ||
| # distance_matrix = distance_matrix[:, :, :2] | ||
|
|
||
| # Force on the rod | ||
| rod_map_fn = jax.vmap(compute_forces, in_axes=(0,)) # map over rod particles | ||
| colloid_map_fn = jax.vmap(rod_map_fn, in_axes=(0,)) # map over colloids | ||
| rod_map_fn = jax.vmap(compute_forces, in_axes=(0, None)) # map over rod particles | ||
| colloid_map_fn = jax.vmap(rod_map_fn, in_axes=(0, 0)) # map over colloids | ||
|
|
||
| # (n_colloids, rod_particles, 3) | ||
| forces = colloid_map_fn(distance_matrix) | ||
| forces = colloid_map_fn(distance_matrix, colloid_directors) | ||
|
|
||
| # Compute torques | ||
| colloid_rod_map = jax.vmap(compute_torque, in_axes=(0, 0)) | ||
| colloid_only_map = jax.vmap(colloid_rod_map, in_axes=(0, None)) | ||
|
|
||
| torques = colloid_only_map(forces, rod_directions) | ||
| net_rod_torque = torques.sum(axis=1) | ||
| torque_magnitude = jnp.linalg.norm(net_rod_torque, axis=-1) + 1e-8 | ||
|
|
||
| return torque_magnitude | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Return normalized weights from compute_torque_partition_on_rod
The function returns raw torque magnitudes per colloid; downstream code treats them as partitions. Normalize to keep total reward unchanged.
- net_rod_torque = torques.sum(axis=1)
- torque_magnitude = jnp.linalg.norm(net_rod_torque, axis=-1) + 1e-8
-
- return torque_magnitude
+ net_rod_torque = torques.sum(axis=1)
+ torque_magnitude = jnp.linalg.norm(net_rod_torque, axis=-1) + 1e-8
+ weights = torque_magnitude / (torque_magnitude.sum() + 1e-8)
+ return weightsOptional: For physical torque, use lever arms (r vectors) instead of rod_directions, aligning with compute_torque_on_rod.
🏁 Script executed:
#!/bin/bash
rg -n 'compute_torque_partition_on_rod' -C3Length of output: 1494
🏁 Script executed:
#!/bin/bash
rg -n 'decomp_fn\(' -C3 swarmrl/tasks/object_movement/rod_rotation.pyLength of output: 279
🏁 Script executed:
#!/bin/bash
rg -n 'colloid_partitions' -C5 swarmrl/tasks/object_movement/rod_rotation.pyLength of output: 699
Normalize torque magnitudes to weights summing to 1
compute_torque_partition_on_rod returns raw torque magnitudes, but downstream code multiplies reward by these as partitions—this changes the total reward. Normalize before returning:
net_rod_torque = torques.sum(axis=1)
torque_magnitude = jnp.linalg.norm(net_rod_torque, axis=-1) + 1e-8
- return torque_magnitude
+ weights = torque_magnitude / (torque_magnitude.sum() + 1e-8)
+ return weights🤖 Prompt for AI Agents
In swarmrl/utils/colloid_utils.py around lines 106 to 143,
compute_torque_partition_on_rod currently returns raw torque magnitudes which
are used as partitions downstream; change it to return normalized weights that
sum to 1. After computing torque_magnitude, compute total =
torque_magnitude.sum(axis=0) (or sum over colloids as appropriate), guard
against zero by adding a small epsilon, then divide torque_magnitude by total to
produce normalized_weights with the same shape; return normalized_weights
instead of raw magnitudes.
| def compute_rod_particle_distances(rod_positions): | ||
| """ | ||
| Compute the vectors between the middle of the rod to each colloid. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| rod_positions : Positions of all the rod colloids. | ||
| """ | ||
|
|
||
| def _sub_compute(a, b): | ||
| return b - a | ||
|
|
||
| distance_fn = jax.vmap(_sub_compute, in_axes=(0, None)) | ||
|
|
||
| return distance_fn( | ||
| rod_positions, rod_positions[0] | ||
| ) # rod_positions[0] is the middle of the rod | ||
|
|
There was a problem hiding this comment.
Fix lever-arm direction or docstring in compute_rod_particle_distances
The function returns center − position (vectors pointing toward the center), while the doc suggests “middle to each colloid.” Use position − center for a standard lever arm.
- def _sub_compute(a, b):
- return b - a
+ def _sub_compute(a, b):
+ return a - bAnd keep:
- return distance_fn(rod_positions, rod_positions[0])
+ return distance_fn(rod_positions, rod_positions[0])This yields vectors from the rod center to each particle.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def compute_rod_particle_distances(rod_positions): | |
| """ | |
| Compute the vectors between the middle of the rod to each colloid. | |
| Parameters | |
| ---------- | |
| rod_positions : Positions of all the rod colloids. | |
| """ | |
| def _sub_compute(a, b): | |
| return b - a | |
| distance_fn = jax.vmap(_sub_compute, in_axes=(0, None)) | |
| return distance_fn( | |
| rod_positions, rod_positions[0] | |
| ) # rod_positions[0] is the middle of the rod | |
| def compute_rod_particle_distances(rod_positions): | |
| """ | |
| Compute the vectors between the middle of the rod to each colloid. | |
| Parameters | |
| ---------- | |
| rod_positions : Positions of all the rod colloids. | |
| """ | |
| def _sub_compute(a, b): | |
| return a - b | |
| distance_fn = jax.vmap(_sub_compute, in_axes=(0, None)) | |
| return distance_fn( | |
| rod_positions, rod_positions[0] | |
| ) # rod_positions[0] is the middle of the rod |
🤖 Prompt for AI Agents
In swarmrl/utils/colloid_utils.py around lines 146 to 163, the lever-arm sign is
reversed: _sub_compute currently returns b - a (center − position) while the
docstring says "middle to each colloid." Change the subtraction to return a - b
so each vector is position − center (i.e., from the rod center to each
particle), and update the docstring line to explicitly state "This yields
vectors from the rod center to each particle." Ensure the jax.vmap in_axes usage
remains (0, None) and no other logic changes.
| delta = rewards[t] + self.gamma * values[t + 1] - values[t] | ||
|
|
||
| gae = delta + self.gamma * self.lambda_ * gae | ||
| advantages = advantages.at[t].set(gae) | ||
|
|
||
| returns = advantages + values | ||
| returns = advantages + values[:-1] | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
OOB risk and docstring mismatch: values must be length T+1 (bootstrap) for delta/returns to work
- values[t+1] at the last step will IndexError if values has length T.
- returns = advantages + values[:-1] also assumes length T+1.
- Docstring still says values shape is (T, N).
Add a shape guard and update docs to require values.shape[0] == rewards.shape[0] + 1. Minimal fix below.
Apply this diff within the function to guard shapes:
@@
- gae = 0
+ gae = 0
advantages = np.zeros_like(rewards)
+ # Require bootstrap value at t = T (values length = T+1).
+ assert values.shape[0] == rewards.shape[0] + 1 and values.shape[1:] == rewards.shape[1:], (
+ f"GAE expects values shape (T+1, ...). Got rewards {rewards.shape}, values {values.shape}."
+ )
for t in reversed(range(len(rewards))):
delta = rewards[t] + self.gamma * values[t + 1] - values[t]
@@
- returns = advantages + values[:-1]
+ returns = advantages + values[:-1]Also update the docstring to reflect the new contract (outside the changed lines):
- values : np.ndarray (n_time_steps, n_particles)
- The prediction of the critic for the episode.
+ values : np.ndarray (n_time_steps + 1, n_particles)
+ Critic predictions with a bootstrap value for the step after the last reward.
+ I.e., if rewards has length T, values must have length T+1.Optional (stability): initialize gae with correct dtype/shape:
- gae = 0
+ gae = np.zeros_like(rewards[0])Verification suggestion:
- Ensure all GAE call sites pass values with length T+1. If not, either pad with the bootstrap value (e.g., v_T = 0 for terminal) or reintroduce last-step handling.
🏁 Script executed:
#!/bin/bash
FILE=$(fd "generalized_advantage_estimate.py" -t f | head -n 1)
echo "Inspecting: $FILE"
sed -n '1,200p' "$FILE"Length of output: 2095
Add shape check and update docstring for bootstrap value
- In
__call__(around lines 59–65), addto prevent out-of-bounds atassert values.shape[0] == rewards.shape[0] + 1, ( f"GAE expects values.shape[0] == rewards.shape[0] + 1, got {values.shape}" )
values[t+1]. - Update the docstring to
values : np.ndarray (n_time_steps + 1, n_particles) Critic predictions with a bootstrap value for the step after the last reward.
🤖 Prompt for AI Agents
In swarmrl/value_functions/generalized_advantage_estimate.py around lines 59 to
65, add a shape assertion to prevent out-of-bounds access when referencing
values[t+1] by asserting values.shape[0] == rewards.shape[0] + 1 with a helpful
error message, and update the function docstring for the values parameter to
read: "values : np.ndarray (n_time_steps + 1, n_particles) Critic
predictions with a bootstrap value for the step after the last reward." This
ensures callers supply a bootstrap value and provides clear documentation of
expected shapes.
This includes all the code I wrote into my own fork of SwarmRL. Notice, the used branch: samtov_reward_shift.
Changes to episodic_trainer.py allow saving the network with the highest reward over multiple episodes. This is probably obsolete, because of Jannik's checkpointing.
After the merge, some stuff in rod_rotation.py and colloid_utils.py can probably be deleted, because of similarities to my code.
Summary by CodeRabbit
New Features
Refactor
Tests