Conversation
…d loop to prevent recursion
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | -0.07 % | 4.183e+03 | 4.180e+03 | -2.84 | 32.74 | 38.35 |
test_proximal_jac_w7x_with_eq_update | -1.33 % | 6.653e+03 | 6.564e+03 | -88.46 | 145.76 | 156.51 |
test_proximal_freeb_jac | -0.46 % | 1.341e+04 | 1.335e+04 | -61.15 | 88.44 | 94.44 |
test_proximal_freeb_jac_blocked | 0.03 % | 7.694e+03 | 7.696e+03 | 2.44 | 76.49 | 83.01 |
test_proximal_freeb_jac_batched | 0.16 % | 7.688e+03 | 7.700e+03 | 12.33 | 75.89 | 83.32 |
test_proximal_jac_ripple | -3.83 % | 3.685e+03 | 3.544e+03 | -141.26 | 53.14 | 60.26 |
test_proximal_jac_ripple_bounce1d | -1.18 % | 3.769e+03 | 3.725e+03 | -44.35 | 66.59 | 73.81 |
test_eq_solve | -2.47 % | 2.245e+03 | 2.189e+03 | -55.51 | 82.68 | 89.83 |For the memory plots, go to the summary of |
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | -1.40 +/- 2.90 | -1.23e-02 +/- 2.54e-02 | 8.64e-01 +/- 1.6e-02 | 8.76e-01 +/- 2.0e-02 |
test_equilibrium_init_medres | -2.07 +/- 3.30 | -1.45e-01 +/- 2.32e-01 | 6.87e+00 +/- 1.9e-01 | 7.01e+00 +/- 1.3e-01 |
test_equilibrium_init_highres | -0.02 +/- 2.91 | -1.59e-03 +/- 2.23e-01 | 7.67e+00 +/- 1.2e-01 | 7.68e+00 +/- 1.9e-01 |
test_objective_compile_dshape_current | -0.88 +/- 1.95 | -3.86e-02 +/- 8.55e-02 | 4.34e+00 +/- 4.7e-02 | 4.38e+00 +/- 7.1e-02 |
test_objective_compute_dshape_current | -0.41 +/- 11.36 | -2.96e-06 +/- 8.15e-05 | 7.14e-04 +/- 5.1e-05 | 7.17e-04 +/- 6.3e-05 |
test_objective_jac_dshape_current | +4.72 +/- 27.27 | +1.15e-03 +/- 6.65e-03 | 2.56e-02 +/- 4.9e-03 | 2.44e-02 +/- 4.5e-03 |
test_perturb_2 | -0.10 +/- 1.90 | -2.09e-02 +/- 3.90e-01 | 2.05e+01 +/- 3.8e-01 | 2.06e+01 +/- 8.4e-02 |
test_proximal_jac_atf_with_eq_update | -0.50 +/- 1.69 | -6.06e-02 +/- 2.05e-01 | 1.21e+01 +/- 7.9e-02 | 1.21e+01 +/- 1.9e-01 |
test_proximal_freeb_jac | +0.51 +/- 1.60 | +2.42e-02 +/- 7.55e-02 | 4.75e+00 +/- 2.4e-02 | 4.73e+00 +/- 7.2e-02 |
test_solve_fixed_iter_compiled | +0.39 +/- 2.44 | +3.15e-02 +/- 1.99e-01 | 8.17e+00 +/- 1.1e-01 | 8.14e+00 +/- 1.6e-01 |
test_LinearConstraintProjection_build | +1.70 +/- 3.51 | +1.59e-01 +/- 3.28e-01 | 9.49e+00 +/- 2.8e-01 | 9.33e+00 +/- 1.7e-01 |
test_objective_compute_ripple_bounce1d | +0.34 +/- 4.88 | +1.01e-03 +/- 1.44e-02 | 2.97e-01 +/- 1.3e-02 | 2.96e-01 +/- 5.5e-03 |
test_objective_grad_ripple_bounce1d | -0.23 +/- 1.91 | -2.18e-03 +/- 1.77e-02 | 9.25e-01 +/- 1.2e-02 | 9.27e-01 +/- 1.3e-02 |
test_build_transform_fft_midres | -0.70 +/- 1.74 | -6.25e-03 +/- 1.55e-02 | 8.85e-01 +/- 8.8e-03 | 8.91e-01 +/- 1.3e-02 |
test_build_transform_fft_highres | +0.10 +/- 3.20 | +1.19e-03 +/- 3.78e-02 | 1.18e+00 +/- 3.3e-02 | 1.18e+00 +/- 1.8e-02 |
test_equilibrium_init_lowres | -0.76 +/- 3.36 | -4.96e-02 +/- 2.19e-01 | 6.46e+00 +/- 1.6e-01 | 6.50e+00 +/- 1.5e-01 |
test_objective_compile_atf | +0.96 +/- 3.82 | +6.04e-02 +/- 2.40e-01 | 6.35e+00 +/- 2.0e-01 | 6.29e+00 +/- 1.4e-01 |
test_objective_compute_atf | +4.86 +/- 13.44 | +1.00e-04 +/- 2.76e-04 | 2.16e-03 +/- 2.1e-04 | 2.06e-03 +/- 1.8e-04 |
test_objective_jac_atf | -0.56 +/- 3.75 | -8.83e-03 +/- 5.89e-02 | 1.56e+00 +/- 3.9e-02 | 1.57e+00 +/- 4.4e-02 |
test_perturb_1 | -0.47 +/- 1.55 | -7.71e-02 +/- 2.56e-01 | 1.64e+01 +/- 1.9e-01 | 1.65e+01 +/- 1.7e-01 |
test_proximal_jac_atf | -0.76 +/- 1.53 | -3.98e-02 +/- 8.06e-02 | 5.22e+00 +/- 5.7e-02 | 5.26e+00 +/- 5.7e-02 |
test_proximal_freeb_compute | +2.09 +/- 2.23 | +3.45e-03 +/- 3.68e-03 | 1.69e-01 +/- 2.5e-03 | 1.65e-01 +/- 2.7e-03 |
test_solve_fixed_iter | -0.12 +/- 1.98 | -3.67e-02 +/- 5.82e-01 | 2.94e+01 +/- 4.0e-01 | 2.95e+01 +/- 4.2e-01 |
test_objective_compute_ripple | +1.23 +/- 4.51 | +2.81e-03 +/- 1.03e-02 | 2.31e-01 +/- 4.3e-03 | 2.29e-01 +/- 9.4e-03 |
test_objective_grad_ripple | +0.90 +/- 2.17 | +7.69e-03 +/- 1.86e-02 | 8.64e-01 +/- 1.7e-02 | 8.57e-01 +/- 8.2e-03 |Github CI performance can be noisy. When evaluating the benchmarks, developers should take this into account. |
…mplify desc.backend jax tests
unalmis
left a comment
There was a problem hiding this comment.
my suggestion on the issue avoids an eigenvalue solve on import. current changes force an eigenvalue solve every time an object is made, which is worse for optimization and debugging with jit in general. i didn't check the claude stuff
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2156 +/- ##
==========================================
- Coverage 94.45% 94.39% -0.06%
==========================================
Files 101 101
Lines 28604 28666 +62
==========================================
+ Hits 27018 27060 +42
- Misses 1586 1606 +20
🚀 New features to boost your workflow:
|
| # --- Step 1: Topological sort via iterative Depth-First Search --- | ||
| # We need to process quantities with no dependencies before the | ||
| # quantities that depend on them. This way, when we process key K, | ||
| # all of K's dependencies already have their full_dependencies and | ||
| # full_with_axis_dependencies cached, and we can build K's full | ||
| # dependency set with a simple set union instead of deep recursion. |
There was a problem hiding this comment.
This sounds like it's basically doing what set_tier is doing below, so ideally we could do them both at the same time.
There was a problem hiding this comment.
Yeah, I was actually gonna suggest that in looped compute PR. The order obtained at the end can be used directly. That is why I have special sort instead of normal sort.
_build_data_indexCo-authored-by: Kaya Unalmis <[email protected]>
|
ididn't check calude stuff |
|
I will merge this after the dev meeting, in case @f0uriest has some additional comments |
|
As a final check, I used this. Save both import pickle
data_index_new = data_index.copy()
for p in data_index_new:
for key in data_index_new[p]:
data_index_new[p][key].pop("fun")
output = open("data_index_pr.pkl", "wb")
# Pickle dictionary using protocol 0.
pickle.dump(data_index, output)
output.close()Then compare them. pkl1 = open("data_index_master.pkl", "rb")
pkl2 = open("data_index_pr.pkl", "rb")
data1 = pickle.load(pkl1)
data2 = pickle.load(pkl2)
def norm(full):
# sort alphabetically (since PR sorts in topo order)
return {
"data": sorted(full["data"]),
"params": sorted(full["params"]),
"profiles": sorted(full["profiles"]),
"transforms": {k: sorted(v) for k, v in full["transforms"].items()},
}
for p in data1:
if p not in data2:
print(f"{p} not in pr data_index")
continue
for k in data1[p]:
if k not in data2[p]:
print(f"{p}-{k} not in pr data_index")
continue
fd1 = norm(data1[p][k]["full_dependencies"])
fd2 = norm(data2[p][k]["full_dependencies"])
assert fd1 == fd2, f"full_dependencies differ for {p}-{k}"
fa1 = norm(data1[p][k]["full_with_axis_dependencies"])
fa2 = norm(data2[p][k]["full_with_axis_dependencies"])
assert fa1 == fa2, f"full_with_axis_dependencies differ for {p}-{k}"
for p in data2:
if p not in data1:
print(f"{p} not in master data_index")
continue
for k in data2[p]:
if k not in data1[p]:
print(f"{p}-{k} not in master data_index")
continue
fd1 = norm(data1[p][k]["full_dependencies"])
fd2 = norm(data2[p][k]["full_dependencies"])
assert fd1 == fd2, f"full_dependencies differ for {p}-{k}"
fa1 = norm(data1[p][k]["full_with_axis_dependencies"])
fa2 = norm(data2[p][k]["full_with_axis_dependencies"])
assert fa1 == fa2, f"full_with_axis_dependencies differ for {p}-{k}"
pkl1.close()
pkl2.close()And this passes. |
Rory said he is fine, so I am merging now. |
desc.computedefault_quadcomputation at import time.Importing
desc.computetakes a very long time, and on my system, only_build_data_indexitself takes 3.5 seconds. This change makes it run in 0.15 seconds!The previous function was slow because of recursion, but it also had an extreme amount of redundant calls. For example,
get_params,get_profiles, andget_derivscallget_depsagain internally, and since the result of the first call is not stored yet, it rebuilds the whole tree again, almost 8 more times.New
_build_data_indexwas originally written by Claude Code (Opus 4.6), but I wrote most of the comments and I understand the algorithm behind.Resolves #2154