我有以下代码,几乎取自 https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch。
class BiLSTM(tf.keras.Model):
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute our own loss
loss = self.compiled_loss(y, y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Compute our own metrics
loss_tracker.update_state(loss)
rmse_metric.update_state(y, y_pred)
return {"loss": loss_tracker.result(), "rmse": rmse_metric.result()}
然而,我的 data
是 <_GroupByWindowDataset shapes: ((None, None, 46), (None, None, 1)), types: (tf.float64, tf.float64)>
,所以我相信不能解压为 x, y = data
。谁能给我一些关于在自定义训练步骤时如何处理 tf.data 对象的指示?