我试图将一个简单的Keras模型转换为使用tf.data api进行数据加载,但是不知何故,在整个10个时间段内,精度仍然保持在10%左右。
相比之下,不使用tf.data api的原始代码可以轻松实现约98%的准确性。我做错了什么吗?
使用tf.data api的版本
import math
import tensorflow as tf
import numpy as np
batch_size = 32
def load_data():
mnist = tf.keras.datasets.mnist
(train_data, train_label), (validation_data, validation_label) = mnist.load_data()
train_data, validation_data = train_data / 255.0, validation_data / 255.0
train_label = train_label.astype(np.float32)
return train_data, train_label
def build_model():
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__(name='my_model')
self.flatten = tf.keras.layers.Flatten()
self.dense_1 = tf.keras.layers.Dense(512, activation=tf.nn.relu)
self.dropout = tf.keras.layers.Dropout(0.2)
self.dense_2 = tf.keras.layers.Dense(10, activation=tf.nn.softmax)
def call(self, inputs):
x = self.flatten(inputs)
x = self.dense_1(x)
x = self.dropout(x)
y = self.dense_2(x)
return y
model = MyModel()
model.compile(optimizer=tf.train.AdamOptimizer(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
train_data, train_label = load_data()
train_sample_count = len(train_data)
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.repeat()
model = build_model()
model.fit(
train_dataset,
epochs=10,
steps_per_epoch=math.ceil(train_sample_count/batch_size)
)
不使用tf.data api的版本
# load_data and build_model are exactly same as those in the tf.data api version
train_data, train_label = load_data()
model = build_model()
model.fit(
train_data,
train_label,
epochs=10
)