更改为使用tf.data api后,Keras模型无法学习任何内容

时间:2018-08-24 07:37:52

标签: tensorflow keras

我试图将一个简单的Keras模型转换为使用tf.data api进行数据加载,但是不知何故,在整个10个时间段内,精度仍然保持在10%左右。

相比之下,不使用tf.data api的原始代码可以轻松实现约98%的准确性。我做错了什么吗?

使用tf.data api的版本

import math
import tensorflow as tf
import numpy as np

batch_size = 32


def load_data():
    mnist = tf.keras.datasets.mnist
    (train_data, train_label), (validation_data, validation_label) = mnist.load_data()
    train_data, validation_data = train_data / 255.0, validation_data / 255.0
    train_label = train_label.astype(np.float32)
    return train_data, train_label


def build_model():
    class MyModel(tf.keras.Model):

        def __init__(self):
            super(MyModel, self).__init__(name='my_model')
            self.flatten = tf.keras.layers.Flatten()
            self.dense_1 = tf.keras.layers.Dense(512, activation=tf.nn.relu)
            self.dropout = tf.keras.layers.Dropout(0.2)
            self.dense_2 = tf.keras.layers.Dense(10, activation=tf.nn.softmax)

        def call(self, inputs):
            x = self.flatten(inputs)
            x = self.dense_1(x)
            x = self.dropout(x)
            y = self.dense_2(x)
            return y

    model = MyModel()

    model.compile(optimizer=tf.train.AdamOptimizer(),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    return model


train_data, train_label = load_data()
train_sample_count = len(train_data)

train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))

train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.repeat()

model = build_model()
model.fit(
    train_dataset,
    epochs=10,
  steps_per_epoch=math.ceil(train_sample_count/batch_size)
)

不使用tf.data api的版本

# load_data and build_model are exactly same as those in the tf.data api version

train_data, train_label = load_data()
model = build_model()
model.fit(
    train_data,
    train_label,
    epochs=10
)

0 个答案:

没有答案