Skip to content

Commit cfc87e4

Browse files
tmp
1 parent 3aab22f commit cfc87e4

7 files changed

Lines changed: 169 additions & 67 deletions

File tree

python/egglog/bindings.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,7 @@ class FrozenEGraph:
899899
class FrozenFunction:
900900
input_sorts: list[str]
901901
output_sort: str
902+
is_let_binding: bool
902903
rows: list[FrozenRow]
903904

904905
@final

python/egglog/egraph.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,29 +1341,42 @@ def debug_print(self) -> None:
13411341
"""
13421342
print("=== EGraph Debug Print ===")
13431343
print("Mapping from functions to their e-class")
1344-
for name, fn in self._egraph.freeze().functions.items():
1345-
for row in fn.rows:
1346-
call = self._values_to_expr(row.inputs, name)
1347-
# None for let calls we cant resolve
1348-
if call is None:
1349-
continue
1350-
res = RuntimeExpr.__from_values__(
1351-
self.__egg_decls__,
1352-
TypedExprDecl(
1353-
tp=call.__egg_typed_expr__.tp,
1354-
expr=self._state.value_to_expr(tp=call.__egg_typed_expr__.tp, value=row.output),
1355-
),
1356-
)
1357-
equality = eq(call).to(res)
1358-
debug_str = str(equality)
1344+
# printed_output_sorts = set()
1345+
frozen = self._egraph.freeze()
1346+
for name, fn in sorted(frozen.functions.items(), key=lambda kv: kv[1].output_sort):
1347+
printed_output_values = set()
1348+
for row in sorted(fn.rows, key=lambda r: r.output):
1349+
if fn.is_let_binding:
1350+
call = RuntimeExpr.__from_values__(
1351+
self.__egg_decls__,
1352+
TypedExprDecl(self._state.egg_sort_to_type_ref[fn.output_sort], LetRefDecl(name)),
1353+
)
1354+
else:
1355+
call = self._values_to_expr(row.inputs, name)
1356+
1357+
# if call.__egg_typed_expr__.tp not in printed_output_sorts:
1358+
# printed_output_sorts.add(call.__egg_typed_expr__.tp)
1359+
# print(f"\n# {call.__egg_typed_expr__.tp}\n")
1360+
1361+
if row.output not in printed_output_values:
1362+
printed_output_values.add(row.output)
1363+
1364+
res = RuntimeExpr.__from_values__(
1365+
self.__egg_decls__,
1366+
TypedExprDecl(
1367+
tp=call.__egg_typed_expr__.tp,
1368+
expr=self._state.value_to_expr(tp=call.__egg_typed_expr__.tp, value=row.output),
1369+
),
1370+
)
1371+
print(f"\n## {res}: {call.__egg_typed_expr__.tp}\n")
1372+
13591373
if row.subsumed:
1360-
debug_str += " # subsumed"
1361-
print(debug_str)
1374+
print(subsume(cast("Expr", call)))
1375+
else:
1376+
print(call)
13621377
print("=== End EGraph Debug Print ===")
13631378

1364-
def _values_to_expr(self, args: list[bindings._Value], name: str) -> RuntimeExpr | None:
1365-
if name not in self._state.egg_fn_to_callable_refs:
1366-
return None
1379+
def _values_to_expr(self, args: list[bindings._Value], name: str) -> RuntimeExpr:
13671380
(callable_ref,) = self._state.egg_fn_to_callable_refs[name]
13681381
signature = self.__egg_decls__.get_callable_decl(callable_ref).signature
13691382
assert isinstance(signature, FunctionSignature)

python/egglog/egraph_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class EGraphState:
5454
rulesets: dict[Ident, set[RewriteOrRuleDecl]] = field(default_factory=dict)
5555

5656
# Bidirectional mapping between egg function names and python callable references.
57-
# Note that there are possibly mutliple callable references for a single egg function name, like `+`
57+
# Note that there are possibly multiple callable references for a single egg function name, like `+`
5858
# for both int and rational classes.
5959
egg_fn_to_callable_refs: dict[str, set[CallableRef]] = field(
6060
default_factory=lambda: defaultdict(set, {"!=": {FunctionRef(Ident.builtin("!="))}})

python/egglog/exp/array_api.py

Lines changed: 95 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,23 @@ def append(self, i: IntLike) -> TupleInt:
545545
>>> ti = TupleInt.range(3)
546546
>>> ti2 = ti.append(3)
547547
>>> list(ti2)
548-
[i64(0), i64(1), i64(2), i64(3)]
548+
[Int(0), Int(1), Int(2), Int(3)]
549549
"""
550550
return TupleInt.fn(
551551
self.length() + 1, lambda j: Int.if_(j == self.length(), lambda: cast("Int", i), lambda: self[j])
552552
)
553553

