GradientDescentOptimizer在急切模式下需要不带参数的损失函数

时间:2019-04-25 08:13:10

标签: tensorflow tensorflow2.0

tf.train.Optimizer 的API说:  “启用急切执行后,损失应该是不带任何参数并计算要最小化的值的Python函数。”

我很困惑,损失函数如何在没有给出预测和标签的情况下计算损失?

我尝试了 tf.losses.mean_squared_error ,但是,正如预期的那样,此操作不起作用,因为它需要参数。

opt = tf.train.GradientDescentOptimizer(learning_rate=.1)

opt_op = opt.minimize(tf.losses.mean_squared_error, var_list=[model.W, model.b])
# TypeError: mean_squared_error() missing 2 required positional arguments: 'labels' and 'predictions'

1 个答案:

答案 0 :(得分:0)

tf.losses.mean_squared_error需要两个位置参数。

给出pred的预测和label的预期结果。

pred  # computed
label # computed
def custom_loss() :
  return tf.losses.mean_squared_error(pred, label)

opt.minimize(custom_loss, var_list=[model.W, model.b])