TF 2.0打印张量值

时间:2019-03-22 15:23:57

标签: python tensorflow tensorflow2.0

我正在学习Tensorflow(2.0)的最新版本,并且我试图运行一个简单的代码来切片矩阵。 使用装饰器@ tf.function我做了以下课程:

class Data:
def __init__(self):
    pass

def back_to_zero(self, input):
    time = tf.slice(input, [0,0], [-1,1])
    new_time = time - time[0][0]
    return new_time

@tf.function
def load_data(self, inputs):
    new_x = self.back_to_zero(inputs)
    print(new_x)

因此,当使用numpy矩阵运行代码时,我无法检索数字。

time = np.linspace(0,10,20)
magntiudes = np.random.normal(0,1,size=20)
x = np.vstack([time, magntiudes]).T


d = Data()
d.load_data(x)

输出:

Tensor("sub:0", shape=(20, 1), dtype=float64)

我需要以numpy格式获取此张量,但是TF 2.0没有类tf.Session可以使用run()或eval()方法。

感谢您能为我提供的任何帮助!

3 个答案:

答案 0 :(得分:1)

同样的问题,如何只打印张量的值?我已经失去了半天。慢慢地但一定要移到Pytorch ...

答案 1 :(得分:0)

在装饰器@tf.function所指示的图形内,您可以使用tf.print打印张量的值。

tf.print(new_x)

这是重写代码的方式

class Data:
    def __init__(self):
        pass

    def back_to_zero(self, input):
        time = tf.slice(input, [0,0], [-1,1])
        new_time = time - time[0][0]
        return new_time

    @tf.function
    def load_data(self, inputs):
        new_x = self.back_to_zero(inputs)
        tf.print(new_x) # print inside the graph context
        return new_x

time = np.linspace(0,10,20)
magntiudes = np.random.normal(0,1,size=20)
x = np.vstack([time, magntiudes]).T

d = Data()
data = d.load_data(x)
print(data) # print outside the graph context

tf.decorator上下文之外的张量类型为tensorflow.python.framework.ops.EagerTensor。要将其转换为numpy数组,可以使用data.numpy()

答案 2 :(得分:0)

问题是您无法直接在图形内部获取张量的值。因此,您可以按照{edkeveked的建议使用tf.print进行操作,或按如下所示更改代码:

class Data:
    def __init__(self):
        pass

    def back_to_zero(self, input):
        time = tf.slice(input, [0,0], [-1,1])
        new_time = time - time[0][0]

        return new_time

    @tf.function
    def load_data(self, inputs):
        new_x = self.back_to_zero(inputs)

        return new_x

time = np.linspace(0,10,20)
magntiudes = np.random.normal(0,1,size=20)
x = np.vstack([time, magntiudes]).T

d = Data()
data = d.load_data(x)
print(data.numpy())