我正在尝试在Keras中使用GradientTape训练模型。这是代码:
@tf.function
def train_step(x,y):
with tf.GradientTape() as tape:
predictions = model.predict(x)
loss = compute_loss(y, predections)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
history = []
for iter in tqdm(range(num_iters)):
x_batch, y_batch = get_batch(x_train, y_train, batch_dim)
loss = train_step(x_batch, y_batch)
history.append(loss.numpy().mean())
此代码导致以下错误:
ValueError: When using data tensors as input to a model, you should specify the `steps` argument.
但是,如果我尝试按以下方式在函数外部调用预测:
history = []
for iter in tqdm(range(num_iters)):
x_batch, y_batch = get_batch(x_train, y_train, batch_dim)
x_hat = model.predict(x_batch)
我没有错误...
有人可以向我解释为什么我会从Keras得到这种行为吗?
答案 0 :(得分:0)
即使用户在评论中回答了这个问题也是为了社区的利益。
通过将x_batch
和y_batch
的数据类型更改为float32
,然后调用model(x_batch)
来预测输出。
这样,可以通过将train_step
函数保持为@tf.function
来解决问题。