如何使`fit_generator`与`tf.keras.Model`一起工作

时间:2019-05-27 07:22:54

标签: python tensorflow machine-learning keras

我正在实现一个tf.keras.Model(不是Sequential模型!),应该使用fit_generator进行训练。 但是,fit_generator会引发错误,可能是因为在编译时输入形状不可用。

这是一个最小的示例:

import tensorflow as tf
import numpy as np


class MyModel(tf.keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(3, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(3, activation=tf.nn.softmax)

    def call(self, inputs, training=None, mask=None):
        return self.dense2(self.dense1(inputs))


class MyGenerator(tf.keras.utils.Sequence):

    def __len__(self):
        # Number of batches per epoch
        return 1

    def __getitem__(self, _):
        # Generate one batch of data
        x = np.array([[1., 2., 3.]])
        y = np.array([[0., 1., 0.5]])

        return x, y


if __name__ == '__main__':
    m = MyModel()    
    g = MyGenerator()

    m.compile(tf.keras.optimizers.SGD(), loss=tf.keras.losses.mean_squared_error)
    m.fit_generator(g)

最后一行升起

AttributeError: 'MyModel' object has no attribute 'total_loss'

那么在自定义Keras模型中使用fit_generator的正确方法是什么?

1 个答案:

答案 0 :(得分:1)

在Tensorflow 2.x中,默认情况下启用急切执行。 Model.fit_generator已过时,将在以后的版本中删除。因此,您必须使用Model.fit,它支持生成器。

请参考TF 2.4兼容代码,如下所示

import tensorflow as tf
print(tf.__version__)
import numpy as np


class MyModel(tf.keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(3, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(3, activation=tf.nn.softmax)

    def call(self, inputs, training=None, mask=None):
        return self.dense2(self.dense1(inputs))


class MyGenerator(tf.keras.utils.Sequence):

    def __len__(self):
        # Number of batches per epoch
        return 1

    def __getitem__(self, _):
        # Generate one batch of data
        x = np.array([[1., 2., 3.]])
        y = np.array([[0., 1., 0.5]])

        return x, y


if __name__ == '__main__':
    m = MyModel()    
    g = MyGenerator()

    m.compile(tf.keras.optimizers.SGD(), loss=tf.keras.losses.mean_squared_error)
    m.fit(g)

输出:

2.4.0
1/1 [==============================] - 0s 224ms/step - loss: 0.4725