即使我们不使用model.fit,我们应该何时继承keras.Model而不是keras.layers.Layer?

时间:2019-09-26 13:50:36

标签: tensorflow keras tensorflow2.0 tf.keras

在一些使用tf2的Tensorflow教程中(例如Neural Machine Translation with AttentionEager essentials),他们定义了自定义的tf.keras.Model而不是tf.keras.layers.Layer(例如BahdanauAttention(tf.keras.Model):)< / p>

此外,Models: composing layers文档明确使用了tf.keras.Model。该部分显示:

  

创建包含其他图层的类似图层的东西时使用的主要类是tf.keras.Model。实现方法之一是通过继承tf.keras.Model。

听起来我们需要继承tf.keras.Model来定义组成子图层的图层。

但是,据我检查,即使我将ResnetIdentityBlock定义为tf.keras.layers.Layer的子类,此代码仍然有效。其他两个教程也可以与Layer一起使用。 除此之外,another tutorial

  

模型就像一个图层,但是增加了训练和序列化实用程序。

因此,我不知道tf.keras.Modeltf.keras.layers.Layer之间的真正区别是什么,为什么那三个急切执行的教程尽管不使用 training却使用tf.keras.Modeltf.keras.Model的序列化实用程序

为什么在那些教程中我们需要继承tf.keras.Model

其他评论

Model

实用程序仅适用于LayerLayers whose call receive only one input)的特殊子集。因此,我认为像“始终扩展模型,因为模型具有更多功能” 这样的想法是不正确的。而且,它违反了诸如SRP之类的基本编程程序。

1 个答案:

答案 0 :(得分:1)

更新

所以评论是:Yes, I know training and serialization utilities exist in Model as I wrote in the question. My question is why TF tutorials need to use Model though they don't use these methods.

在这种情况下,作者可以提供最佳答案,因为您的问题询问为什么他们选择一种方法而不是另一种方法,因为他们俩都可以很好地完成这项工作。为什么做得同样好?好吧,因为Model is just like a Layer, but with added training and serialization utilities.

我们可以争辩说,仅在图层可以胜任工作时使用模型是一个过大的杀伤力,但这可能是一个品味问题。

希望有帮助

PS。

在您提供的eager examplecustom layer writing教程中,我们无法将模型替换为图层,因此这些教程不适用于您的问题


使用模型,您可以训练,但是只有图层,您不能。请参见下面的方法列表(不包括内部方法和继承方法):

tf.keras.layers.Layer

activity_regularizer
activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_mask
compute_output_shape
count_params
dtype
dynamic
from_config
get_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
set_weights
trainable
trainable
trainable_variables
trainable_weights
updates
variables
weights

看到了吗?没有合适的方法或评估方法。 tf.keras.Model


compile
evaluate
evaluate_generator
fit
fit_generator
get_weights
load_weights
metrics
metrics_names
predict
predict_generator
predict_on_batch
reset_metrics
run_eagerly
run_eagerly
sample_weights
test_on_batch
train_on_batch