554+
@method(unextractable=True)
555+
def append_start(self, i: IntLike) -> TupleInt:
556+
"""
557+
Prepend an integer to the start of the tuple.
558+
>>> ti = TupleInt.range(3)
559+
>>> ti2 = ti.append_start( -1)
560+
>>> list(ti2)
561+
[Int(-1), Int(0), Int(1), Int(2)]
562+
"""
563+
return TupleInt.fn(self.length() + 1, lambda j: Int.if_(j == 0, lambda: cast("Int", i), lambda: self[j - 1]))
564+
554565
@method(unextractable=True)
555566
def __add__(self, other: TupleIntLike) -> TupleInt:
556567
"""
@@ -826,8 +837,8 @@ def _tuple_int(
826837
yield rewrite(TupleInt.fn(i2, idx_fn).length(), subsume=True).to(i2)
827838
yield rewrite(TupleInt.fn(i2, idx_fn)[i], subsume=True).to(idx_fn(check_index(i2, i)))
828839

829-
yield rewrite(TupleInt(vs).length(), subsume=True).to(Int(vs.length()))
830-
yield rewrite(TupleInt(vs)[Int(k)], subsume=True).to(vs[k])
840+
yield rewrite(TupleInt(vs).length()).to(Int(vs.length()))
841+
yield rewrite(TupleInt(vs)[Int(k)]).to(vs[k])
831842

832843
yield rewrite(TupleInt.if_(TRUE, lt, lf), subsume=True).to(lt())
833844
yield rewrite(TupleInt.if_(FALSE, lt, lf), subsume=True).to(lf())
@@ -1488,22 +1499,29 @@ def shape(self) -> TupleInt:
14881499
[Int(2), Int(1)]
14891500
"""
14901501

1491-
# @method(preserve=True) # type: ignore[prop-decorator]
1492-
# @property
1493-
# def value(self) -> PyTupleValuesRecursive:
1494-
# """
1495-
# Unwraps the RecursiveValue into either a Value or a nested tuple of Values.
1502+
@method(preserve=True) # type: ignore[prop-decorator]
1503+
@property
1504+
def value(self) -> PyTupleValuesRecursive:
1505+
"""
1506+
Unwraps the RecursiveValue into either a Value or a nested tuple of Values.
14961507
1497-
# >>> convert(((1, 2), (3, 4)), RecursiveValue).value
1498-
# ((Value.from_int(Int(1)), Value.from_int(Int(2))), (Value.from_int(Int(3)), Value.from_int(Int(4))))
1499-
# """
1500-
# match get_callable_args(self, RecursiveValue):
1501-
# case (value,):
1502-
# return cast("Value", value)
1503-
# match get_callable_args(self, RecursiveValue.vec):
1504-
# case (vec,):
1505-
# return tuple(v.value for v in cast("Vec[RecursiveValue]", vec))
1506-
# raise ExprValueError(self, "RecursiveValue or RecursiveValue.vec")
1508+
>>> convert(((1, 2), (3, 4)), RecursiveValue).value
1509+
((Value.from_int(Int(1)), Value.from_int(Int(2))), (Value.from_int(Int(3)), Value.from_int(Int(4))))
1510+
"""
1511+
match get_callable_args(self, RecursiveValue):
1512+
case (value,):
1513+
return cast("Value", value)
1514+
match get_callable_args(self, RecursiveValue.vec):
1515+
case (vec,):
1516+
return tuple(v.value for v in cast("Vec[RecursiveValue]", vec))
1517+
raise ExprValueError(self, "RecursiveValue or RecursiveValue.vec")
1518+
1519+
@method(preserve=True)
1520+
def eval(self) -> PyTupleValuesRecursive:
1521+
"""
1522+
Evals to a nested tuple of values representing the RecursiveValue.
1523+
"""
1524+
return try_evaling(self)
15071525

15081526

15091527
PyTupleValuesRecursive: TypeAlias = Value | tuple["PyTupleValuesRecursive", ...]
@@ -1525,15 +1543,20 @@ def _recursive_value(
15251543
lt: Callable[[], RecursiveValue],
15261544
lf: Callable[[], RecursiveValue],
15271545
vi: Vec[Int],
1546+
rv: RecursiveValue,
15281547
):
15291548
yield rewrite(RecursiveValue(v).shape).to(TupleInt(()))
15301549
yield rewrite(RecursiveValue.vec(vs).shape).to(TupleInt((vs.length(),)) + vs[0].shape, vs.length() > 0)
15311550
yield rewrite(RecursiveValue.vec(vs).shape).to(TupleInt((0,)), vs.length() == i64(0))
15321551

15331552
yield rewrite(RecursiveValue(v)[ti], subsume=True).to(v) # Assume ti is empty
1534-
yield rewrite(RecursiveValue.vec(vs)[TupleInt(vi)], subsume=True).to(
1535-
vs[k][TupleInt(vi.remove(0))],
1553+
1554+
yield rule(
1555+
eq(v).to(RecursiveValue.vec(vs)[TupleInt(vi)]),
1556+
vi.length() > 0,
15361557
eq(vi[0]).to(Int(k)),
1558+
).then(
1559+
union(v).with_(vs[k][TupleInt(vi.remove(0))]),
15371560
)
15381561

15391562

@@ -1572,45 +1595,47 @@ def from_tuple_value(cls, tv: TupleValueLike) -> NDArray:
15721595
lambda idx: tv[idx[0]],
15731596
)
15741597

1575-
@method(preserve=True)
1576-
def eval_vecs(self) -> VecValuesRecursive:
1577-
"""
1578-
Evals to a nested Vec of values representing the array. It will be extracted and simplified by the e-graph.
1579-
"""
1580-
# Share an e-graph for the computation of shape and eval
1581-
egraph = _get_current_egraph()
1582-
with set_array_api_egraph(egraph):
1583-
shape = self.shape.eval()
1598+
def to_recursive_value(self) -> RecursiveValue: ...
1599+
1600+
# @method(preserve=True)
1601+
# def eval_vecs(self) -> VecValuesRecursive:
1602+
# """
1603+
# Evals to a nested Vec of values representing the array. It will be extracted and simplified by the e-graph.
1604+
# """
1605+
# # Share an e-graph for the computation of shape and eval
1606+
# egraph = _get_current_egraph()
1607+
# with set_array_api_egraph(egraph):
1608+
# shape = self.shape.eval()
15841609

