Skip to content

Commit 89d0d82

Browse files
Merge pull request #401 from egraphs-good/fix-cost-function
Fix lookup of value in cost model
2 parents a19aa86 + d38a170 commit 89d0d82

3 files changed

Lines changed: 37 additions & 0 deletions

File tree

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ _This project uses semantic versioning_
1818
- Improve doctest support, teaching expressions about their `__module__`, `__dir__`, and special methods.
1919
- Surface original Python exceptions from the runtime and tighten pretty-printing of values that cannot be re-parsed to make debugging e-graph executions easier.
2020
- Update the bundled Egglog crate, visualizer, and related dev dependencies (including `ipykernel`) to pick up the latest backend fixes.
21+
- Fix lookup of cost model based on value (see [zulip for issue](https://egraphs.zulipchat.com/#narrow/channel/375765-egg.2Fegglog/topic/Cost.20function.3A.20using.20function.20values.20of.20subtrees/near/577062352))
2122

2223
## 11.4.0 (2025-10-02)
2324

python/egglog/egraph_state.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,8 @@ def _generate_callable_egg_name(self, ref: CallableRef) -> str:
683683
assert_never(ref)
684684

685685
def typed_expr_to_value(self, typed_expr: TypedExprDecl) -> bindings.Value:
686+
if isinstance(typed_expr.expr, ValueDecl):
687+
return typed_expr.expr.value
686688
egg_expr = self.typed_expr_to_egg(typed_expr, False)
687689
return self.egraph.eval_expr(egg_expr)[1]
688690

python/tests/test_high_level.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,3 +1568,37 @@ def __radd__(self, other: object) -> tuple[X, X]: ...
15681568

15691569
assert X(1) + 10 == (X(1), 10)
15701570
assert 10 + X(1) == (X(10), X(1))
1571+
1572+
1573+
def test_custom_cost_model_size():
1574+
"""
1575+
https://egraphs.zulipchat.com/#narrow/channel/375765-egg.2Fegglog/topic/Cost.20function.3A.20using.20function.20values.20of.20subtrees/near/577062352
1576+
"""
1577+
1578+
class KAT(Expr):
1579+
@classmethod
1580+
def eps(cls) -> KAT: ...
1581+
1582+
@classmethod
1583+
def emp(cls) -> KAT: ...
1584+
1585+
def func(self, other: KAT) -> KAT: ...
1586+
1587+
def size(self) -> i64: ...
1588+
1589+
eps, emp = KAT.eps(), KAT.emp()
1590+
1591+
eg = EGraph()
1592+
q0 = eg.let("q0", KAT.func(eps, emp))
1593+
1594+
eg.register(set_(eps.size()).to(i64(1)))
1595+
eg.register(set_(emp.size()).to(i64(0)))
1596+
1597+
def conv_cost(eg, expr, child_costs):
1598+
if isinstance(expr, KAT):
1599+
args = get_callable_args(expr)
1600+
return sum(int(eg.lookup_function_value(cast("KAT", a).size())) for a in args)
1601+
1602+
return 2
1603+
1604+
assert eg.extract(q0, include_cost=True, cost_model=conv_cost) == (KAT.eps().func(KAT.emp()), 1)

0 commit comments

Comments
 (0)