Keras-如何正确使用fit()训练模型?

时间:2020-06-25 02:07:06

标签: python tensorflow keras deep-learning

我一直熟悉Keras,从文档开始,我整理了一个基本模型并加载了自己的图像文件夹以进行训练,而不是使用mnist数据集。我已经到了建立模型的地步,但是我不确定如何使用fit()方法调用我的数据集,然后训练模型进行预测。这是到目前为止的代码:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers.experimental.preprocessing import CenterCrop
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
from tensorflow.keras import layers

#Importing the dataset and setting the path
dataset = keras.preprocessing.image_dataset_from_directory(
    'PetImages',
    batch_size = 64,
    image_size = (200, 200)
)

dataset = keras.Input(shape = (None, None, 3))

# PreProcessing layers to better format the datset 
x = CenterCrop(height=150, width=150)(dataset)
x = Rescaling(scale=1.0 / 255)(x)

# Convolution and Pooling Layers
x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation="relu")(x)
x = layers.MaxPooling2D(pool_size=(3, 3))(x)
x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation="relu")(x)
x = layers.MaxPooling2D(pool_size=(3, 3))(x)
x = layers.Conv2D(filters=32, kernel_size=(3, 3), activation="relu")(x)

# Global average pooling to get flat feature vectors
x = layers.GlobalAveragePooling2D()(x)

# Adding a dense classifier 
num_classes = 10
outputs = layers.Dense(num_classes, activation="softmax")(x)

# Instantiates the model once layers have been set
model = keras.Model(inputs = dataset, outputs = outputs)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')

# Problem: Unsure how to further call on dataset to train the model and make a prediction
model.fit()

1 个答案:

答案 0 :(得分:2)

.fit()方法实际上是将对您的网络进行训练的方法,以便它以您希望其训练的方式运行。模型需要数据才能进行训练。我将看看他们的documentation或他们的一些示例,以了解如何将您的模型作为一个很好的起点。

根据您使用的tensorflow.keras的版本,.fit可以采用两个位置参数xy,也可以采用一个生成器对象,该生成器对象就像一个连续活动的功能。您可能还需要设置一个batch_size,这实际上是一次要评估多少个样本。同样,该文档将提供有关可以采用哪种参数的更多信息。

在您的情况下,似乎您在变量dataset中得到了一些不错的输入图像(您会立即覆盖它),但是没有标签。标签定义输入的训练图像的预期输出。您需要的第一步是一组标签,然后,可以对代码进行一些调整以使其运行:

# Add line below
labels = # ... load labels from someplace, like how you loaded the images

# Change this
dataset = keras.Input(shape = (None, None, 3))
# to this
input_layer = layers.Input(shape=(None, None, 3))

# Change this
model = keras.Model(inputs = dataset, outputs = outputs)
# to this
model = keras.models.Model(inputs = input_layer, outputs = outputs)

# and finally you can fit your model using
model.fit(dataset, labels)