Tensorflow python:访问张量中的各个元素

时间:2016-02-02 06:21:56

标签: python python-2.7 tensorflow

这个问题是关于访问张量中的单个元素,比如[[1,2,3]]。我需要访问内部元素[1,2,3](这可以使用.eval()或sess.run()执行)但是当张量的大小很大时需要更长时间)

有没有什么方法可以更快地做到这一点?

先谢谢。

4 个答案:

答案 0 :(得分:51)

有两种主要方法可以访问张量中元素的子集,其中任何一个都适用于您的示例。

  1. 使用索引运算符(基于tf.slice())从张量中提取连续切片。

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    
    output = input[0, :]
    print sess.run(output)  # ==> [1 2 3]
    

    索引运算符支持许多与NumPy相同的切片规范。

  2. 使用tf.gather() op从张量中选择一个非连续切片。

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    
    output = tf.gather(input, 0)
    print sess.run(output)  # ==> [1 2 3]
    
    output = tf.gather(input, [0, 2])
    print sess.run(output)  # ==> [[1 2 3] [7 8 9]]
    

    请注意,tf.gather()仅允许您选择第0维中的整个切片(矩阵示例中的整行),因此您可能需要输入tf.reshape()tf.transpose()获得适当的元素。

答案 1 :(得分:1)

我怀疑这是计算的其余部分需要时间,而不是访问一个元素。

此外,结果可能需要从存储的任何内存中复制,因此如果它在显卡上,则需要先将其复制回RAM,然后才能访问您的元素。如果是这种情况,您可以通过添加tensorflow操作来跳过它来获取第一个元素,并且只返回它。

答案 2 :(得分:0)

你没有得到[[1,2,3]]的第0个元素的而没有run() - ning或eval() - 正在获得它的操作。因为在'run'或'eval'之前,你只有一个描述如何获得这个内部元素(因为TF使用符号图/计算)。因此,即使您使用tf.gather / tf.slice,您仍然必须通过eval / run获得这些操作的。请参阅@ mrry的回答。

答案 3 :(得分:0)

希望我能很好地理解你的问题。您可以通过.numpy()访问TensorFlow 2中的张量元素。

import tensorflow as tf
t = tf.constant([[1,2,3]])

print(t.numpy()[0][1]) # This will prints 2

>>> 2