Skip to content

Commit f5d993b

Browse files
Add freeze!
1 parent 3a9b8d9 commit f5d993b

14 files changed

Lines changed: 563 additions & 136 deletions

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ name = "egglog"
1010
crate-type = ["cdylib"]
1111

1212
[dependencies]
13-
pyo3 = { version = "0.27", features = ["extension-module", "num-bigint", "num-rational"] }
13+
pyo3 = { version = "0.27", features = ["extension-module", "num-bigint", "num-rational", "indexmap"] }
1414
num-bigint = "*"
1515
num-rational = "*"
16+
indexmap = "2.12"
1617
# egglog = { path = "../egg-smol", default-features = false }
1718
# egglog-bridge = { path = "../egg-smol/egglog-bridge" }
1819
# egglog-core-relations = { path = "../egg-smol/core-relations" }

docs/how-to-guides.md

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ file_format: mystnb
66

77
## Parsing and running program strings
88

9-
You can provide your program in a special DSL language. You can parse this with {meth}`egglog.bindings.EGraph.parse_program` and then run the result with You can parse this with {meth}`egglog.bindings.EGraph.run_program`::
9+
You can provide your program in a special DSL language. Parse it with
10+
{meth}`egglog.bindings.EGraph.parse_program` and run the resulting commands with
11+
{meth}`egglog.bindings.EGraph.run_program`:
1012

1113
```{code-cell}
1214
from egglog.bindings import EGraph
@@ -19,3 +21,74 @@ commands
1921
```{code-cell}
2022
egraph.run_program(*commands)
2123
```
24+
25+
## Debugging a high-level e-graph
26+
27+
When a rule does not fire or an equality appears unexpectedly, the fastest tools
28+
to reach for are:
29+
30+
- {meth}`egglog.egraph.EGraph.run` for per-run match counts and timings.
31+
- {meth}`egglog.egraph.EGraph.function_values` for inspecting function tables.
32+
- {meth}`egglog.egraph.EGraph.display` for visualizing the current e-graph.
33+
- {meth}`egglog.egraph.EGraph.saturate` for stepping to a fixpoint while
34+
printing extracted expressions.
35+
36+
```{code-cell}
37+
from __future__ import annotations
38+
39+
from egglog import *
40+
41+
42+
class Math(Expr):
43+
def __init__(self, value: i64Like) -> None: ...
44+
45+
def __add__(self, other: Math) -> Math: ...
46+
47+
48+
@function
49+
def score(x: Math) -> i64: ...
50+
51+
52+
debug_rules = ruleset()
53+
54+
55+
@debug_rules.register
56+
def _(i: i64, j: i64):
57+
yield rewrite(Math(i) + Math(j)).to(Math(i + j))
58+
```
59+
60+
Start with a normal run report when you want to see which rules matched:
61+
62+
```{code-cell}
63+
egraph = EGraph()
64+
expr = egraph.let("expr", Math(2) + Math(3))
65+
egraph.register(set_(score(expr)).to(5))
66+
67+
report = egraph.run(debug_rules)
68+
report.num_matches_per_rule
69+
```
70+
71+
Use `function_values(...)` to inspect the rows currently stored for a function:
72+
73+
```{code-cell}
74+
egraph.function_values(score)
75+
```
76+
77+
Use `display(...)` to look at the current e-graph structure:
78+
79+
```{code-cell}
80+
egraph.display(graphviz=True)
81+
```
82+
83+
Use `saturate(...)` when you want to keep running until nothing changes while
84+
printing the extracted form of an expression after each iteration:
85+
86+
```{code-cell}
87+
egraph = EGraph()
88+
expr = egraph.let("expr", Math(2) + Math(3))
89+
egraph.saturate(debug_rules, expr=expr, max=2, visualize=False)
90+
```
91+
92+
For a structural snapshot that can be turned back into replayable high-level
93+
actions, see the `freeze()` section in
94+
{doc}`reference/python-integration`.

docs/reference/python-integration.md

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -660,14 +660,16 @@ The default renderer for the e-graph in a Jupyter Notebook [an interactive Javas
660660
egraph
661661
```
662662

663-
You can also customize the visualization through using the <inv:egglog.EGraph.display> method:
663+
You can also customize the visualization with
664+
{meth}`egglog.egraph.EGraph.display`:
664665

665666
```{code-cell} python
666667
egraph.display()
667668
```
668669

669-
If you would like to visualize the progression of the e-graph over time, you can use the <inv:egglog.EGraph.saturate> method to
670-
run a number of iterations and then visualize the e-graph at each step:
670+
If you would like to visualize the progression of the e-graph over time, you can
671+
use {meth}`egglog.egraph.EGraph.saturate` to run a number of iterations and then
672+
visualize the e-graph at each step:
671673

672674
```{code-cell} python
673675
egraph = EGraph()
@@ -710,6 +712,31 @@ stats = egraph.stats()
710712
stats.num_matches_per_rule
711713
```
712714

