这个问题是关于访问张量中的单个元素,比如[[1,2,3]]。我需要访问内部元素[1,2,3](这可以使用.eval()或sess.run()执行)但是当张量的大小很大时需要更长时间)
有没有什么方法可以更快地做到这一点?
先谢谢。
答案 0 :(得分:51)
有两种主要方法可以访问张量中元素的子集,其中任何一个都适用于您的示例。
使用索引运算符(基于tf.slice()
)从张量中提取连续切片。
input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
output = input[0, :]
print sess.run(output) # ==> [1 2 3]
索引运算符支持许多与NumPy相同的切片规范。
使用tf.gather()
op从张量中选择一个非连续切片。
input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
output = tf.gather(input, 0)
print sess.run(output) # ==> [1 2 3]
output = tf.gather(input, [0, 2])
print sess.run(output) # ==> [[1 2 3] [7 8 9]]
请注意,tf.gather()
仅允许您选择第0维中的整个切片(矩阵示例中的整行),因此您可能需要输入tf.reshape()
或tf.transpose()
获得适当的元素。
答案 1 :(得分:1)
我怀疑这是计算的其余部分需要时间,而不是访问一个元素。
此外,结果可能需要从存储的任何内存中复制,因此如果它在显卡上,则需要先将其复制回RAM,然后才能访问您的元素。如果是这种情况,您可以通过添加tensorflow操作来跳过它来获取第一个元素,并且只返回它。
答案 2 :(得分:0)
你没有得到[[1,2,3]]的第0个元素的值而没有run() - ning或eval() - 正在获得它的操作。因为在'run'或'eval'之前,你只有一个描述如何获得这个内部元素(因为TF使用符号图/计算)。因此,即使您使用tf.gather / tf.slice,您仍然必须通过eval / run获得这些操作的值。请参阅@ mrry的回答。
答案 3 :(得分:0)
希望我能很好地理解你的问题。您可以通过.numpy()
访问TensorFlow 2中的张量元素。
import tensorflow as tf
t = tf.constant([[1,2,3]])
print(t.numpy()[0][1]) # This will prints 2
>>> 2