如何使用张量流数据集训练神经网络?

时间:2020-05-23 15:20:19

标签: python tensorflow keras neural-network

我正在尝试在emnist数据集上训练神经网络,但是当我尝试展平图像时,它将引发以下错误:

警告:tensorflow:使用输入Tensor的形状(None,28,28)构造了模型(“ flatten_input:0”,shape =(None,28,28),dtype = float32),但在一个输入的形状不兼容(无,1、28、28)。

我不知道是什么问题,并试图更改我的预处理,从model.fit和ds.map中删除批处理大小。

这是完整的代码:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

def preprocess(dict):
    image = dict['image']
    image = tf.transpose(image)
    label = dict['label']
    return image, label

train_data, validation_data = tfds.load('emnist/letters', split = ['train', 'test'])
train_data_gen = train_data.map(preprocess).shuffle(1000).batch(32)
validation_data_gen = validation_data.map(preprocess).batch(32)

print(train_data_gen)
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape = (28, 28)),
    tf.keras.layers.Dense(128, activation = 'relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation = 'softmax')
])

model.compile(optimizer = 'adam',
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy'])

early_stopping = keras.callbacks.EarlyStopping(monitor = 'val_accuracy', patience = 10)
history = model.fit(train_data_gen, epochs = 50, batch_size = 32, validation_data = validation_data_gen, callbacks = [early_stopping], verbose = 1)
model.save('emnistmodel.h5')

2 个答案:

答案 0 :(得分:1)

实际上,这里发生了一些事情,所以让我们一次解决它们。

  1. 输入形状

    因此,要解决您的紧迫问题,您会收到不兼容的形状错误,因为输入的形状与预期的形状不匹配。

    在这一行tf.keras.layers.Flatten(input_shape=(28, 28)),中,我们告诉模型期望输入形状(28,28),但这并不准确。我们的输入实际上具有形状(28、28、1),因为我们要拍摄具有 1个通道的28x28像素图像(与具有3个通道r,g和b的彩色图像相反)。因此,要解决此直接问题,我们只需更新模型以使用输入的形状即可。即tf.keras.layers.Flatten(input_shape=(28, 28, 1)),

  2. 输出节点数

    正如里沙伯(Rishabh)在他的回答中所建议的那样,EMNIST dataset具有十多个平衡类。但是,就您而言,您似乎正在使用具有26个平衡类的EMNIST字母。因此,您的神经网络应该相应地具有27个输出节点(因为类标签从1 .. 26出发,而我们的输出节点对应于0 .. 26)才能对给定数据进行分类。当然,为其提供额外的输出节点将使其也能够运行,但是这将为我们提供不必要的额外权重训练,这将增加我们模型所需的训练时间。简而言之,您的最后一层应该是tf.keras.layers.Dense(27, activation='softmax')

  3. 预处理TensorFlow数据集

    阅读您的preprocess()函数,我相信您正在尝试将训练和验证数据集转换为(图像,标签)的元组。 TensorFlow无需创建自己的函数,而是通过参数as_supervised为我们方便地实现了此功能。

    此外,我看到了您正在尝试实现的一些额外预处理,例如批处理和改组数据。同样,TensorFlow为我们实现了batch_sizeshuffle_files(请参见常见参数)!因此,加载数据集的外观类似于

    train_data, validation_data = tfds.load('emnist/letters',
                                            split=['train', 'test'],
                                            shuffle_files=True,
                                            batch_size=32,
                                            as_supervised=True)
    
  4. 一些其他注释

    另外,建议您考虑将batch_size从model.fit()中排除。在两个不同的地方定义同一件事是导致错误和意外行为的原因。此外,在使用TensorFlow数据集时,没有必要,因为它们already generate batches

整个更新后的程序应该看起来像这样

import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow import keras
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


train_data, validation_data = tfds.load('emnist/letters',
                                        split=['train', 'test'],
                                        shuffle_files=True,
                                        batch_size=32,
                                        as_supervised=True)

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(27, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_accuracy', patience=10)

history = model.fit(train_data,
                    epochs=50,
                    validation_data=validation_data,
                    callbacks=[early_stopping],
                    verbose=1)
model.save('emnistmodel.h5')

希望这会有所帮助!

答案 1 :(得分:0)

Hie @Rattandeep我刚刚检查了emnist数据集,它具有47个不同的类,并且在您的密集层中,您提到了10。

如果您从

更改代码

tf.keras.layers.Dense(10,激活='softmax')

对于这个,它将起作用

tf.keras.layers.Dense(47,激活='softmax')

谢谢