Tensorflow Estimator-使用权重操纵成本函数

时间:2019-01-21 13:51:06

标签: tensorflow tensorflow-estimator

我正在尝试建立一个自定义估算器,该估算器将执行以下操作。 在每批反向支持步骤中,我都希望对成本进行一些操作,例如:

loss = tf.squared_difference(y_pred, y)
weighted_loss = tf.multiply(weights, loss)
cost = tf.reduce_sum(weighted_loss) / batch_size

这里的“权重”矩阵是一些外部数据(基本上,每行中的某些元素都为零,因为我不想对其进行反向传播),但这是一个外部数据,我必须提供model_fn训练步骤中每批次的功能。 我该怎么做?如何找到当前训练批次中的记录,并为model_fn提供与此记录相对应的权重矩阵?

0 个答案:

没有答案