-
Notifications
You must be signed in to change notification settings - Fork 238
Bug: Shape info in tensor break forward in some pytorch op #988
Copy link
Copy link
Closed
Labels
DocArray v2This issue is part of the rewrite; not to be merged into mainThis issue is part of the rewrite; not to be merged into main
Description
Context
with TorchTensor[512] this break in some case as_expand
error
TypeError Traceback (most recent call last)
Cell In[182], line 4
1 for i, batch in enumerate(loader):
2 #assert batch.image.tensor.shape == (2, 3, 224, 224)
3 print(model.encode_vision(batch.image).shape)
----> 4 print(model.encode_text(batch.text))
6 break
Cell In[176], line 12, in CLIP.encode_text(self, texts)
11 def encode_text(self, texts: Text) -> TorchTensor:
---> 12 return self.bert(input_ids = texts.tokens.input_ids, attention_mask = texts.tokens.attention_mask)
File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py:579, in DistilBertModel.forward(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
577 if inputs_embeds is None:
578 inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
--> 579 return self.transformer(
580 x=inputs_embeds,
581 attn_mask=attention_mask,
582 head_mask=head_mask,
583 output_attentions=output_attentions,
584 output_hidden_states=output_hidden_states,
585 return_dict=return_dict,
586 )
File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py:354, in Transformer.forward(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, return_dict)
351 if output_hidden_states:
352 all_hidden_states = all_hidden_states + (hidden_state,)
--> 354 layer_outputs = layer_module(
355 x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions
356 )
357 hidden_state = layer_outputs[-1]
359 if output_attentions:
File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py:289, in TransformerBlock.forward(self, x, attn_mask, head_mask, output_attentions)
279 """
280 Parameters:
281 x: torch.tensor(bs, seq_length, dim)
(...)
286 torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
287 """
288 # Self-Attention
--> 289 sa_output = self.attention(
290 query=x,
291 key=x,
292 value=x,
293 mask=attn_mask,
294 head_mask=head_mask,
295 output_attentions=output_attentions,
296 )
297 if output_attentions:
298 sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py:215, in MultiHeadSelfAttention.forward(self, query, key, value, mask, head_mask, output_attentions)
213 q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
214 scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
--> 215 mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
216 scores = scores.masked_fill(
217 mask, torch.tensor(torch.finfo(scores.dtype).min)
218 ) # (bs, n_heads, q_length, k_length)
220 weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
TypeError: no implementation found for 'torch.Tensor.expand_as' on types that implement __torch_function__: [<class 'docarray.typing.tensor.abstract_tensor.TorchTensor[512]'>, <class 'docarray.typing.tensor.abstract_tensor.TorchTensor[512]'>]
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
DocArray v2This issue is part of the rewrite; not to be merged into mainThis issue is part of the rewrite; not to be merged into main
Type
Projects
Status
Done