检查tf.Tensor是否可变

时间:2019-07-18 16:06:29

标签: python tensorflow

如何检查tf.Tensor是否可变?

我想断言函数的参数具有正确的类型。

tf.tensor可能是可变的:

import tensorflow as tf
import numpy as np
x = tf.get_variable('x', shape=(2,), dtype=np.float32)
print(x[1])  # x[1] is a tf.Tensor
tf.assign(x[1], 1.0)

2 个答案:

答案 0 :(得分:1)

这不是公共API的一部分,但看看tf.assign的实现方式,我想您可以这样做:

import tensorflow as tf

def is_assignable(x):
    return x.dtype._is_ref_dtype or (isinstance(x, tf.Tensor) and hasattr(x, 'assign'))

答案 1 :(得分:0)

您可以检查他们的dtype attributes,例如断言my_tensor.dtype == tf.float32

Tensors are immutable在变量之外:它们描述数量之间的关系。除非将type cast操作添加到图形(添加边),否则数据类型不会更改。如果将值传递给类型与期望类型不同的张量,例如将数据加载到管道中时,会引发错误。您可以通过为张量分配错误的类型来检查这一点-您会得到一个错误。

尝试此代码

import tensorflow as tf
x = tf.get_variable('x', shape=(2,), dtype=tf.float32)
tf.assign(x[1], tf.ones(shape=(2,), dtype=tf.int32))

您应该得到以下错误:“ TypeError:“ StridedSliceAssign”的输入“值” Op的类型为int32的类型与参数“ ref”的float32的类型不匹配。”