Keras训练使用混洗的tf.data:如果训练被中断,如何在最后一次数据迭代/最后保存的检查点的顺序上继续训练

时间:2020-06-23 06:47:53

标签: tensorflow keras tensorflow2.0 tensorflow-datasets tf.keras

我正在使用keras model.fit进行培训,数据来自tf.records,并加载到tf.data对象中,该对象使用.shuffle对数据进行混排。我还使用callbacks.ModelCheckpoint每隔x步/批次保存模型。

有时候我的云实例在某个时期完成之前就断开连接或崩溃,但是在y步骤中的模型已保存到我的驱动器中。

在训练另一个时期之前,我想完成对该时期(我有很长的时期)中数据的训练,因此每个数据示例每个时期都要进行一次训练。

有没有办法获取数据的原始顺序,以及数据在其中最后保存模型的位置?

到目前为止我发现的东西

您似乎可以通过设置种子来在.shuffle中设置特定顺序。但是,改组仅在缓冲区中发生,因此我不确定100%是否设置种子会完美地重现该顺序。另外,我不确定reshuffle_each_iteration会如何工作。在每个时期之后使用不同的种子吗?如果是这样,我想一种变通方法是一次只训练1个纪元,每个纪元都指定种子。

即使我确实获得了训练订单的副本,也不确定如何找到该订单上次保存模型的位置,然后从该点开始训练。我必须获得订单的一个想法是手动遍历数据集,直到找到为止。尽管我不确定model.fit()是从该顺序继续还是重新开始。 F

要从上次保存模型的位置获取步骤/批号,我可以将其记录在某个地方。

这些解决方案似乎是一个粗略的解决方法,我想知道Keras中是否有些功能我可能会忽略以帮助解决此问题。

2 个答案:

答案 0 :(得分:1)

似乎没有构建喀拉拉邦的方法,但是如果我错了,请纠正我。

我的方法

Dataset.shuffle内部使用初始种子值生成种子,以在reshuffle_each_iteration=True迭代期间将其用于改组。因此,针对特定时期重新创建相同的顺序,并在该特定批次上继续训练该时期,我们必须使用相同的种子重新创建数据集,并将数据集迭代器移至相同的时期和相同的批次。

调试

为进行调试并确保按相同顺序生成历元和批次,我们需要一种方法来打印如何在每个历元批次中拾取数据点。这在技巧上是棘手的,因此出于调试目的,我将使用回归问题并将地面实数作为序列号。然后,我可能会遭受自定义损失,在此我可以打印基本事实并使用户的订单正确无误。

模型和数据

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import keras.backend as K


# Data
x_train = np.random.randn(15, 10).astype("float32")
y_train = np.arange(15).astype("float32")

# Custom MSE looss just to track the order in which data is picked up
def my_mse(y_true, y_pred):
    tf.print(tf.keras.backend.flatten(y_true))
    loss = K.square(y_pred - y_true)
    loss = K.sum(loss, axis=1)
    return loss

# Model
def get_model():
    inputs = keras.Input(shape=(10))    
    outputs = layers.Dense(1, activation="linear")(inputs)
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    model.compile(
        optimizer="rmsprop",
        loss=my_mse,
    )
    return model

数据集

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(8)

epochs = 2

print ("Runs 1")
for e in range(epochs):
  for i, (x, y) in enumerate(train_dataset):
    print (e, i, y)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(8)
print ("Runs 2")
for e in range(epochs):
  for i, (x, y) in enumerate(train_dataset):
    print (e, i, y)

输出:

