Skip to content

Commit 2158dc4

Browse files
authored
full fix for as_strided in torch backend (tinygrad#9257)
* fixes from chargpt for torch backend * shrink support * add stride support * comment cleanup * a few more * work * import the stream hack * llvm multi auto
1 parent f60f997 commit 2158dc4

5 files changed

Lines changed: 189 additions & 50 deletions

File tree

.github/workflows/test.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ jobs:
155155
with:
156156
key: torch-backend-pillow-torchvision-et-pt
157157
deps: testing_minimal
158-
pydeps: "pillow torchvision expecttest pytest"
158+
pydeps: "pillow torchvision expecttest"
159+
llvm: 'true'
159160
- name: Install ninja
160161
run: |
161162
sudo apt update || true
@@ -169,11 +170,11 @@ jobs:
169170
- name: Test one op in torch tests
170171
run: PYTHONPATH=. DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32
171172
- name: Test Ops with TINY_BACKEND (expect failure)
172-
run: PYTHONPATH=. TINY_BACKEND=1 python3 -m pytest test/test_ops.py || true
173+
run: PYTHONPATH=. LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py || true
173174
- name: Test beautiful_mnist in torch with TINY_BACKEND (expect failure)
174175
run: PYTHONPATH=. TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py || true
175176
- name: Test some torch tests (expect failure)
176-
run: PYTHONPATH=. pytest extra/torch_backend/torch_tests.py -v --tb=no || true
177+
run: PYTHONPATH=. python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true
177178

178179
tc:
179180
name: Tensor Core tests

extra/to_movement_ops.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,26 @@
66
from tinygrad.helpers import prod, tqdm
77
from tinygrad.ops import UOp, Ops
88
from tinygrad.shape.shapetracker import ShapeTracker
9-
from tinygrad.ops import sym_infer, Node
9+
from tinygrad.ops import sym_infer
10+
from tinygrad.tensor import Tensor
1011

1112
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
1213

13-
def apply_mop(st: ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker:
14+
def apply_mop(st: Tensor|ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker:
1415
mop, arg = mop_arg
1516
if mop == MovementOps.RESHAPE:
1617
# shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE
17-
if arg == (-1,): return st.reshape((prod(st.views[-1].shape),))
18+
if arg == (-1,): return st.reshape((prod(st.shape),))
1819
return st.reshape(arg)
1920
if mop == MovementOps.PERMUTE: return st.permute(arg)
2021
if mop == MovementOps.EXPAND:
2122
if len(arg) != len(st.shape): st = st.reshape((1,*st.shape))
2223
return st.expand(arg)
2324
if mop == MovementOps.PAD: return st.pad(arg)
2425
if mop == MovementOps.SHRINK: return st.shrink(arg)
25-
if mop == MovementOps.STRIDE: return st.stride(arg)
26+
if mop == MovementOps.STRIDE:
27+
assert all(x in [-1, 1] for x in arg)
28+
return st.flip(tuple(i for i,x in enumerate(arg) if x == -1))
2629
raise ValueError("invalid mop")
2730

2831
def make_scratch_st(st: ShapeTracker) -> ShapeTracker:
@@ -36,7 +39,7 @@ def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]:
3639
offset = v.offset + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0)
3740
real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
3841
real_real_shape = [s for s,st in zip(real_shape, v.strides) if st]
39-
strides: List[Node|int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
42+
strides: List[int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
4043
buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1
4144
if i: buffer_size = prod(st.views[i-1].shape) - real_offset
4245
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
@@ -80,9 +83,12 @@ def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lamb
8083
if scratch_st in seen:
8184
ret = seen[scratch_st][:]
8285
else:
83-
ret.append(mop_arg)
86+
if len(ret) and ret[-1][0] == MovementOps.RESHAPE and mop_arg[0] == MovementOps.RESHAPE:
87+
ret[-1] = mop_arg
88+
else:
89+
if mop_arg == (MovementOps.RESHAPE, -1): mop_arg = (MovementOps.RESHAPE, (prod(st.shape),))
90+
ret.append(mop_arg)
8491
seen[scratch_st] = ret[:]
85-
8692
return ret
8793

8894
def get_real_view(shape, strides, offset, mask):

extra/torch_backend/backend.py

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from tinygrad import Tensor, dtypes
22
from tinygrad.helpers import DEBUG, getenv, prod
3+
import torch.lib
34
TORCH_DEBUG = getenv("TORCH_DEBUG")
45
import torch, pathlib, math, operator
56
torch.autograd.grad_mode.set_multithreading_enabled(False)
@@ -38,33 +39,19 @@ def masked_select(self, mask):
3839
# err, bad
3940
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
4041

42+
from tinygrad.shape.shapetracker import ShapeTracker, View
43+
from extra.to_movement_ops import to_movement_ops, apply_mop, MovementOps
4144
@torch.library.impl("aten::as_strided", "privateuseone")
4245
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
43-
#return tensor.cpu().as_strided(size, stride).tiny()
44-
if TORCH_DEBUG >= 1: print("** NOTE: this as_strided might be wrong", tensor.shape, size, stride, storage_offset)
45-
46-
nz_strides = [st for s,st in zip(size, stride) if s != 1]
47-
decending_strides = all(x>=y for x,y in zip(nz_strides[:-1], nz_strides[1:]))
48-
49-
# this is reshape (squeeze/unsqueeze), strides must be in decending order
50-
if tuple(x for x in tensor.shape if x != 1) == tuple(x for x in size if x != 1) and decending_strides:
51-
return tensor.reshape(size)
52-
53-
# this is also expand, hit?
54-
if tensor.numel() == 1:
55-
assert all(x == 0 for x in stride)
56-
return wrap(unwrap(tensor).reshape([1]*len(size)).expand(size))
57-
58-
# this is expand
59-
if len(tensor.shape) == len(size) and all(x == y or x == 1 for x,y in zip(tensor.shape, size)) and decending_strides:
60-
return wrap(unwrap(tensor).expand(size))
61-
62-
# this is permute because we are flipping strides
63-
if len(tensor.shape) == 2 and tuple(tensor.shape)[::-1] == tuple(size) and stride == [0, 1]:
64-
return wrap(unwrap(tensor).permute(1,0))
65-
66-
#print(tensor.cpu().numpy())
67-
raise NotImplementedError(f"fix as_strided {tensor.shape} -> {size} {stride} {storage_offset}")
46+
# TODO: this is heavyweight
47+
st = ShapeTracker([View.create(tuple(tensor.shape)), View.create(tuple(size), tuple(stride), 0 if storage_offset is None else storage_offset)])
48+
ret = unwrap(tensor)
49+
if prod(size) == 1: return wrap(ret.flatten()[storage_offset].reshape(size))
50+
if TORCH_DEBUG >= 1: print("**** as_strided", tensor.shape, size, stride, st)
51+
mops = to_movement_ops(st)
52+
if mops[0] == (MovementOps.RESHAPE, tuple(tensor.shape)): mops = mops[1:]
53+
for mo in mops: ret = apply_mop(ret, mo)
54+
return wrap(ret)
6855

6956
@torch.library.impl("aten::empty_strided", "privateuseone")
7057
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
@@ -85,6 +72,33 @@ def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, di
8572
# TODO: this is wrong
8673
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
8774

75+
@torch.library.impl("aten::arange", "privateuseone")
76+
def arange(end, dtype=None, device=None, pin_memory=None):
77+
return wrap(Tensor.arange(0, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
78+
79+
@torch.library.impl("aten::arange.start", "privateuseone")
80+
def arange_start(start, end, dtype=None, device=None, pin_memory=None):
81+
return wrap(Tensor.arange(start, end, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
82+
83+
@torch.library.impl("aten::arange.start_step", "privateuseone")
84+
def arange_start_step(start, end, step, dtype=None, device=None, pin_memory=None):
85+
return wrap(Tensor.arange(start, end, step, dtype=_from_torch_dtype(dtype or torch.get_default_dtype())))
86+
87+
@torch.library.impl("aten::topk", "privateuseone")
88+
def topk(self, k, dim=-1, largest=True, sorted=True):
89+
# TODO: move to tinygrad
90+
t1, t2 = torch.topk(self.cpu(), k, dim, largest, sorted)
91+
return torch.return_types.topk((t1.tiny(), t2.tiny()))
92+
93+
@torch.library.impl("aten::_index_put_impl_", "privateuseone")
94+
def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
95+
# TODO: move to tinygrad
96+
return aten._index_put_impl_(self.cpu(), [x.cpu() for x in indices], values.cpu(), accumulate, unsafe).tiny()
97+
98+
@torch.library.impl("aten::index.Tensor", "privateuseone")
99+
def index_tensor(x, y):
100+
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).tiny()
101+
88102
@torch.library.impl("aten::convolution_overrideable", "privateuseone")
89103
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
90104
if TORCH_DEBUG >= 1:
@@ -103,21 +117,23 @@ def _copy_from(src, dest):
103117
dest.copy_(torch.from_numpy(unwrap(src).numpy()))
104118
elif str(src.device) == "cpu" and str(dest.device) == "tiny":
105119
unwrap(dest).assign(Tensor(src.numpy()))
120+
#if 0 in dest.stride():
121+
# print(dest.shape, dest.stride())
122+
# exit(0)
106123
else:
107124
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")
108125

109126
@torch.library.impl("aten::cat.out", "privateuseone")
110-
def cat_out(tensors, out, dim=0): unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True)
111-
112-
@torch.library.impl("aten::index.Tensor", "privateuseone")
113-
def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
127+
def cat_out(tensors, dim=0, out=None):
128+
unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True)
114129

115130
# register some decompositions
116131
from torch._decomp import get_decompositions
117132
aten = torch.ops.aten
118133
decomps = {
119134
"post_autograd": [
120135
aten.native_batch_norm, aten.native_batch_norm_backward,
136+
aten.native_layer_norm_backward,
121137
aten.addmm,
122138
aten.addcmul,
123139
aten.addcdiv,
@@ -142,6 +158,10 @@ def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
142158
aten.nan_to_num,
143159
aten.logit,
144160
aten.rsub,
161+
aten.index_select,
162+
aten.native_dropout, aten.native_dropout_backward,
163+
aten._softmax_backward_data, aten.embedding_dense_backward,
164+
aten.linalg_vector_norm,
145165
# activations
146166
aten.hardswish, aten.hardswish_backward,
147167
aten.hardtanh, aten.hardtanh_backward,
@@ -194,11 +214,13 @@ def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
194214
"aten.add.out": lambda input,other,alpha=1: input+alpha*other,
195215
"aten.sub.out": lambda input,other,alpha=1: input-alpha*other, # NOTE: this is also needed to handle reverse
196216
"aten.mul.out": operator.mul,
217+
"aten.bmm.out": operator.matmul,
197218
"aten.leaky_relu.out": Tensor.leakyrelu, # TODO: this should be renamed in tinygrad
198219
# NOTE: because these methods have a name with "Tensor" in them, they can't go in simple tensor methods
199220
"aten.remainder.Tensor_out": Tensor.mod,
200221
"aten.pow.Tensor_Tensor_out": Tensor.pow,
201222
"aten.pow.Tensor_Scalar_out": Tensor.pow,
223+
"aten.pow.Scalar_out": lambda x,y: x**y,
202224
"aten.bitwise_and.Tensor_out": Tensor.bitwise_and,
203225
"aten.bitwise_or.Tensor_out": Tensor.bitwise_or,
204226
"aten.bitwise_xor.Tensor_out": lambda x,y: x^y, # TODO: tinygrad lacks bitwise_xor, add it
@@ -229,10 +251,15 @@ def _wrap_out(*args, **kwargs):
229251

230252
tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
231253
"aten.view": Tensor.reshape,
254+
"aten._unsafe_view": Tensor.reshape, # when are views unsafe, and do we care?
255+
"aten.remainder.Scalar_Tensor": lambda x,y: x%y,
232256
"aten.floor_divide": lambda x,y: x//y,
257+
"aten.floor_divide_.Tensor": lambda x,y: x.assign(x//y),
233258
# TODO: use tinygrad methods, but they require x to be unsigned
234259
"aten.__lshift__.Scalar": lambda x,y: x*(2**y),
260+
"aten.__ilshift__.Scalar": lambda x,y: x.assign(x*(2**y)),
235261
"aten.__rshift__.Scalar": lambda x,y: x//(2**y),
262+
"aten.__irshift__.Scalar": lambda x,y: x.assign(x//(2**y)),
236263
# relu doesn't have an out form?
237264
"aten.relu": Tensor.relu,
238265
"aten.relu_": lambda x: x.assign(x.relu()),
@@ -251,7 +278,7 @@ def _wrap_out(*args, **kwargs):
251278
"aten.var_mean.correction": lambda self, dims, keepdim=False, correction=1: (self.var(dims, keepdim, correction), self.mean(dims, keepdim)),
252279
# NOTE: axis=[] in torch means all, change tinygrad?
253280
"aten.sum.IntList_out": lambda self,axis,keepdim=False,out=None:
254-
out.replace(Tensor.sum(self, axis if len(axis) else None, keepdim), allow_shape_mismatch=True),
281+
out.replace(Tensor.sum(self, axis if axis is None or len(axis) else None, keepdim), allow_shape_mismatch=True),
255282
"aten.scatter.value": Tensor.scatter,
256283
"aten.gather": Tensor.gather,
257284
"aten.where.self": Tensor.where,
@@ -266,11 +293,14 @@ def _wrap_out(*args, **kwargs):
266293
# these don't work in out form, they have size 0
267294
"aten.abs": Tensor.abs,
268295
"aten.logical_not": Tensor.logical_not,
296+
"aten.masked_fill_.Scalar": lambda self,mask,value: self.assign(mask.where(self, value)),
297+
"aten.multinomial": Tensor.multinomial,
269298
}}
270299

271300
def wrap_fxn(k,f):
272301
def nf(*args, **kwargs):
273-
if TORCH_DEBUG: print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
302+
if TORCH_DEBUG:
303+
print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
274304
{k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()})
275305
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
276306
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
@@ -300,8 +330,9 @@ def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
300330
tinygrad_tensors.append(param.data)
301331
for state_dict in optimizer.state.values():
302332
for key, value in state_dict.items():
303-
if torch.is_tensor(value) and str(value.device) == "tiny": tinygrad_tensors.append(value)
304-
Tensor.realize(*[unwrap(x) for x in tinygrad_tensors])
333+
if torch.is_tensor(value): tinygrad_tensors.append(value)
334+
real_tinygrad_tensors = [unwrap(x) for x in tinygrad_tensors if str(x.device) == "tiny"]
335+
if len(real_tinygrad_tensors): Tensor.realize(*real_tinygrad_tensors)
305336

306337
_optimizer_init = torch.optim.Optimizer.__init__
307338
def _optimizer_patched_init(self, *args, **kwargs):

extra/torch_backend/test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ def test_permute(self):
5050
np.testing.assert_equal(perm.cpu().numpy(), [[1,3],[2,4]])
5151
np.testing.assert_equal(back.cpu().numpy(), [[1,2],[3,4]])
5252

53+
def test_shrink(self):
54+
a = torch.Tensor([1,2,3,4]).to(device)
55+
np.testing.assert_equal(a[:3].cpu().numpy(), [1,2,3])
56+
np.testing.assert_equal(a[1:].cpu().numpy(), [2,3,4])
57+
5358
def test_plus_inplace(self):
5459
a = torch.ones(4, device=device)
5560
b = torch.ones(4, device=device)
@@ -66,15 +71,13 @@ def test_isfinite(self):
6671
a = torch.ones(4, device=device)
6772
np.testing.assert_equal(torch.isfinite(a).cpu().numpy(), [True, True, True, True])
6873

69-
@unittest.skip("broken")
7074
def test_eq(self):
7175
a = torch.ones(4, device=device)
7276
b = torch.ones(4, device=device)
7377
c = a == b
7478
print(c.cpu().numpy())
7579

76-
# TODO: why
77-
@unittest.skip("broken")
80+
@unittest.skip("meh")
7881
def test_str(self):
7982
a = torch.ones(4, device=device)
8083
print(str(a))

0 commit comments

Comments
 (0)