从tf.while_loop中的数据集中获取下一个张量

时间:2018-05-04 07:08:29

标签: tensorflow tensorflow-datasets

我想迭代数据集直到满足某些条件,但我不知道如何“迭代”。以下是我的代码。

import tensorflow as tf

c = tf.constant([1,2,6])
d = tf.data.Dataset.from_tensor_slices((c,))
t = d.make_one_shot_iterator().get_next()

def condition(t):
  return t < 5

def body(t):
  # I don't know what to do here to return the next t
  return [t]

t = tf.while_loop(condition, body, [t])

with tf.Session() as sess:
    print(sess.run([t]))

在回答下面亚历克斯的回答时,下面是我想要实现的更现实的例子。

import tensorflow as tf

# I want to "merge" the dataset da to dataset db by "backfilling" da.
# So session.run will return [[1,'a'], [1,'x']], then [[5, 'c'],[3, 'y']]
# note that one element from dataset da is skipped, which is what I want to achieve with the while loop.
ta = tf.constant([1,2,5]) 
va = tf.constant(['a','b','c'])
da = tf.data.Dataset.from_tensor_slices((ta, va))

tb = tf.constant([1,3,6])
vb = tf.constant(['x','y','z'])
db = tf.data.Dataset.from_tensor_slices((tb, vb))

ea = da.make_one_shot_iterator().get_next()
eb = db.make_one_shot_iterator().get_next()

def condition(ea, eb):
  return ea[0] < eb[0]

def body(ea, eb):
  # I don't know what to do here to get the next ea.
  return ea, eb

result = tf.while_loop(condition, body, (ea, eb))

with tf.Session() as sess:
  sess.run([result])

我可以像Alex建议的那样将while循环逻辑移动到python,但我猜测它会在数据流图中留下更好的性能。

2 个答案:

答案 0 :(得分:0)

我相信您还不了解Tensorflow的工作原理。 Tensorflow tf.while_loop在计算图内部创建一个while循环,通过添加控制语句多次重新应用图的某些部分,直到满足某个条件。我建议您开始阅读here以了解图表和会话是什么。

您当然不希望迭代计算图中的数据集,您想要的迭代应该发生在Python中,而不是Tensorflow图中。

这就是你要做的:

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)

更详细地解释了这一点here

答案 1 :(得分:0)

您可以使用Dataset.filter()方法根据用户定义的谓词过滤数据集元素。你必须传入一个过滤函数,它返回一个tf.bool张量,如果你想保留记录,则评估为真,否则为假。