我正在尝试将WALSModel设置为Estimator对象,因此可以对其调用train_and_evaluate。为此,我需要通过将WALSModel传递给model_fn中的EstimatorSpec来将其转换为Estimator。我现在专注于训练操作。
要设置培训操作,我在线找到了指南,并将其定义为:
train_op = tf.group(model.row_update_prep_gramian_op,
model.initialize_row_update_op,
model.update_row_factors(sp_input=input_tensor)[1],
model.col_update_prep_gramian_op,
model.initialize_col_update_op,
model.update_col_factors(sp_input=input_tensor)[1],
)
loss = ?
return tf.estimator.EstimatorSpec(mode=mode, train_op=train_op, loss=loss)
目前sp_input是一个tf.SparseTensor,其中包含我的所有数据(将来我可能会分别对列和行使用批处理)。
我现在必须定义损失,然后将其传递给EstimatorSpec。我不确定如何执行此操作:在文档https://www.tensorflow.org/api_docs/python/tf/contrib/factorization/WALSModel中,他们提到了通过评估update_row_factors和update_col_factors来计算损失的方法,但是我的理解是model_fn不应该这样做:它应该列出自己操作。