11from tinygrad import Tensor , dtypes
22from tinygrad .helpers import DEBUG , getenv , prod
3+ import torch .lib
34TORCH_DEBUG = getenv ("TORCH_DEBUG" )
45import torch , pathlib , math , operator
56torch .autograd .grad_mode .set_multithreading_enabled (False )
@@ -38,33 +39,19 @@ def masked_select(self, mask):
3839 # err, bad
3940 return wrap (Tensor (self .cpu ().numpy ()[mask .cpu ().numpy ()]))
4041
42+ from tinygrad .shape .shapetracker import ShapeTracker , View
43+ from extra .to_movement_ops import to_movement_ops , apply_mop , MovementOps
4144@torch .library .impl ("aten::as_strided" , "privateuseone" )
4245def as_strided (tensor :torch .Tensor , size , stride , storage_offset = None ):
43- #return tensor.cpu().as_strided(size, stride).tiny()
44- if TORCH_DEBUG >= 1 : print ("** NOTE: this as_strided might be wrong" , tensor .shape , size , stride , storage_offset )
45-
46- nz_strides = [st for s ,st in zip (size , stride ) if s != 1 ]
47- decending_strides = all (x >= y for x ,y in zip (nz_strides [:- 1 ], nz_strides [1 :]))
48-
49- # this is reshape (squeeze/unsqueeze), strides must be in decending order
50- if tuple (x for x in tensor .shape if x != 1 ) == tuple (x for x in size if x != 1 ) and decending_strides :
51- return tensor .reshape (size )
52-
53- # this is also expand, hit?
54- if tensor .numel () == 1 :
55- assert all (x == 0 for x in stride )
56- return wrap (unwrap (tensor ).reshape ([1 ]* len (size )).expand (size ))
57-
58- # this is expand
59- if len (tensor .shape ) == len (size ) and all (x == y or x == 1 for x ,y in zip (tensor .shape , size )) and decending_strides :
60- return wrap (unwrap (tensor ).expand (size ))
61-
62- # this is permute because we are flipping strides
63- if len (tensor .shape ) == 2 and tuple (tensor .shape )[::- 1 ] == tuple (size ) and stride == [0 , 1 ]:
64- return wrap (unwrap (tensor ).permute (1 ,0 ))
65-
66- #print(tensor.cpu().numpy())
67- raise NotImplementedError (f"fix as_strided { tensor .shape } -> { size } { stride } { storage_offset } " )
46+ # TODO: this is heavyweight
47+ st = ShapeTracker ([View .create (tuple (tensor .shape )), View .create (tuple (size ), tuple (stride ), 0 if storage_offset is None else storage_offset )])
48+ ret = unwrap (tensor )
49+ if prod (size ) == 1 : return wrap (ret .flatten ()[storage_offset ].reshape (size ))
50+ if TORCH_DEBUG >= 1 : print ("**** as_strided" , tensor .shape , size , stride , st )
51+ mops = to_movement_ops (st )
52+ if mops [0 ] == (MovementOps .RESHAPE , tuple (tensor .shape )): mops = mops [1 :]
53+ for mo in mops : ret = apply_mop (ret , mo )
54+ return wrap (ret )
6855
6956@torch .library .impl ("aten::empty_strided" , "privateuseone" )
7057def empty_strided (size , stride , dtype , layout = None , device = None , pin_memory = False ):
@@ -85,6 +72,33 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
8572 # TODO: this is wrong
8673 return (wrap (ret ), wrap (Tensor .zeros_like (ret , dtype = dtypes .int64 )))
8774
75+ @torch .library .impl ("aten::arange" , "privateuseone" )
76+ def arange (end , dtype = None , device = None , pin_memory = None ):
77+ return wrap (Tensor .arange (0 , end , dtype = _from_torch_dtype (dtype or torch .get_default_dtype ())))
78+
79+ @torch .library .impl ("aten::arange.start" , "privateuseone" )
80+ def arange_start (start , end , dtype = None , device = None , pin_memory = None ):
81+ return wrap (Tensor .arange (start , end , dtype = _from_torch_dtype (dtype or torch .get_default_dtype ())))
82+
83+ @torch .library .impl ("aten::arange.start_step" , "privateuseone" )
84+ def arange_start_step (start , end , step , dtype = None , device = None , pin_memory = None ):
85+ return wrap (Tensor .arange (start , end , step , dtype = _from_torch_dtype (dtype or torch .get_default_dtype ())))
86+
87+ @torch .library .impl ("aten::topk" , "privateuseone" )
88+ def topk (self , k , dim = - 1 , largest = True , sorted = True ):
89+ # TODO: move to tinygrad
90+ t1 , t2 = torch .topk (self .cpu (), k , dim , largest , sorted )
91+ return torch .return_types .topk ((t1 .tiny (), t2 .tiny ()))
92+
93+ @torch .library .impl ("aten::_index_put_impl_" , "privateuseone" )
94+ def _index_put_impl_ (self , indices , values , accumulate = False , unsafe = False ):
95+ # TODO: move to tinygrad
96+ return aten ._index_put_impl_ (self .cpu (), [x .cpu () for x in indices ], values .cpu (), accumulate , unsafe ).tiny ()
97+
98+ @torch .library .impl ("aten::index.Tensor" , "privateuseone" )
99+ def index_tensor (x , y ):
100+ return aten .index (x .cpu (), [z .cpu () if isinstance (z , torch .Tensor ) else None for z in y ]).tiny ()
101+
88102@torch .library .impl ("aten::convolution_overrideable" , "privateuseone" )
89103def convolution_overrideable (input , weight , bias , stride , padding , dilation , transposed , output_padding , groups ):
90104 if TORCH_DEBUG >= 1 :
@@ -103,21 +117,23 @@ def _copy_from(src, dest):
103117 dest .copy_ (torch .from_numpy (unwrap (src ).numpy ()))
104118 elif str (src .device ) == "cpu" and str (dest .device ) == "tiny" :
105119 unwrap (dest ).assign (Tensor (src .numpy ()))
120+ #if 0 in dest.stride():
121+ # print(dest.shape, dest.stride())
122+ # exit(0)
106123 else :
107124 raise NotImplementedError (f"can't copy from { src .device } -> { dest .device } " )
108125
109126@torch .library .impl ("aten::cat.out" , "privateuseone" )
110- def cat_out (tensors , out , dim = 0 ): unwrap (out ).replace (Tensor .cat (* [unwrap (x ) for x in tensors ], dim = dim ), allow_shape_mismatch = True )
111-
112- @torch .library .impl ("aten::index.Tensor" , "privateuseone" )
113- def index_tensor (x , y ): return wrap (unwrap (x )[y [0 ].tolist ()])
127+ def cat_out (tensors , dim = 0 , out = None ):
128+ unwrap (out ).replace (Tensor .cat (* [unwrap (x ) for x in tensors ], dim = dim ), allow_shape_mismatch = True )
114129
115130# register some decompositions
116131from torch ._decomp import get_decompositions
117132aten = torch .ops .aten
118133decomps = {
119134 "post_autograd" : [
120135 aten .native_batch_norm , aten .native_batch_norm_backward ,
136+ aten .native_layer_norm_backward ,
121137 aten .addmm ,
122138 aten .addcmul ,
123139 aten .addcdiv ,
@@ -142,6 +158,10 @@ def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
142158 aten .nan_to_num ,
143159 aten .logit ,
144160 aten .rsub ,
161+ aten .index_select ,
162+ aten .native_dropout , aten .native_dropout_backward ,
163+ aten ._softmax_backward_data , aten .embedding_dense_backward ,
164+ aten .linalg_vector_norm ,
145165 # activations
146166 aten .hardswish , aten .hardswish_backward ,
147167 aten .hardtanh , aten .hardtanh_backward ,
@@ -194,11 +214,13 @@ def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
194214 "aten.add.out" : lambda input ,other ,alpha = 1 : input + alpha * other ,
195215 "aten.sub.out" : lambda input ,other ,alpha = 1 : input - alpha * other , # NOTE: this is also needed to handle reverse
196216 "aten.mul.out" : operator .mul ,
217+ "aten.bmm.out" : operator .matmul ,
197218 "aten.leaky_relu.out" : Tensor .leakyrelu , # TODO: this should be renamed in tinygrad
198219 # NOTE: because these methods have a name with "Tensor" in them, they can't go in simple tensor methods
199220 "aten.remainder.Tensor_out" : Tensor .mod ,
200221 "aten.pow.Tensor_Tensor_out" : Tensor .pow ,
201222 "aten.pow.Tensor_Scalar_out" : Tensor .pow ,
223+ "aten.pow.Scalar_out" : lambda x ,y : x ** y ,
202224 "aten.bitwise_and.Tensor_out" : Tensor .bitwise_and ,
203225 "aten.bitwise_or.Tensor_out" : Tensor .bitwise_or ,
204226 "aten.bitwise_xor.Tensor_out" : lambda x ,y : x ^ y , # TODO: tinygrad lacks bitwise_xor, add it
@@ -229,10 +251,15 @@ def _wrap_out(*args, **kwargs):
229251
230252tiny_backend = {** {k :wrap_out (v ) for k ,v in tiny_backend_out .items ()}, ** {
231253 "aten.view" : Tensor .reshape ,
254+ "aten._unsafe_view" : Tensor .reshape , # when are views unsafe, and do we care?
255+ "aten.remainder.Scalar_Tensor" : lambda x ,y : x % y ,
232256 "aten.floor_divide" : lambda x ,y : x // y ,
257+ "aten.floor_divide_.Tensor" : lambda x ,y : x .assign (x // y ),
233258 # TODO: use tinygrad methods, but they require x to be unsigned
234259 "aten.__lshift__.Scalar" : lambda x ,y : x * (2 ** y ),
260+ "aten.__ilshift__.Scalar" : lambda x ,y : x .assign (x * (2 ** y )),
235261 "aten.__rshift__.Scalar" : lambda x ,y : x // (2 ** y ),
262+ "aten.__irshift__.Scalar" : lambda x ,y : x .assign (x // (2 ** y )),
236263 # relu doesn't have an out form?
237264 "aten.relu" : Tensor .relu ,
238265 "aten.relu_" : lambda x : x .assign (x .relu ()),
@@ -251,7 +278,7 @@ def _wrap_out(*args, **kwargs):
251278 "aten.var_mean.correction" : lambda self , dims , keepdim = False , correction = 1 : (self .var (dims , keepdim , correction ), self .mean (dims , keepdim )),
252279 # NOTE: axis=[] in torch means all, change tinygrad?
253280 "aten.sum.IntList_out" : lambda self ,axis ,keepdim = False ,out = None :
254- out .replace (Tensor .sum (self , axis if len (axis ) else None , keepdim ), allow_shape_mismatch = True ),
281+ out .replace (Tensor .sum (self , axis if axis is None or len (axis ) else None , keepdim ), allow_shape_mismatch = True ),
255282 "aten.scatter.value" : Tensor .scatter ,
256283 "aten.gather" : Tensor .gather ,
257284 "aten.where.self" : Tensor .where ,
@@ -266,11 +293,14 @@ def _wrap_out(*args, **kwargs):
266293 # these don't work in out form, they have size 0
267294 "aten.abs" : Tensor .abs ,
268295 "aten.logical_not" : Tensor .logical_not ,
296+ "aten.masked_fill_.Scalar" : lambda self ,mask ,value : self .assign (mask .where (self , value )),
297+ "aten.multinomial" : Tensor .multinomial ,
269298}}
270299
271300def wrap_fxn (k ,f ):
272301 def nf (* args , ** kwargs ):
273- if TORCH_DEBUG : print (k , len (args ), [x .shape if isinstance (x , torch .Tensor ) else x for x in args ],
302+ if TORCH_DEBUG :
303+ print (k , len (args ), [x .shape if isinstance (x , torch .Tensor ) else x for x in args ],
274304 {k :v .shape if isinstance (v , torch .Tensor ) else v for k ,v in kwargs .items ()})
275305 args = [unwrap (x ) if isinstance (x , torch .Tensor ) else x for x in args ]
276306 kwargs = {k :unwrap (v ) if isinstance (v , torch .Tensor ) else v for k ,v in kwargs .items ()}
@@ -300,8 +330,9 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
300330 tinygrad_tensors .append (param .data )
301331 for state_dict in optimizer .state .values ():
302332 for key , value in state_dict .items ():
303- if torch .is_tensor (value ) and str (value .device ) == "tiny" : tinygrad_tensors .append (value )
304- Tensor .realize (* [unwrap (x ) for x in tinygrad_tensors ])
333+ if torch .is_tensor (value ): tinygrad_tensors .append (value )
334+ real_tinygrad_tensors = [unwrap (x ) for x in tinygrad_tensors if str (x .device ) == "tiny" ]
335+ if len (real_tinygrad_tensors ): Tensor .realize (* real_tinygrad_tensors )
305336
306337_optimizer_init = torch .optim .Optimizer .__init__
307338def _optimizer_patched_init (self , * args , ** kwargs ):
0 commit comments