Skip to content

Bug: Shape info in tensor break forward in some pytorch op #988

@samsja

Description

@samsja

Context

with TorchTensor[512] this break in some case as_expand

error

TypeError                                 Traceback (most recent call last)
Cell In[182], line 4
      1 for i, batch in enumerate(loader):
      2     #assert batch.image.tensor.shape == (2, 3, 224, 224)
      3     print(model.encode_vision(batch.image).shape)
----> 4     print(model.encode_text(batch.text))
      6     break

Cell In[176], line 12, in CLIP.encode_text(self, texts)
     11 def encode_text(self, texts: Text) -> TorchTensor:
---> 12     return self.bert(input_ids = texts.tokens.input_ids, attention_mask = texts.tokens.attention_mask)

File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py:579, in DistilBertModel.forward(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
    577 if inputs_embeds is None:
    578     inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)
--> 579 return self.transformer(
    580     x=inputs_embeds,
    581     attn_mask=attention_mask,
    582     head_mask=head_mask,
    583     output_attentions=output_attentions,
    584     output_hidden_states=output_hidden_states,
    585     return_dict=return_dict,
    586 )

File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py:354, in Transformer.forward(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, return_dict)
    351 if output_hidden_states:
    352     all_hidden_states = all_hidden_states + (hidden_state,)
--> 354 layer_outputs = layer_module(
    355     x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions
    356 )
    357 hidden_state = layer_outputs[-1]
    359 if output_attentions:

File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py:289, in TransformerBlock.forward(self, x, attn_mask, head_mask, output_attentions)
    279 """
    280 Parameters:
    281     x: torch.tensor(bs, seq_length, dim)
   (...)
    286     torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
    287 """
    288 # Self-Attention
--> 289 sa_output = self.attention(
    290     query=x,
    291     key=x,
    292     value=x,
    293     mask=attn_mask,
    294     head_mask=head_mask,
    295     output_attentions=output_attentions,
    296 )
    297 if output_attentions:
    298     sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)

File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.config/JetBrains/PyCharmCE2022.3/scratches/laion/venv/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py:215, in MultiHeadSelfAttention.forward(self, query, key, value, mask, head_mask, output_attentions)
    213 q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
    214 scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
--> 215 mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
    216 scores = scores.masked_fill(
    217     mask, torch.tensor(torch.finfo(scores.dtype).min)
    218 )  # (bs, n_heads, q_length, k_length)
    220 weights = nn.functional.softmax(scores, dim=-1)  # (bs, n_heads, q_length, k_length)

TypeError: no implementation found for 'torch.Tensor.expand_as' on types that implement __torch_function__: [<class 'docarray.typing.tensor.abstract_tensor.TorchTensor[512]'>, <class 'docarray.typing.tensor.abstract_tensor.TorchTensor[512]'>]

Metadata

Metadata

Labels

DocArray v2This issue is part of the rewrite; not to be merged into main

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions