在tf.while_loop期间将Tensorflow索引到python列表中

时间:2019-03-06 11:54:52

标签: tensorflow

我有这个烦人的问题,我不知道如何解决。

我正在使用数据集读取器从CSV中批量读取数据,并希望收集某些列。读取器返回张量的元组,并且根据我使用的读取器,通过整数或字符串对列进行索引。

我可以轻松地在python中执行for循环并将所需的列切片,但是我想在tf.while_loop中执行此操作以利用并行执行。

这是我的问题所在-while循环中的迭代器基于张量,并且我无法使用它来索引我的数据集。如果我尝试对其进行评估,则会收到有关会话不相同等的错误消息

我该如何使用while循环(或映射函数)并使该函数能够在不评估或运行迭代器张量的情况下索引到python列表/字典中?

简单的例子:

#green

1 个答案:

答案 0 :(得分:0)

这符合您的要求吗?除了打印值外,它什么都不做。

import tensorflow as tf
from tensorflow.python.framework import tensor_shape

some_data = [11,222,33,4,5,6,7,8]

def func( v ):
    print (some_data[v])
    return some_data[v]

with tf.Session() as sess:
    r = tf.while_loop(
        lambda i, v: i < 4,
        lambda i, v: [i + 1, tf.py_func(func, [i], [tf.int32])[0]],
        [tf.constant(0), tf.constant(2, tf.int32)],
        [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()])

    r[1].eval()

它打印

  

11   4   222   33

每次都会更改顺序,但我想tf.control_dependencies可能会有用。