在一些使用tf2的Tensorflow教程中(例如Neural Machine Translation with Attention和Eager essentials),他们定义了自定义的tf.keras.Model
而不是tf.keras.layers.Layer
(例如BahdanauAttention(tf.keras.Model):
)< / p>
此外,Models: composing layers文档明确使用了tf.keras.Model
。该部分显示:
创建包含其他图层的类似图层的东西时使用的主要类是tf.keras.Model。实现方法之一是通过继承tf.keras.Model。
听起来我们需要继承tf.keras.Model
来定义组成子图层的图层。
但是,据我检查,即使我将ResnetIdentityBlock
定义为tf.keras.layers.Layer
的子类,此代码仍然有效。其他两个教程也可以与Layer
一起使用。
除此之外,another tutorial说
模型就像一个图层,但是增加了训练和序列化实用程序。
因此,我不知道tf.keras.Model
和tf.keras.layers.Layer
之间的真正区别是什么,为什么那三个急切执行的教程尽管不使用 training却使用tf.keras.Model
和tf.keras.Model
的序列化实用程序。
为什么在那些教程中我们需要继承tf.keras.Model
?
其他评论
Model
的实用程序仅适用于Layer
(Layers whose call
receive only one input)的特殊子集。因此,我认为像“始终扩展模型,因为模型具有更多功能” 这样的想法是不正确的。而且,它违反了诸如SRP之类的基本编程程序。
答案 0 :(得分:1)
更新
所以评论是:Yes, I know training and serialization utilities exist in Model as I wrote in the question. My question is why TF tutorials need to use Model though they don't use these methods.
在这种情况下,作者可以提供最佳答案,因为您的问题询问为什么他们选择一种方法而不是另一种方法,因为他们俩都可以很好地完成这项工作。为什么做得同样好?好吧,因为Model is just like a Layer, but with added training and serialization utilities.
我们可以争辩说,仅在图层可以胜任工作时使用模型是一个过大的杀伤力,但这可能是一个品味问题。
希望有帮助
PS。
在您提供的eager example和custom layer writing教程中,我们无法将模型替换为图层,因此这些教程不适用于您的问题
使用模型,您可以训练,但是只有图层,您不能。请参见下面的方法列表(不包括内部方法和继承方法):
tf.keras.layers.Layer
activity_regularizer
activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_mask
compute_output_shape
count_params
dtype
dynamic
from_config
get_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
set_weights
trainable
trainable
trainable_variables
trainable_weights
updates
variables
weights
看到了吗?没有合适的方法或评估方法。
tf.keras.Model
compile
evaluate
evaluate_generator
fit
fit_generator
get_weights
load_weights
metrics
metrics_names
predict
predict_generator
predict_on_batch
reset_metrics
run_eagerly
run_eagerly
sample_weights
test_on_batch
train_on_batch