我想迭代数据集直到满足某些条件,但我不知道如何“迭代”。以下是我的代码。
import tensorflow as tf
c = tf.constant([1,2,6])
d = tf.data.Dataset.from_tensor_slices((c,))
t = d.make_one_shot_iterator().get_next()
def condition(t):
return t < 5
def body(t):
# I don't know what to do here to return the next t
return [t]
t = tf.while_loop(condition, body, [t])
with tf.Session() as sess:
print(sess.run([t]))
在回答下面亚历克斯的回答时,下面是我想要实现的更现实的例子。
import tensorflow as tf
# I want to "merge" the dataset da to dataset db by "backfilling" da.
# So session.run will return [[1,'a'], [1,'x']], then [[5, 'c'],[3, 'y']]
# note that one element from dataset da is skipped, which is what I want to achieve with the while loop.
ta = tf.constant([1,2,5])
va = tf.constant(['a','b','c'])
da = tf.data.Dataset.from_tensor_slices((ta, va))
tb = tf.constant([1,3,6])
vb = tf.constant(['x','y','z'])
db = tf.data.Dataset.from_tensor_slices((tb, vb))
ea = da.make_one_shot_iterator().get_next()
eb = db.make_one_shot_iterator().get_next()
def condition(ea, eb):
return ea[0] < eb[0]
def body(ea, eb):
# I don't know what to do here to get the next ea.
return ea, eb
result = tf.while_loop(condition, body, (ea, eb))
with tf.Session() as sess:
sess.run([result])
我可以像Alex建议的那样将while循环逻辑移动到python,但我猜测它会在数据流图中留下更好的性能。
答案 0 :(得分:0)
我相信您还不了解Tensorflow的工作原理。 Tensorflow tf.while_loop在计算图内部创建一个while循环,通过添加控制语句多次重新应用图的某些部分,直到满足某个条件。我建议您开始阅读here以了解图表和会话是什么。
您当然不希望迭代计算图中的数据集,您想要的迭代应该发生在Python中,而不是Tensorflow图中。
这就是你要做的:
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
更详细地解释了这一点here。
答案 1 :(得分:0)
您可以使用Dataset.filter()方法根据用户定义的谓词过滤数据集元素。你必须传入一个过滤函数,它返回一个tf.bool张量,如果你想保留记录,则评估为真,否则为假。