Skip to content

Commit 06f2d82

Browse files
before change
1 parent 338649a commit 06f2d82

9 files changed

Lines changed: 138 additions & 78 deletions

File tree

python/egglog/conversion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def convert(source: object, target: type[V]) -> V:
118118
"""
119119
Convert a source object to a target type.
120120
"""
121+
target = cast("RuntimeClass", target)
122+
# TODO: for some reason this breaks things
121123
# if not issubclass(target, RuntimeClass):
122124
# raise TypeError(f"Expected target type to be a egglog type, got {target} of type {type(target)}")
123125
return cast("V", resolve_literal(target.__egg_tp__, source, target.__egg_decls_thunk__))

python/egglog/egraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1376,7 +1376,7 @@ def debug_print(self) -> None:
13761376
print(call)
13771377
print("=== End EGraph Debug Print ===")
13781378

1379-
def _values_to_expr(self, args: list[bindings._Value], name: str) -> RuntimeExpr:
1379+
def _values_to_expr(self, args: list[bindings.Value], name: str) -> RuntimeExpr:
13801380
(callable_ref,) = self._state.egg_fn_to_callable_refs[name]
13811381
signature = self.__egg_decls__.get_callable_decl(callable_ref).signature
13821382
assert isinstance(signature, FunctionSignature)

python/egglog/exp/array_api.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def if_(cls, b: BooleanLike, i: Callable[[], Boolean], j: Callable[[], Boolean])
9090
"""
9191

9292

93-
BooleanLike = Boolean | BoolLike
93+
BooleanLike: TypeAlias = Boolean | BoolLike
9494

9595
TRUE = Boolean(True)
9696
FALSE = Boolean(False)
@@ -416,6 +416,7 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
416416
rewrite(Float(f)).to(Float.rational(BigRat(f.to_i64(), 1)), eq(f64.from_i64(f.to_i64())).to(f)),
417417
# always convert from int to rational
418418
rewrite(Float.from_int(Int(i))).to(Float.rational(BigRat(i, 1))),
419+
rewrite(Float.rational(r)).to(Float(r.to_f64())),
419420
rewrite(Float(f) + Float(f2)).to(Float(f + f2)),
420421
rewrite(Float(f) - Float(f2)).to(Float(f - f2)),
421422
rewrite(Float(f) * Float(f2)).to(Float(f * f2)),
@@ -1221,6 +1222,9 @@ def _value(
12211222
yield rewrite(v ** Value.from_float(Float.rational(BigRat(1, 1)))).to(v)
12221223
yield rewrite(Value.from_float(Float.from_int(i))).to(Value.from_int(i))
12231224

1225+
# Upcast binary op
1226+
yield rewrite(Value.from_int(i) * Value.from_float(f)).to(Value.from_float(Float.from_int(i)) * Value.from_float(f))
1227+
12241228

12251229
class TupleValue(Expr, ruleset=array_api_ruleset):
12261230
def __init__(self, vec: VecLike[Value, ValueLike] = ()) -> None: ...
@@ -1575,6 +1579,7 @@ def fn(cls, shape: TupleIntLike, dtype: DType, idx_fn: Callable[[TupleInt], Valu
15751579

15761580
NEVER: ClassVar[NDArray]
15771581

1582+
@method(unextractable=True)
15781583
@classmethod
15791584
def from_tuple_value(cls, tv: TupleValueLike) -> NDArray:
15801585
"""
@@ -1590,6 +1595,13 @@ def from_tuple_value(cls, tv: TupleValueLike) -> NDArray:
15901595
lambda idx: tv[idx[0]],
15911596
)
15921597

1598+
@method(unextractable=True)
1599+
def to_tuple_values(self) -> TupleValue:
1600+
"""
1601+
Turns a vector array into a tuple value.
1602+
"""
1603+
return TupleValue.fn(self.shape[0], lambda i: self.index((i,)))
1604+
15931605
def to_recursive_value(self) -> RecursiveValue: ...
15941606

15951607
# @method(preserve=True)
@@ -2026,7 +2038,7 @@ def isfinite(x: NDArray) -> NDArray: ...
20262038

20272039

20282040
@function
2029-
def sum(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray:
2041+
def sum(x: NDArray, axis: OptionalIntOrTupleLike = OptionalIntOrTuple.none) -> NDArray:
20302042
"""
20312043
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.sum.html?highlight=sum
20322044
"""
@@ -2198,7 +2210,7 @@ def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: ...
21982210

21992211

22002212
@function
2201-
def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ...
2213+
def mean(x: NDArray, axis: OptionalIntOrTupleLike = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ...
22022214

22032215

22042216
# TODO: Possibly change names to include modules.
@@ -2207,7 +2219,7 @@ def sqrt(x: NDArray) -> NDArray: ...
22072219

22082220

22092221
@function
2210-
def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ...
2222+
def std(x: NDArray, axis: OptionalIntOrTupleLike = OptionalIntOrTuple.none) -> NDArray: ...
22112223

22122224

22132225
@function
@@ -2684,7 +2696,15 @@ def try_evaling(expr: ExprWithValue[T_co]) -> T_co:
26842696
egraph = _get_current_egraph()
26852697
egraph.register(expr) # type: ignore[arg-type]
26862698
egraph.run(array_api_schedule)
2687-
return egraph.extract(expr).value # type: ignore[call-overload]
2699+
# run on another e-graph to get around bug
2700+
# https://github.com/egraphs-good/egglog/issues/801
2701+
# return egraph.extract(expr).value # type: ignore[call-overload]
2702+
extracted_expr = egraph.extract(expr) # type: ignore[call-overload]
2703+
new_egraph = EGraph()
2704+
new_egraph.register(extracted_expr)
2705+
new_egraph.run(array_api_schedule)
2706+
return new_egraph.extract(extracted_expr).value
2707+
26882708
# try:
26892709
# return egraph.extract(prim_expr).value # type: ignore[attr-defined]
26902710
# except EggSmolError:

python/egglog/exp/array_api_numba.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,19 @@
1717
# Rewrite mean(x, <int>, <expand dims>) to use sum b/c numba cant do mean with axis
1818
# https://github.com/numba/numba/issues/1269
1919
@array_api_numba_ruleset.register
20-
def _mean(y: NDArray, x: NDArray, i: Int):
21-
axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
22-
res = sum(x, axis) / NDArray.scalar(Value.from_int(x.shape[i]))
20+
def _mean(y: NDArray, x: NDArray, axis: Int):
21+
res = sum(x, axis) / x.shape[axis]
2322

2423
yield rewrite(mean(x, axis, FALSE), subsume=True).to(res)
25-
yield rewrite(mean(x, axis, TRUE), subsume=True).to(expand_dims(res, i))
24+
yield rewrite(mean(x, axis, TRUE), subsume=True).to(expand_dims(res, axis))
2625

2726

2827
# Rewrite std(x, <int>) to use mean and sum b/c numba cant do std with axis
2928
@array_api_numba_ruleset.register
30-
def _std(y: NDArray, x: NDArray, i: Int):
31-
axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
29+
def _std(y: NDArray, x: NDArray, axis: Int):
3230
# https://numpy.org/doc/stable/reference/generated/numpy.std.html
3331
# "std = sqrt(mean(x)), where x = abs(a - a.mean())**2."
34-
yield rewrite(
35-
std(x, axis),
36-
subsume=True,
37-
).to(
32+
yield rewrite(std(x, axis), subsume=True).to(
3833
sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)),
3934
)
4035

@@ -47,14 +42,16 @@ def count_values(x: NDArrayLike, values: TupleValueLike) -> TupleValue:
4742
"""
4843
x = cast(NDArray, x)
4944
values = cast(TupleValue, values)
50-
return TupleValue(values.length(), lambda i: sum(x == values[i]).to_value())
45+
return TupleValue.fn(values.length(), lambda i: sum(x == values[i]).index(()))
5146

5247

5348
@array_api_numba_ruleset.register
5449
def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value):
5550
return [
5651
# The unique counts are the count of all the unique values
57-
rewrite(unique_counts(x)[1], subsume=True).to(NDArray.vector(count_values(x, unique_values(x).to_values()))),
52+
rewrite(unique_counts(x)[1], subsume=True).to(
53+
NDArray.from_tuple_value(count_values(x, unique_values(x).to_tuple_values()))
54+
),
5855
]
5956

6057

@@ -63,7 +60,5 @@ def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value):
6360
def _unique_inverse(x: NDArray, i: Int):
6461
return [
6562
# Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
66-
rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.from_int(i)), subsume=True).to(
67-
x == NDArray.scalar(unique_values(x).index((i,)))
68-
),
63+
rewrite(unique_inverse(x)[1] == i, subsume=True).to(x == unique_values(x).index((i,))),
6964
]

0 commit comments

Comments
 (0)