如何在model_fn中设置WALSModel损失以初始化EstimatorSpec?

时间:2019-05-02 20:15:55

标签: python tensorflow tensorflow-estimator

我正在尝试将WALSModel设置为Estimator对象,因此可以对其调用train_and_evaluate。为此,我需要通过将WALSModel传递给model_fn中的EstimatorSpec来将其转换为Estimator。我现在专注于训练操作。

要设置培训操作,我在线找到了指南,并将其定义为:

    train_op = tf.group(model.row_update_prep_gramian_op,
                        model.initialize_row_update_op,
                        model.update_row_factors(sp_input=input_tensor)[1],
                        model.col_update_prep_gramian_op,
                        model.initialize_col_update_op,
                        model.update_col_factors(sp_input=input_tensor)[1],
                        )

    loss = ?

    return tf.estimator.EstimatorSpec(mode=mode, train_op=train_op, loss=loss)

目前sp_input是一个tf.SparseTensor,其中包含我的所有数据(将来我可能会分别对列和行使用批处理)。

我现在必须定义损失,然后将其传递给EstimatorSpec。我不确定如何执行此操作:在文档https://www.tensorflow.org/api_docs/python/tf/contrib/factorization/WALSModel中,他们提到了通过评估update_row_factors和update_col_factors来计算损失的方法,但是我的理解是model_fn不应该这样做:它应该列出自己操作。

  • 设置培训操作的方式正确吗?
  • 如何计算损失?

0 个答案:

没有答案