运行tf.Strategy与tf.data batch()时的批处理大小

时间:2020-06-11 19:05:29

标签: tensorflow training-data tf.keras

我想在运行tf.distribute策略时显示批处理大小。我是这样创建自定义Keras图层的:

class DebugLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def build(self, input_shape):
        pass

    def call(self, inputs):
        print_op = tf.print("******Shape is:", tf.shape(inputs) , name='shapey')
        #print_op = tf.print("Debug output:", loss, y_true, y_true.shape)
        with tf.control_dependencies([print_op]):
            return tf.identity(inputs)

第1季度:每批每名工人的示例数

如果我与一个工作人员一起运行,则批处理大小为128,这是我在tf.data数据集流.batch(128)中设置的大小。

如果我有两个工人,每个工人输出128。我想知道在每个工人上运行了多少示例? 同时正在运行多少个示例?

第二季度:正确的steps_per_epoch

在我的Model.fit()调用中,我指定了steps_per_epoch,并且在数据流中有一个.repeat。如果我的训练集包含1024个样本,我有2个工作人员,并且我的.batch设置为128,那么steps_per_epoch设置为一个历时应该是什么?

1 个答案:

答案 0 :(得分:0)

使用tf.data操作时,通常有一种.batch()方法应用于数据。假设该值为128。这将是每批运行的示例总数,而与工作人员数量无关。如果...

  • 使用了1个工作程序,每个训练步骤将运行128个示例。
  • 使用2名工人,每个工人每个培训步骤将运行64个示例。
  • 使用了3名工人,每个工人每个训练步骤将运行约42个示例。

对于3工情况,由于128/3不是整数值,我不确定确切的数字。

要设置steps_per_epoch,请用样本总数除以您在.batch()中设置的批量大小。因此,对于我在问题中的示例,该值为8,即1024/128。

这有点不方便,因为您需要了解训练示例的数量,如果它们发生变化,则需要调整steps_per_epoch值。另外,如果不是整数倍,则需要确定是否应该对steps_per_epoch值进行四舍五入。