我有这个烦人的问题,我不知道如何解决。
我正在使用数据集读取器从CSV中批量读取数据,并希望收集某些列。读取器返回张量的元组,并且根据我使用的读取器,通过整数或字符串对列进行索引。
我可以轻松地在python中执行for循环并将所需的列切片,但是我想在tf.while_loop中执行此操作以利用并行执行。
这是我的问题所在-while循环中的迭代器基于张量,并且我无法使用它来索引我的数据集。如果我尝试对其进行评估,则会收到有关会话不相同等的错误消息
我该如何使用while循环(或映射函数)并使该函数能够在不评估或运行迭代器张量的情况下索引到python列表/字典中?
简单的例子:
#green
答案 0 :(得分:0)
这符合您的要求吗?除了打印值外,它什么都不做。
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
some_data = [11,222,33,4,5,6,7,8]
def func( v ):
print (some_data[v])
return some_data[v]
with tf.Session() as sess:
r = tf.while_loop(
lambda i, v: i < 4,
lambda i, v: [i + 1, tf.py_func(func, [i], [tf.int32])[0]],
[tf.constant(0), tf.constant(2, tf.int32)],
[tensor_shape.unknown_shape(), tensor_shape.unknown_shape()])
r[1].eval()
它打印
11 4 222 33
每次都会更改顺序,但我想tf.control_dependencies
可能会有用。