张量流排名模块的慢速预测

时间:2019-12-30 10:04:11

标签: python tensorflow machine-learning ranking tensorflow-estimator

我当前正在使用TensorFlow Ranking module进行推荐任务。具体来说,我正在出于自己的目的修改this tutorial file。我不能说该教程对新用户非常友好。因为这是我第一次与TensorFlow进行交互,所以我只是想使其运行。

您可能会注意到,the example file没有说出如何进行预测,因此在完成模型训练后,我修改了train_and_eval()函数以进行预测。这是我的代码。

def _train_op_fn(loss):
        """Defines train op used in ranking head."""
        update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
        minimize_op = optimizer.minimize(
                loss = loss, global_step = tf.compat.v1.train.get_global_step())
        train_op = tf.group([minimize_op, update_ops])
        return train_op

ranking_head = tfr.head.create_ranking_head(
        loss_fn = tfr.losses.make_loss_fn(loss),
        eval_metric_fns = get_eval_metric_fns(),
        train_op_fn = _train_op_fn
)

estimator =tf.estimator.Estimator(
        model_fn=tfr.model.make_groupwise_ranking_fn(
                group_score_fn=make_score_fn(),
                group_size = group_size,
                transform_fn=make_transform_fn(),
                ranking_head=ranking_head),
        config=tf.estimator.RunConfig(
                output_dir, save_checkpoints_steps=1000))

def predict_(feature_dict = {}):
    if feature_dict == {}:
        feature_dict = input_fn()
    pred_fn, pred_hook = get_eval_inputs(feature_dict)
    generator_ = estimator.predict(input_fn = pred_fn, hooks = [pred_hook])
    pred_list = list(generator_)

    return pred_list

predict_函数使用字典

{'feature 1':[doc 1 score, doc 2 score...], 
 'feature 2':[doc 1 score, doc 2 score...], 
 ...}

并返回一个列表,该列表按该顺序为所有文档评分。 (或者至少我认为应该这样做)

预测结果非常好。问题是,它真的很慢。预测400个文档需要1秒钟以上的时间(我只有4个功能)。我怎样才能更快?

0 个答案:

没有答案