自定义训练步骤时处理 tf.data 对象

时间:2021-03-16 17:59:29

标签: tensorflow keras tf.keras

我有以下代码,几乎取自 https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch

class BiLSTM(tf.keras.Model):
    
    def train_step(self, data):
        
    x, y = data

        
    with tf.GradientTape() as tape:
            
         y_pred = self(x, training=True)  # Forward pass
            
         # Compute our own loss
            
         loss = self.compiled_loss(y, y_pred)

        
    # Compute gradients
        
    trainable_vars = self.trainable_variables
        
    gradients = tape.gradient(loss, trainable_vars)

        
    # Update weights
        
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        
    # Compute our own metrics
        
    loss_tracker.update_state(loss)
        
    rmse_metric.update_state(y, y_pred)
        
    return {"loss": loss_tracker.result(), "rmse": rmse_metric.result()}


然而,我的 data<_GroupByWindowDataset shapes: ((None, None, 46), (None, None, 1)), types: (tf.float64, tf.float64)>,所以我相信不能解压为 x, y = data

。谁能给我一些关于在自定义训练步骤时如何处理 tf.data 对象的指示?

0 个答案:

没有答案
相关问题