1585-
def _inner_values(current_index: tuple[int, ...], remaining_dims: tuple[Int, ...]) -> VecValuesRecursive:
1586-
if not remaining_dims:
1587-
return self.index(current_index)
1588-
return Vec(*(_inner_values((*current_index, i), remaining_dims[1:]) for i in range(remaining_dims[0])))
1610+
# def _inner_values(current_index: tuple[int, ...], remaining_dims: tuple[Int, ...]) -> VecValuesRecursive:
1611+
# if not remaining_dims:
1612+
# return self.index(current_index)
1613+
# return Vec(*(_inner_values((*current_index, i), remaining_dims[1:]) for i in range(remaining_dims[0])))
15891614

1590-
res = _inner_values((), shape)
1591-
egraph.register(res)
1592-
egraph.run(array_api_schedule)
1593-
return egraph.extract(res)
1615+
# res = _inner_values((), shape)
1616+
# egraph.register(res)
1617+
# egraph.run(array_api_schedule)
1618+
# return egraph.extract(res)
15941619

15951620
@method(preserve=True)
15961621
def eval(self) -> PyTupleValuesRecursive:
15971622
"""
15981623
Evals to a nested tuple of values representing the array.
15991624
"""
1625+
return self.to_recursive_value().eval()
16001626

1601-
def _to_tuple(v: VecValuesRecursive) -> PyTupleValuesRecursive:
1602-
if isinstance(v, Value):
1603-
return v
1604-
return tuple(_to_tuple(i) for i in v)
1627+
# def _to_tuple(v: VecValuesRecursive) -> PyTupleValuesRecursive:
1628+
# if isinstance(v, Value):
1629+
# return v
1630+
# return tuple(_to_tuple(i) for i in v)
16051631

1606-
return _to_tuple(self.eval_vecs())
1632+
# return _to_tuple(self.eval_vecs())
16071633

16081634
@method(preserve=True)
16091635
def eval_numpy(self, dtype: np.dtype | None = None) -> np.ndarray:
16101636
"""
16111637
Evals to a numpy ndarray.
16121638
"""
1613-
print(self.eval()[0])
16141639
return np.array(self.eval(), dtype=dtype)
16151640

16161641
@method(preserve=True)
@@ -1775,6 +1800,17 @@ def index(self, indices: TupleIntLike) -> Value:
17751800
@classmethod
17761801
def if_(cls, b: BooleanLike, i: Callable[[], NDArray], j: Callable[[], NDArray]) -> NDArray: ...
17771802

