如何使用tf.Print()在张量内打印3个以上的值?

时间:2018-05-12 10:20:43

标签: python tensorflow

我有一个简单的代码,用于将张量分为三个部分axis=1

sess = tf.Session()
a = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])

x, y, z = tf.split(a, 3, axis=1)

print_out_x = tf.Print(x, [x], message='Value of x: ', name='x_value')
print_out_y = tf.Print(y, [y], message='Value of y: ', name='y_value')
print_out_z = tf.Print(z, [z], message='Value of z: ', name='z_value')

sess.run([print_out_x, print_out_y, print_out_z])

print(print_out_x)
print(print_out_y)
print(print_out_z)

我得到了如下输出

enter image description here

如何使用...在x,y,z内而不是tf.Print()内填充值?更多的是,我希望输出

Value of z: [[3][6][9][12]]
Value of y: [[2][5][8][11]]
Value of x: [[1][4][7][10]]

1 个答案:

答案 0 :(得分:2)

您可以使用summarize函数中的Argstf.Print()来设置每个输入张量需要打印的参数数量。

对于前,

tf.Print(x, [x],summarize=x.shape[0], message='Value of x: ', name='x_value')
#Value of x: [[1][4][7][10]]