tensorflow.data批处理方法不批处理数据集中的文本行

时间:2019-01-12 08:15:40

标签: python tensorflow

我一直在关注斯坦福大学tensorflow link的tensorflow教程,但遇到了麻烦。

我正在从具有以下几行的文本文件中读取数据:

I use Tensorflow
You use PyTorch
Tensorflow is better
By a lot

使用oneshot迭代器时,批处理方法可以正常工作

# Reading the file with tf.data
import tensorflow as tf

dataset = tf.data.TextLineDataset("file.txt")

iterator = dataset.make_one_shot_iterator() # iter can loop through data once
next_element = iterator.get_next() 


#---TRANSFORMING DATA---

# create batches 
batch_size = 2
dataset = dataset.batch(batch_size) 

# prefetching data (transfer from main storage to temp for faster loader)
dataset = dataset.prefetch(1)

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    for _ in range(4//batch_size):
        print(sess.run(next_element))

这将按预期返回(请注意,注释不是输出的一部分):

[b'I use Tensorflow' b'You use PyTorch'] # first batch
[b'Tensorflow is better' b'By a lot'] # second batch

但是,当我使用可初始化的迭代器进行类似的练习时...

# Using initializable iterators
import tensorflow as tf

dataset = tf.data.TextLineDataset("file.txt")
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
init_op = iterator.initializer # this allows you to reset iterator --> you can iterate multiple times (epochs)

epochs = 2
batch_size = 2
num_examples = 4

# This code block doesn't seem to work


dataset = dataset.batch(batch_size) # doesn't batch up lines


dataset = dataset.prefetch(1)
with tf.Session() as sess:
    # Initialize the iterator
    for i in range(epochs):
        sess.run(init_op)
        for _ in range(num_examples//batch_size): # loops through all batches
            print(sess.run(next_element))
        print("\n")

...我得到这个结果:

# first epoch
b'I use Tensorflow' #??
b'You use PyTorch'

# second epoch
b'I use Tensorflow'
b'You use PyTorch'

我所期望的:

# First epoch
[b'I use Tensorflow' b'You use PyTorch'] # first batch
[b'Tensorflow is better' b'By a lot'] # second batch

# Second epoch
[b'I use Tensorflow' b'You use PyTorch'] # first batch
[b'Tensorflow is better' b'By a lot'] # second batch

有人可以帮我弄清楚我做错了什么吗?我已经检查了batch()的文档,一切看起来都已经签出了。

谢谢。

2 个答案:

答案 0 :(得分:0)

也许是因为您在每个时期都呼叫sess.run(init_op)吗?在循环之前调用一次。

答案 1 :(得分:0)

在创建迭代器之前,将post_delete行移到 之前。