我想对tf.Tensor
类进行子类化。这个想法是,子类的对象的行为应类似于Tensors(即,我可以使用它们来进行任何类型的tf操作),但它们还应该拥有其他属性,这些属性可以在框架内为它们提供特定的行为。
到目前为止,我是在Graph模式下工作的,我只是做这样的事情:
class EnrichedTensor(tf.Tensor):
def __init__(self, tensor, other_stuff):
super(EnrichedTensor, self).__init__((
op=tensor.op,
value_index=tensor.value_index,
dtype=tensor.dtype)
self.other_stuff = other_stuff
现在,我只想在热切的模式下执行相同的操作,但是我真的不知道(而且我什么也没找到)EagerTensor
实例化。显然,op
属性不再有意义。
我尝试通过__new__
方法处理对象的 creation ,但是在将EnrichedTensor
子类化并遵循创建路径时发现了问题。
因此,我想知道是否有任何方法可以以“声音”方式干净地进行此操作。