@@ -90,7 +90,7 @@ def if_(cls, b: BooleanLike, i: Callable[[], Boolean], j: Callable[[], Boolean])
9090 """
9191
9292
93- BooleanLike = Boolean | BoolLike
93+ BooleanLike : TypeAlias = Boolean | BoolLike
9494
9595TRUE = Boolean (True )
9696FALSE = Boolean (False )
@@ -416,6 +416,7 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
416416 rewrite (Float (f )).to (Float .rational (BigRat (f .to_i64 (), 1 )), eq (f64 .from_i64 (f .to_i64 ())).to (f )),
417417 # always convert from int to rational
418418 rewrite (Float .from_int (Int (i ))).to (Float .rational (BigRat (i , 1 ))),
419+ rewrite (Float .rational (r )).to (Float (r .to_f64 ())),
419420 rewrite (Float (f ) + Float (f2 )).to (Float (f + f2 )),
420421 rewrite (Float (f ) - Float (f2 )).to (Float (f - f2 )),
421422 rewrite (Float (f ) * Float (f2 )).to (Float (f * f2 )),
@@ -1221,6 +1222,9 @@ def _value(
12211222 yield rewrite (v ** Value .from_float (Float .rational (BigRat (1 , 1 )))).to (v )
12221223 yield rewrite (Value .from_float (Float .from_int (i ))).to (Value .from_int (i ))
12231224
1225+ # Upcast binary op
1226+ yield rewrite (Value .from_int (i ) * Value .from_float (f )).to (Value .from_float (Float .from_int (i )) * Value .from_float (f ))
1227+
12241228
12251229class TupleValue (Expr , ruleset = array_api_ruleset ):
12261230 def __init__ (self , vec : VecLike [Value , ValueLike ] = ()) -> None : ...
@@ -1575,6 +1579,7 @@ def fn(cls, shape: TupleIntLike, dtype: DType, idx_fn: Callable[[TupleInt], Valu
15751579
15761580 NEVER : ClassVar [NDArray ]
15771581
1582+ @method (unextractable = True )
15781583 @classmethod
15791584 def from_tuple_value (cls , tv : TupleValueLike ) -> NDArray :
15801585 """
@@ -1590,6 +1595,13 @@ def from_tuple_value(cls, tv: TupleValueLike) -> NDArray:
15901595 lambda idx : tv [idx [0 ]],
15911596 )
15921597
1598+ @method (unextractable = True )
1599+ def to_tuple_values (self ) -> TupleValue :
1600+ """
1601+ Turns a vector array into a tuple value.
1602+ """
1603+ return TupleValue .fn (self .shape [0 ], lambda i : self .index ((i ,)))
1604+
15931605 def to_recursive_value (self ) -> RecursiveValue : ...
15941606
15951607 # @method(preserve=True)
@@ -2026,7 +2038,7 @@ def isfinite(x: NDArray) -> NDArray: ...
20262038
20272039
20282040@function
2029- def sum (x : NDArray , axis : OptionalIntOrTuple = OptionalIntOrTuple .none ) -> NDArray :
2041+ def sum (x : NDArray , axis : OptionalIntOrTupleLike = OptionalIntOrTuple .none ) -> NDArray :
20302042 """
20312043 https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.sum.html?highlight=sum
20322044 """
@@ -2198,7 +2210,7 @@ def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: ...
21982210
21992211
22002212@function
2201- def mean (x : NDArray , axis : OptionalIntOrTuple = OptionalIntOrTuple .none , keepdims : Boolean = FALSE ) -> NDArray : ...
2213+ def mean (x : NDArray , axis : OptionalIntOrTupleLike = OptionalIntOrTuple .none , keepdims : Boolean = FALSE ) -> NDArray : ...
22022214
22032215
22042216# TODO: Possibly change names to include modules.
@@ -2207,7 +2219,7 @@ def sqrt(x: NDArray) -> NDArray: ...
22072219
22082220
22092221@function
2210- def std (x : NDArray , axis : OptionalIntOrTuple = OptionalIntOrTuple .none ) -> NDArray : ...
2222+ def std (x : NDArray , axis : OptionalIntOrTupleLike = OptionalIntOrTuple .none ) -> NDArray : ...
22112223
22122224
22132225@function
@@ -2684,7 +2696,15 @@ def try_evaling(expr: ExprWithValue[T_co]) -> T_co:
26842696 egraph = _get_current_egraph ()
26852697 egraph .register (expr ) # type: ignore[arg-type]
26862698 egraph .run (array_api_schedule )
2687- return egraph .extract (expr ).value # type: ignore[call-overload]
2699+ # run on another e-graph to get around bug
2700+ # https://github.com/egraphs-good/egglog/issues/801
2701+ # return egraph.extract(expr).value # type: ignore[call-overload]
2702+ extracted_expr = egraph .extract (expr ) # type: ignore[call-overload]
2703+ new_egraph = EGraph ()
2704+ new_egraph .register (extracted_expr )
2705+ new_egraph .run (array_api_schedule )
2706+ return new_egraph .extract (extracted_expr ).value
2707+
26882708 # try:
26892709 # return egraph.extract(prim_expr).value # type: ignore[attr-defined]
26902710 # except EggSmolError:
0 commit comments