715+
### Freeze the e-graph
716+
717+
For a replayable, high-level snapshot of the current e-graph, use
718+
{meth}`egglog.egraph.EGraph.freeze`.
719+
720+
Unlike the lower-level serializer, `freeze()` reconstructs the current e-graph as
721+
high-level Python actions, so it is convenient for debugging and for writing
722+
regression tests around unexpected unions, sets, costs, or subsumptions.
723+
724+
```{code-cell} python
725+
class DebugMath(Expr):
726+
def __init__(self, value: i64Like) -> None: ...
727+
728+
729+
@function
730+
def debug_score(x: DebugMath) -> i64: ...
731+
732+
733+
egraph = EGraph()
734+
expr = DebugMath(1)
735+
egraph.register(expr, set_(debug_score(expr)).to(3))
736+
737+
str(egraph.freeze())
738+
```
739+
713740
### Serialize the e-graph
714741

715742
If you create the e-graph with `save_egglog_string=True`, you can dump the program

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ ignore = [
206206
"PLW1641",
207207
# allow non module file for docs
208208
"INP001",
209+
# don't replace lambdas with functions because we need them to defer resolution
210+
"PLW0108",
209211
]
210212
select = ["ALL"]
211213

python/egglog/declarations.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88

99
from dataclasses import dataclass, field
10-
from functools import cached_property
10+
from functools import cache, cached_property
1111
from itertools import chain, repeat
1212
from typing import (
1313
TYPE_CHECKING,
@@ -25,6 +25,8 @@
2525
from uuid import UUID
2626
from weakref import WeakValueDictionary
2727

28+
from egglog import bindings
29+
2830
from .bindings import Value
2931

3032
if TYPE_CHECKING:
@@ -54,6 +56,7 @@
5456
"DefaultRewriteDecl",
5557
"DelayedDeclarations",
5658
"DummyDecl",
59+
"EGraphDecl",
5760
"EqDecl",
5861
"ExprActionDecl",
5962
"ExprDecl",
@@ -349,6 +352,175 @@ class CombinedRulesetDecl:
349352
rulesets: tuple[Ident, ...]
350353

351354

355+
T_expr_decl = TypeVar("T_expr_decl", bound="ExprDecl")
356+
357+
358+
@dataclass(frozen=True)
359+
class EGraphDecl:
360+
"""
361+
State of an e-graph, which when re-added to a new e-graph will reconstruct the same e-graph, given the same Declarations.
362+
363+
All the expressions in here may reference values which appear in the `e_classes` mapping.
364+
"""
365+
366+
# Mapping from top level let binding names to their types and expressions
367+
let_bindings: dict[str, TypedExprDecl] = field(default_factory=dict)
368+
# Mapping from egglog values representing e-classes to all the expressions in that e-class
369+
e_classes: dict[bindings.Value, tuple[JustTypeRef, tuple[CallDecl, ...]]] = field(default_factory=dict)
370+
# Mapping from function calls to the values they are set to
371+
sets: dict[CallDecl, TypedExprDecl] = field(default_factory=dict)
372+
# Top-level expr actions such as relation facts.
373+
expr_actions: tuple[TypedExprDecl, ...] = field(default=())
374+
# Mapping from function calls to the set costs.
375+
costs: dict[CallDecl, tuple[JustTypeRef, int]] = field(default_factory=dict)
376+
# Set of values which are subsumed
377+
subsumed: tuple[tuple[JustTypeRef, CallDecl], ...] = field(default=())
378+
379+
def __hash__(self) -> int:
380+
return hash((
381+
type(self),
382+
tuple(self.let_bindings.items()),
383+
tuple((value, tp, exprs) for value, (tp, exprs) in self.e_classes.items()),
384+
tuple(self.sets.items()),
385+
self.expr_actions,
386+
tuple(self.costs.items()),
387+
self.subsumed,
388+
))
389+
390+
@cached_property
391+
def to_actions(self) -> list[ActionDecl]: # noqa: C901
392+
"""
393+
Converts this egraph decl to a list of actions that can be executed to reconstruct the egraph.
394+
395+
Converts all e-classes to grounded terms + unions.
396+
397+
Currently does not support cycles or empty e-classes.
398+
"""
399+
# First fill up the e_class_grounded_term for all e_classes
400+
# by iteratively adding grounded terms for e-classes which have a grounded term until no more progress can be made.
401+
402+
# mapping from e-class to a grounded term in that e-class
403+
e_class_grounded_term: dict[Value, CallDecl] = {}
404+
405+
def is_grounded(expr: ExprDecl) -> bool:
406+
"""
407+
Checks if the given expression is grounded, meaning any values recursively in it have grounded terms in their e-classes.
408+
"""
409+
match expr:
410+
case LetRefDecl(name):
411+
raise ValueError(f"Cannot have unexpanded let bindings in egraph decl: {name}")
412+
case UnboundVarDecl(_):
413+
msg = "Cannot have unbound variables in egraph decl"
414+
raise ValueError(msg)
415+
case CallDecl(_, args, _):
416+
return all(is_grounded(a.expr) for a in args)
417+
case LitDecl(_) | PyObjectDecl(_):
418+
return True
419+
case PartialCallDecl(call):
420+
return is_grounded(call)
421+
case DummyDecl():
422+
msg = "Cannot have dummy decls in egraph decl"
423+
raise ValueError(msg)
424+
case ValueDecl(value):
425+
return value in e_class_grounded_term
426+
case GetCostDecl():
427+
msg = "Cannot have GetCostDecl in egraph decl"
428+
raise ValueError(msg)
429+
case _:
430+
assert_never(expr)
431+
432+
made_progress = True
433+
while made_progress:
434+
made_progress = False
435+
for e_class, (_, exprs) in self.e_classes.items():
436+
if e_class in e_class_grounded_term:
437+
continue
438+
for expr in exprs:
439+
if is_grounded(expr):
440+
e_class_grounded_term[e_class] = expr
441+
made_progress = True
442+
break
443+
444+
# call declarations already emitted as part of other actions.
445+
emitted_call_decls = set[CallDecl]()
446+
447+
@cache
448+
def to_grounded(expr: ExprDecl) -> ExprDecl:
449+
"""
450+
Converts the given expression to a grounded term, by replacing any values in it with their grounded terms.
451+
"""
452+
match expr:
453+
case LetRefDecl(name):
454+
raise ValueError(f"Cannot have unexpanded let bindings in egraph decl: {name}")
455+
case UnboundVarDecl(_):
456+
msg = "Cannot have unbound variables in egraph decl"
457+
raise ValueError(msg)
458+
case CallDecl(callable, args, bound_tp_params):
459+
emitted_call_decls.add(expr)
460+
new_args = tuple(TypedExprDecl(a.tp, to_grounded(a.expr)) for a in args)
461+
return CallDecl(callable, new_args, bound_tp_params)
462+
case LitDecl(_) | PyObjectDecl(_):
463+
return expr
464+
case PartialCallDecl(call):
465+
return PartialCallDecl(cast("CallDecl", to_grounded(call)))
466+
case DummyDecl():
467+
msg = "Cannot have dummy decls in egraph decl"
468+
raise ValueError(msg)
469+
case ValueDecl(value):
470+
if value not in e_class_grounded_term:
471+
raise ValueError(f"Value {value} does not have a grounded term in egraph decl")
472+
return to_grounded(e_class_grounded_term[value])
473+
case GetCostDecl():
474+
msg = "Cannot have GetCostDecl in egraph decl"
475+
raise ValueError(msg)
476+
case _:
477+
assert_never(expr)
478+
479+
# calls that are in e-classes with only one value, so wouldn't be added as a union and might need
480+
# to be added as a single expr action if they don't show up anywhere else
481+
single_e_class_calls: list[tuple[JustTypeRef, CallDecl]] = []
482+
483+
# Now add all e-classes as actions.
484+
actions: list[ActionDecl] = []
485+
for e_class, (tp, exprs) in self.e_classes.items():
486+
chosen_term = e_class_grounded_term[e_class]
487+
if len(exprs) == 1:
488+
single_e_class_calls.append((tp, chosen_term))
489+
continue
490+
491+
grounded_chosen_term = to_grounded(chosen_term)
492+
for expr in exprs:
493+
if expr == chosen_term:
494+
continue
495+
actions.append(UnionDecl(tp, grounded_chosen_term, to_grounded(expr)))
496+
actions.extend(
497+
LetDecl(name, TypedExprDecl(typed_expr.tp, to_grounded(typed_expr.expr)))
498+
for name, typed_expr in self.let_bindings.items()
499+
)
500+
actions.extend(
501+
SetDecl(set_expr.tp, cast("CallDecl", to_grounded(call)), to_grounded(set_expr.expr))
502+
for call, set_expr in self.sets.items()
503+
)
504+
actions.extend(
505+
ExprActionDecl(TypedExprDecl(typed_expr.tp, to_grounded(typed_expr.expr)))
506+
for typed_expr in self.expr_actions
507+
)
508+
actions.extend(
509+
SetCostDecl(tp, cast("CallDecl", to_grounded(call)), LitDecl(cost))
510+
for call, (tp, cost) in self.costs.items()
511+
)
512+
actions.extend(ChangeDecl(tp, cast("CallDecl", to_grounded(call)), "subsume") for tp, call in self.subsumed)
513+
514+
# Now add any remaining call s that weren't part of any other actions
515+
actions.extend(
516+
ExprActionDecl(TypedExprDecl(tp, to_grounded(expr)))
517+
for (tp, expr) in single_e_class_calls
518+
if expr not in emitted_call_decls
519+
)
520+
521+
return actions
522+
523+
352524
# Have two different types of type refs, one that can include vars recursively and one that cannot.
353525
# We only use the one with vars for classmethods and methods, and the other one for egg references as
354526
# well as runtime values.

0 commit comments

Comments
 (0)