如果您使用数据集,生成器等,则tf.keras.Model.fit()的文档说不要指定batch_size
参数,因为它们会创建自己的批处理。
看看我的子类tf.keras.Model的call()
函数中输入张量的形状,看来使用数据集批处理将迫使您的模型处理批处理(即完整批处理)一次传递),而使用fit()
设置批次大小会将批次的单个元素传递给模型。
对吗?我问是因为我遇到tf.keras.backend
的{{1}}和batch_dot()
的问题,这似乎是使用tf.data.Dataset加载虹膜数据的产物。一个类似的体系结构,它不使用TF的数据集,而是调用model.fit()来设置批处理大小,但不会产生这些错误。这些模型使用的自定义权重矢量(在模型的sum()
函数中定义)没有“ batch_size”“占位符”(例如,无),但这似乎只是tf.data.Dataset的问题批处理,而不是model.fit()。
我的主要问题是:
1)您是否必须使用tf.data.Dataset的批处理来显式处理批处理的输入,但是可以假设对model.fit()的批处理来说并不必要吗?
2)每种情况下的批处理工作如何?我在网上找不到很多有关此的信息。