tf.data API无法打印所有批次

时间:2018-11-23 16:57:27

标签: python tensorflow lstm tensorflow-datasets

我正在自学tf.data API。我正在使用MNIST数据集进行二进制分类。训练x和y数据在完整的train_dataset中压缩在一起。首先将此batch()数据集方法与此zip方法链接在一起。数据以30的批次大小进行批处理。由于我的训练集大小为11623,而批次大小为128,所以我将有91批次。最后一批的大小为103,这很好,因为这是LSTM。此外,我正在使用辍学。在计算批处理准确性时,我正在关闭辍学。

完整代码如下:

#Ignore the warnings
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (8,7)

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/")

Xtrain = mnist.train.images[mnist.train.labels < 2]
ytrain = mnist.train.labels[mnist.train.labels < 2]

print(Xtrain.shape)
print(ytrain.shape)

#Data parameters
num_inputs = 28
num_classes = 2
num_steps=28

# create the training dataset
Xtrain = tf.data.Dataset.from_tensor_slices(Xtrain).map(lambda x: tf.reshape(x,(num_steps, num_inputs)))
# apply a one-hot transformation to each label for use in the neural network
ytrain = tf.data.Dataset.from_tensor_slices(ytrain).map(lambda z: tf.one_hot(z, num_classes))
# zip the x and y training data together and batch and Prefetch data for faster consumption
train_dataset = tf.data.Dataset.zip((Xtrain, ytrain)).batch(128).prefetch(128)

iterator = tf.data.Iterator.from_structure(train_dataset.output_types,train_dataset.output_shapes)
X, y = iterator.get_next()

training_init_op = iterator.make_initializer(train_dataset)


#### model is here ####

#Network parameters
num_epochs = 2
batch_size = 128
output_keep_var = 0.5

with tf.Session() as sess:
    init.run()

    print("Initialized")
    # Training cycle
    for epoch in range(0, num_epochs):
        num_batch = 0
        print ("Epoch: ", epoch)
        avg_cost = 0.
        avg_accuracy =0
        total_batch = int(11623 / batch_size + 1)
        sess.run(training_init_op)
       while True:
            try:
                _, miniBatchCost = sess.run([trainer, loss], feed_dict={output_keep_prob: output_keep_var})
                miniBatchAccuracy = sess.run(accuracy, feed_dict={output_keep_prob: 1.0})
               print('Batch %d: loss = %.2f, acc = %.2f' % (num_batch, miniBatchCost, miniBatchAccuracy * 100))
                num_batch +=1
            except tf.errors.OutOfRangeError:
                break

当我运行这段代码时,它似乎正在工作并可以打印:

Batch 0: loss = 0.67276, acc = 0.94531
Batch 1: loss = 0.65672, acc = 0.92969
Batch 2: loss = 0.65927, acc = 0.89062
Batch 3: loss = 0.63996, acc = 0.99219
Batch 4: loss = 0.63693, acc = 0.99219
Batch 5: loss = 0.62714, acc = 0.9765
......
......
Batch 39: loss = 0.16812, acc = 0.98438
Batch 40: loss = 0.10677, acc = 0.96875
Batch 41: loss = 0.11704, acc = 0.99219
Batch 42: loss = 0.10592, acc = 0.98438
Batch 43: loss = 0.09682, acc = 0.97656
Batch 44: loss = 0.16449, acc = 1.00000

但是,正如人们可以轻易看到的那样,这是有问题的。仅打印了45批次,而不是91批次,我不知道为什么会这样。我尝试了很多事情,但我想我缺少了一些东西。

我可以使用repeat()函数,但我不希望这样做,因为我对最后一批有多余的观察,我希望LSTM处理它。

1 个答案:

答案 0 :(得分:3)

当直接基于get_next()迭代器的tf.data输出定义模型时,这是一个令人讨厌的陷阱。在循环中,您有两个sess.run调用,其中两个 都会使迭代器前进一个步骤。这意味着每个循环迭代实际上会消耗两个批次(并且您的损失和准确性计算是针对不同的批次进行的。)

不确定是否有解决此问题的“规范”方法,但是您可以

  • 在与成本/培训步骤相同的run调用中计算准确性。这将意味着精度计算还受防丢蒙版的影响,但是由于它是仅基于一批的近似值,因此这不是一个大问题。
  • 取而代之的是基于占位符定义模型,并在每次循环run get_next op本身中进行迭代,然后将所得的numpy数组(即批处理)输入到损耗/精度计算中。 / li>