1803+
def partial_index(self, i: IntLike) -> NDArray:
1804+
"""
1805+
Partially index into the array, returning a sub-array.
1806+
1807+
>>> NDArray(((1, 2), (3, 4))).partial_index(0).eval_numpy("int64")
1808+
array([1, 2])
1809+
"""
1810+
return NDArray.fn(
1811+
self.shape.drop(1), self.dtype, lambda idx: self.index(idx.append_start(check_index(self.shape[0], i)))
1812+
)
1813+
17781814

17791815
VecValuesRecursive: TypeAlias = "Value | Vec[VecValuesRecursive]"
17801816

@@ -1825,6 +1861,22 @@ def _ndarray(
18251861
# if_
18261862
rewrite(NDArray.if_(TRUE, xt, x1t), subsume=True).to(xt()),
18271863
rewrite(NDArray.if_(FALSE, xt, x1t), subsume=True).to(x1t()),
1864+
# to RecursiveValue
1865+
rewrite(NDArray(rv).to_recursive_value(), subsume=True).to(rv),
1866+
rewrite(NDArray.fn((), dtype, idx_fn).to_recursive_value(), subsume=True).to(
1867+
RecursiveValue(idx_fn(TupleInt(())))
1868+
),
1869+
rule(
1870+
eq(x).to(NDArray.fn(TupleInt(vi), dtype, idx_fn)),
1871+
x.to_recursive_value(),
1872+
vi.length() > i64(0),
1873+
eq(vi[0]).to(Int(i)),
1874+
).then(
1875+
union(x.to_recursive_value()).with_(
1876+
RecursiveValue.vec(i.range().map(lambda j: x.partial_index(j).to_recursive_value()))
1877+
),
1878+
subsume(x.to_recursive_value()),
1879+
),
18281880
]
18291881

18301882

python/egglog/exp/vecdot_example.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,45 @@
11
from egglog.exp.array_api import *
22

3-
# smaller example
43
v = NDArray([[1, 2], [3, 4]])
54
n = NDArray([3, 4])
65
res = vecdot(v, n)
6+
egraph = EGraph()
7+
egraph.register(res.to_recursive_value())
8+
egraph.run(array_api_schedule)
9+
10+
_RecursiveValue_1 = RecursiveValue.vec(
11+
Vec(
12+
RecursiveValue.vec(
13+
Vec(
14+
RecursiveValue(Value.from_int(Int(1))),
15+
RecursiveValue(Value.from_int(Int(2))),
16+
)
17+
),
18+
RecursiveValue.vec(
19+
Vec(
20+
RecursiveValue(Value.from_int(Int(3))),
21+
RecursiveValue(Value.from_int(Int(4))),
22+
)
23+
),
24+
)
25+
)
26+
egraph.let("im_value", _RecursiveValue_1[TupleInt(Vec(Int(0), Int(0)))])
27+
new_res = egraph.extract(res.to_recursive_value())
28+
29+
30+
egraph.debug_print()
31+
print(new_res)
32+
33+
# new_egraph = EGraph()
34+
# new_egraph.register(new_res)
35+
# new_egraph.run(array_api_schedule)
36+
# print(new_egraph.extract(new_res))
37+
38+
39+
# smaller example
40+
741

8-
print(res.eval_numpy("int64"))
42+
# print(res.eval_numpy("int64"))
943
# This fails with EggSmolError: Panic: Illegal merge attempted for function egglog_exp_array_api_Int_to_i64
1044
# assert str(res.eval_numpy("float64")) == "array([ 3., 8., 10.])"
1145

src/egraph.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,5 +263,5 @@ impl EGraph {
263263

264264
/// Wrapper around Egglog Value. Represents either a primitive base value or a reference to an e-class.
265265
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Clone)]
266-
#[pyclass(eq, frozen, hash, str = "{0:?}")]
266+
#[pyclass(eq, frozen, ord, hash, str = "{0:?}")]
267267
pub struct Value(pub egglog::Value);

src/freeze.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub struct FrozenRow {
2020
pub struct FrozenFunction {
2121
input_sorts: Vec<String>,
2222
output_sort: String,
23+
is_let_binding: bool,
2324
rows: Vec<FrozenRow>,
2425
}
2526

@@ -56,6 +57,7 @@ impl FrozenEGraph {
5657
.collect(),
5758
output_sort: func.schema().output.name().to_string(),
5859
rows,
60+
is_let_binding: func.is_let_binding(),
5961
};
6062
functions.insert(fname.clone(), frozen_function);
6163
}

0 commit comments

Comments
 (0)