Runs 1
0 tf.Tensor([1. 3. 5. 7. 4. 0. 8. 2.], shape=(8,), dtype=float32)
1 tf.Tensor([ 6. 11. 10. 14.  9. 12. 13.], shape=(7,), dtype=float32)
2 tf.Tensor([4. 2. 5. 8. 1. 9. 7. 3.], shape=(8,), dtype=float32)
3 tf.Tensor([13. 10.  0. 14.  6. 11. 12.], shape=(7,), dtype=float32)
4 tf.Tensor([ 0.  1.  5.  6.  9.  3.  7. 14.], shape=(8,), dtype=float32)
5 tf.Tensor([13.  8.  4. 10.  2. 12. 11.], shape=(7,), dtype=float32)
Runs 2
0 tf.Tensor([1. 3. 5. 7. 4. 0. 8. 2.], shape=(8,), dtype=float32)
1 tf.Tensor([ 6. 11. 10. 14.  9. 12. 13.], shape=(7,), dtype=float32)
2 tf.Tensor([4. 2. 5. 8. 1. 9. 7. 3.], shape=(8,), dtype=float32)
3 tf.Tensor([13. 10.  0. 14.  6. 11. 12.], shape=(7,), dtype=float32)
4 tf.Tensor([ 0.  1.  5.  6.  9.  3.  7. 14.], shape=(8,), dtype=float32)
5 tf.Tensor([13.  8.  4. 10.  2. 12. 11.], shape=(7,), dtype=float32)

是的,可以复制种子。

现在,让我们编写一种将数据集转发到某个特定时期和批次组合的方法

def forward(dataset, n=None):
  if not n:
    return dataset

  i = 0  
  while True:
    for _ in dataset:        
        i += 1
        if i == n:
          return dataset

测试案例:

让它正常运行并观察顺序

从头开始的数据

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), None)

model = get_model()
model.fit(train_dataset, epochs=3, verbose=0, workers=4, shuffle=False)

输出:

[7 3 6 10]
[11 0 1 2]
[8 14 9 13]
[12 5 4]
[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

来自数据集第n个状态的数据

让我们的数据集进行第4次迭代并进行训练

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), 4)

model = get_model()
model.fit(train_dataset, epochs=3, verbose=0, workers=4, shuffle=False)

输出:

[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

很好,现在我们知道如何正确转发数据集。现在让我们编写回调来跟踪当前的迭代编号:

用于跟踪迭代的自定义回调(历时批处理组合)

现在,我们需要确定模型要检查的时期和批次组合。如果我们有此信息,我们可以加载最后一个检查点模型,并将我们的数据集转发到其批次和纪元组合,然后继续进行训练。我们将使用回叫

class MyCustomCallback(tf.keras.callbacks.ModelCheckpoint, keras.callbacks.Callback):
    def __init__(self, the_id=0, **args):
      self.the_id = the_id
      self.epoch = 0
      super().__init__(**args)

    def _save_model(self, epoch, logs):
      logs['the_id'] = self.the_id
      super()._save_model(epoch, logs)

    def on_batch_end(self, batch, logs={}):
      self.the_id += 1
      super().on_batch_end(batch, logs)

checkpoint_filepath = 'checkpoint-{the_id}'
model_checkpoint_callback = MyCustomCallback(
    filepath=checkpoint_filepath,
    save_freq=2,
    save_best_only=False)

model = get_model()

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), None)

model.fit(train_dataset, epochs=5, verbose=0, callbacks=[model_checkpoint_callback], workers=4, shuffle=False)

输出:

[7 3 6 10]
[11 0 1 2]
[8 14 9 13]
[12 5 4]
[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

我们每两批检查一次。因此,让我们假设它崩溃了,最后一个检查点是checkpoint-4。我们可以加载该模型,并将我们的数据集转发到4,然后继续训练。

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), 4)

model = get_model()
model.fit(train_dataset, epochs=2, verbose=0, workers=4, shuffle=False)

输出:

[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

答案 1 :(得分:0)

我想您想恢复随机播放顺序,以避免在此时期内重复某些样本。

根据shuffle description,在未完成的时期内,您的模型只能访问数据集中的第一个current_step_number + shuffle_buffer_size个样本。

因此,当恢复训练时,如果您知道已处理了多少步骤,则可以跳过此步骤+跳过shuffle_buffer_size步骤,然后将继续进行以下样本的训练,这在当前纪元内尚未观察到。

请注意,在此期间,根本不会观察到来自数据集第一部分的一些随机shuffle_buffer_size样本。正如您所说,您的时期很长,因此,可能您有很多数据,因此丢失shuffle_buffer_size样本对您来说应该不是问题。

因此,在保存检查点期间,还保存了步骤号,然后在加载检查点后,使用跳过的步骤创建数据集副本(使用dataset.skip),然后将model.fit与此较小的数据集一起使用一个时期(以完成当前时期),然后继续您通常的训练方式。