Skip to content

Extending tf.Tensor class #59472

@anna-charlotte

Description

@anna-charlotte
Click to expand!

Issue Type

Support

Have you reproduced the bug with TF nightly?

Yes

Source

binary

Tensorflow Version

2.11.0

Custom Code

Yes

OS Platform and Distribution

MacOS Ventura 13.0

Mobile device

No response

Python version

3.10.8

Bazel version

No response

GCC/Compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current Behaviour?

I want to extend the tf.Tensor class, but neither of the following options work:

  1. option: extend tf.Tensor:
class MyTFTensor(tf.Tensor):
    
    @classmethod
    def _from_native(cls, value: tf.Tensor):
        value.__class__ = cls
        return value

y = MyTFTensor._from_native(value=tf.zeros((3, 224, 224))

Fails with:

/var/folders/kb/yxxdttyj4qzcm447np5p22kw0000gp/T/ipykernel_94703/515175960.py in _from_native(cls, value)
      4     @classmethod
      5     def _from_native(cls, value: tf.Tensor):
----> 6         value.__class__ = cls
      7         return value

TypeError: __class__ assignment: 'MyTFTensor' object layout differs from 'tensorflow.python.framework.ops.EagerTensor'
  1. Option: extend EagerTensor
from tensorflow.python.framework.ops import EagerTensor
class MyTFTensor(EagerTensor):
    
    @classmethod
    def _from_native(cls, value: tf.Tensor):
        value.__class__ = cls
        return value

Fails with:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/kb/yxxdttyj4qzcm447np5p22kw0000gp/T/ipykernel_94703/3632871733.py in <cell line: 2>()
      1 from tensorflow.python.framework.ops import EagerTensor
----> 2 class MyTFTensor(EagerTensor):
      3 
      4     @classmethod
      5     def _from_native(cls, value: tf.Tensor):

TypeError: type 'tensorflow.python.framework.ops.EagerTensor' is not an acceptable base type

Our goal is to extend it though, we don't want to store the tf.tensor instance as an attribute of MyTFTensor, but instead extend the tf.Tensor class!

Standalone code to reproduce the issue

import tensorflow as tf

# 1st option
class MyTFTensor(tf.Tensor):
    
    @classmethod
    def _from_native(cls, value: tf.Tensor):
        value.__class__ = cls
        return value

y = MyTFTensor._from_native(value=tf.zeros((3, 224, 224))

# 2nd option
from tensorflow.python.framework.ops import EagerTensor
class MyTFTensor(EagerTensor):
    
    @classmethod
    def _from_native(cls, value: tf.Tensor):
        value.__class__ = cls
        return value

Relevant log output

No response

Metadata

Metadata

Labels

TF 2.11Issues related to TF 2.11comp:opsOPs related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authortype:supportSupport issues

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions