@@ -545,12 +545,23 @@ def append(self, i: IntLike) -> TupleInt:
545545 >>> ti = TupleInt.range(3)
546546 >>> ti2 = ti.append(3)
547547 >>> list(ti2)
548- [i64 (0), i64 (1), i64 (2), i64 (3)]
548+ [Int (0), Int (1), Int (2), Int (3)]
549549 """
550550 return TupleInt .fn (
551551 self .length () + 1 , lambda j : Int .if_ (j == self .length (), lambda : cast ("Int" , i ), lambda : self [j ])
552552 )
553553
554+ @method (unextractable = True )
555+ def append_start (self , i : IntLike ) -> TupleInt :
556+ """
557+ Prepend an integer to the start of the tuple.
558+ >>> ti = TupleInt.range(3)
559+ >>> ti2 = ti.append_start( -1)
560+ >>> list(ti2)
561+ [Int(-1), Int(0), Int(1), Int(2)]
562+ """
563+ return TupleInt .fn (self .length () + 1 , lambda j : Int .if_ (j == 0 , lambda : cast ("Int" , i ), lambda : self [j - 1 ]))
564+
554565 @method (unextractable = True )
555566 def __add__ (self , other : TupleIntLike ) -> TupleInt :
556567 """
@@ -826,8 +837,8 @@ def _tuple_int(
826837 yield rewrite (TupleInt .fn (i2 , idx_fn ).length (), subsume = True ).to (i2 )
827838 yield rewrite (TupleInt .fn (i2 , idx_fn )[i ], subsume = True ).to (idx_fn (check_index (i2 , i )))
828839
829- yield rewrite (TupleInt (vs ).length (), subsume = True ).to (Int (vs .length ()))
830- yield rewrite (TupleInt (vs )[Int (k )], subsume = True ).to (vs [k ])
840+ yield rewrite (TupleInt (vs ).length ()).to (Int (vs .length ()))
841+ yield rewrite (TupleInt (vs )[Int (k )]).to (vs [k ])
831842
832843 yield rewrite (TupleInt .if_ (TRUE , lt , lf ), subsume = True ).to (lt ())
833844 yield rewrite (TupleInt .if_ (FALSE , lt , lf ), subsume = True ).to (lf ())
@@ -1488,22 +1499,29 @@ def shape(self) -> TupleInt:
14881499 [Int(2), Int(1)]
14891500 """
14901501
1491- # @method(preserve=True) # type: ignore[prop-decorator]
1492- # @property
1493- # def value(self) -> PyTupleValuesRecursive:
1494- # """
1495- # Unwraps the RecursiveValue into either a Value or a nested tuple of Values.
1502+ @method (preserve = True ) # type: ignore[prop-decorator]
1503+ @property
1504+ def value (self ) -> PyTupleValuesRecursive :
1505+ """
1506+ Unwraps the RecursiveValue into either a Value or a nested tuple of Values.
14961507
1497- # >>> convert(((1, 2), (3, 4)), RecursiveValue).value
1498- # ((Value.from_int(Int(1)), Value.from_int(Int(2))), (Value.from_int(Int(3)), Value.from_int(Int(4))))
1499- # """
1500- # match get_callable_args(self, RecursiveValue):
1501- # case (value,):
1502- # return cast("Value", value)
1503- # match get_callable_args(self, RecursiveValue.vec):
1504- # case (vec,):
1505- # return tuple(v.value for v in cast("Vec[RecursiveValue]", vec))
1506- # raise ExprValueError(self, "RecursiveValue or RecursiveValue.vec")
1508+ >>> convert(((1, 2), (3, 4)), RecursiveValue).value
1509+ ((Value.from_int(Int(1)), Value.from_int(Int(2))), (Value.from_int(Int(3)), Value.from_int(Int(4))))
1510+ """
1511+ match get_callable_args (self , RecursiveValue ):
1512+ case (value ,):
1513+ return cast ("Value" , value )
1514+ match get_callable_args (self , RecursiveValue .vec ):
1515+ case (vec ,):
1516+ return tuple (v .value for v in cast ("Vec[RecursiveValue]" , vec ))
1517+ raise ExprValueError (self , "RecursiveValue or RecursiveValue.vec" )
1518+
1519+ @method (preserve = True )
1520+ def eval (self ) -> PyTupleValuesRecursive :
1521+ """
1522+ Evals to a nested tuple of values representing the RecursiveValue.
1523+ """
1524+ return try_evaling (self )
15071525
15081526
15091527PyTupleValuesRecursive : TypeAlias = Value | tuple ["PyTupleValuesRecursive" , ...]
@@ -1525,15 +1543,20 @@ def _recursive_value(
15251543 lt : Callable [[], RecursiveValue ],
15261544 lf : Callable [[], RecursiveValue ],
15271545 vi : Vec [Int ],
1546+ rv : RecursiveValue ,
15281547):
15291548 yield rewrite (RecursiveValue (v ).shape ).to (TupleInt (()))
15301549 yield rewrite (RecursiveValue .vec (vs ).shape ).to (TupleInt ((vs .length (),)) + vs [0 ].shape , vs .length () > 0 )
15311550 yield rewrite (RecursiveValue .vec (vs ).shape ).to (TupleInt ((0 ,)), vs .length () == i64 (0 ))
15321551
15331552 yield rewrite (RecursiveValue (v )[ti ], subsume = True ).to (v ) # Assume ti is empty
1534- yield rewrite (RecursiveValue .vec (vs )[TupleInt (vi )], subsume = True ).to (
1535- vs [k ][TupleInt (vi .remove (0 ))],
1553+
1554+ yield rule (
1555+ eq (v ).to (RecursiveValue .vec (vs )[TupleInt (vi )]),
1556+ vi .length () > 0 ,
15361557 eq (vi [0 ]).to (Int (k )),
1558+ ).then (
1559+ union (v ).with_ (vs [k ][TupleInt (vi .remove (0 ))]),
15371560 )
15381561
15391562
@@ -1572,45 +1595,47 @@ def from_tuple_value(cls, tv: TupleValueLike) -> NDArray:
15721595 lambda idx : tv [idx [0 ]],
15731596 )
15741597
1575- @method (preserve = True )
1576- def eval_vecs (self ) -> VecValuesRecursive :
1577- """
1578- Evals to a nested Vec of values representing the array. It will be extracted and simplified by the e-graph.
1579- """
1580- # Share an e-graph for the computation of shape and eval
1581- egraph = _get_current_egraph ()
1582- with set_array_api_egraph (egraph ):
1583- shape = self .shape .eval ()
1598+ def to_recursive_value (self ) -> RecursiveValue : ...
1599+
1600+ # @method(preserve=True)
1601+ # def eval_vecs(self) -> VecValuesRecursive:
1602+ # """
1603+ # Evals to a nested Vec of values representing the array. It will be extracted and simplified by the e-graph.
1604+ # """
1605+ # # Share an e-graph for the computation of shape and eval
1606+ # egraph = _get_current_egraph()
1607+ # with set_array_api_egraph(egraph):
1608+ # shape = self.shape.eval()
15841609
1585- def _inner_values (current_index : tuple [int , ...], remaining_dims : tuple [Int , ...]) -> VecValuesRecursive :
1586- if not remaining_dims :
1587- return self .index (current_index )
1588- return Vec (* (_inner_values ((* current_index , i ), remaining_dims [1 :]) for i in range (remaining_dims [0 ])))
1610+ # def _inner_values(current_index: tuple[int, ...], remaining_dims: tuple[Int, ...]) -> VecValuesRecursive:
1611+ # if not remaining_dims:
1612+ # return self.index(current_index)
1613+ # return Vec(*(_inner_values((*current_index, i), remaining_dims[1:]) for i in range(remaining_dims[0])))
15891614
1590- res = _inner_values ((), shape )
1591- egraph .register (res )
1592- egraph .run (array_api_schedule )
1593- return egraph .extract (res )
1615+ # res = _inner_values((), shape)
1616+ # egraph.register(res)
1617+ # egraph.run(array_api_schedule)
1618+ # return egraph.extract(res)
15941619
15951620 @method (preserve = True )
15961621 def eval (self ) -> PyTupleValuesRecursive :
15971622 """
15981623 Evals to a nested tuple of values representing the array.
15991624 """
1625+ return self .to_recursive_value ().eval ()
16001626
1601- def _to_tuple (v : VecValuesRecursive ) -> PyTupleValuesRecursive :
1602- if isinstance (v , Value ):
1603- return v
1604- return tuple (_to_tuple (i ) for i in v )
1627+ # def _to_tuple(v: VecValuesRecursive) -> PyTupleValuesRecursive:
1628+ # if isinstance(v, Value):
1629+ # return v
1630+ # return tuple(_to_tuple(i) for i in v)
16051631
1606- return _to_tuple (self .eval_vecs ())
1632+ # return _to_tuple(self.eval_vecs())
16071633
16081634 @method (preserve = True )
16091635 def eval_numpy (self , dtype : np .dtype | None = None ) -> np .ndarray :
16101636 """
16111637 Evals to a numpy ndarray.
16121638 """
1613- print (self .eval ()[0 ])
16141639 return np .array (self .eval (), dtype = dtype )
16151640
16161641 @method (preserve = True )
@@ -1775,6 +1800,17 @@ def index(self, indices: TupleIntLike) -> Value:
17751800 @classmethod
17761801 def if_ (cls , b : BooleanLike , i : Callable [[], NDArray ], j : Callable [[], NDArray ]) -> NDArray : ...
17771802
1803+ def partial_index (self , i : IntLike ) -> NDArray :
1804+ """
1805+ Partially index into the array, returning a sub-array.
1806+
1807+ >>> NDArray(((1, 2), (3, 4))).partial_index(0).eval_numpy("int64")
1808+ array([1, 2])
1809+ """
1810+ return NDArray .fn (
1811+ self .shape .drop (1 ), self .dtype , lambda idx : self .index (idx .append_start (check_index (self .shape [0 ], i )))
1812+ )
1813+
17781814
17791815VecValuesRecursive : TypeAlias = "Value | Vec[VecValuesRecursive]"
17801816
@@ -1825,6 +1861,22 @@ def _ndarray(
18251861 # if_
18261862 rewrite (NDArray .if_ (TRUE , xt , x1t ), subsume = True ).to (xt ()),
18271863 rewrite (NDArray .if_ (FALSE , xt , x1t ), subsume = True ).to (x1t ()),
1864+ # to RecursiveValue
1865+ rewrite (NDArray (rv ).to_recursive_value (), subsume = True ).to (rv ),
1866+ rewrite (NDArray .fn ((), dtype , idx_fn ).to_recursive_value (), subsume = True ).to (
1867+ RecursiveValue (idx_fn (TupleInt (())))
1868+ ),
1869+ rule (
1870+ eq (x ).to (NDArray .fn (TupleInt (vi ), dtype , idx_fn )),
1871+ x .to_recursive_value (),
1872+ vi .length () > i64 (0 ),
1873+ eq (vi [0 ]).to (Int (i )),
1874+ ).then (
1875+ union (x .to_recursive_value ()).with_ (
1876+ RecursiveValue .vec (i .range ().map (lambda j : x .partial_index (j ).to_recursive_value ()))
1877+ ),
1878+ subsume (x .to_recursive_value ()),
1879+ ),
18281880 ]
18291881
18301882
0 commit comments