Keras Sub分类api风格

时间:2018-05-15 05:07:01

标签: python api tensorflow keras subclass

我坚持使用子类方法制作模型。问题是在这个子类方法中,我们的输入形状方法在哪里,我们的编译步骤在哪里?

请帮我完成作业。

managedObject.localDate = stringDate

这是the link

1 个答案:

答案 0 :(得分:0)

我希望摘自https://www.tensorflow.org/guide/keras的这段代码能对您有所帮助:

class MyModel(keras.Model):

  def __init__(self, num_classes=10):
    super(MyModel, self).__init__(name='my_model')
    self.num_classes = num_classes
    # Define your layers here.
    self.dense_1 = keras.layers.Dense(32, activation='relu')
    self.dense_2 = keras.layers.Dense(num_classes, activation='sigmoid')

  def call(self, inputs):
    # Define your forward pass here,
    # using layers you previously defined (in `__init__`).
    x = self.dense_1(inputs)
    return self.dense_2(x)

  def compute_output_shape(self, input_shape):
    # You need to override this function if you want to use the subclassed model
    # as part of a functional-style model.
    # Otherwise, this method is optional.
    shape = tf.TensorShape(input_shape).as_list()
    shape[-1] = self.num_classes
    return tf.TensorShape(shape)


# Instantiates the subclassed model.
model = MyModel(num_classes=10)

# The compile step specifies the training configuration.
model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Trains for 5 epochs.
model.fit(data, labels, batch_size=32, epochs=5)

您可以看到“ model.compile”调用,在拟合阶段,您将把输入数据传递给模型。数据在模型内部的流动方式是在call方法中定义的,因此,如果要进行一些输入大小验证,也可以将其放置在那里。

Seba