如何从训练有素的Tensorflow分类器中获取类预测?

时间:2018-03-15 15:21:06

标签: python tensorflow machine-learning

我已经训练过二元分类器模型。模型类包含self.costself.initial_stateself.final_stateself.logits参数。只需使用tf.train.Saver

保存即可
saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
saver.save(session, 'model.ckpt')

模型训练完成后,我将其加载为:

with tf.variable_scope("Model", reuse=False):
    model = MODEL(config, is_training=False)

with tf.Session() as session:
    saver = tf.train.Saver(tf.global_variables())
    saver.restore(session, 'model.ckpt')

但是,我的model.run函数返回交叉熵丢失,这是图中的最后一个操作。我不需要丢失,我需要每个批处理元素的模型预测

logits = tf.sigmoid(tf.nn.xw_plus_b(last_layer, self.output_w, self.output_b))

其中last_layer800x1矩阵,然后我将其重新整形为32x25x1(batch_size,sequence_length,1)矩阵。正是该矩阵包含[0-1]范围内的模型预测值。

那么,我如何使用这个模型来预测单个元素矩阵1x1x1

1 个答案:

答案 0 :(得分:1)

添加计算准确性所需的OP,就像我在下面复制的那样(只是从我手边最近的模型中复制出来)。

  self.logits_flat = tf.argmax(logits, axis=1, output_type=tf.int32)
  labels_flat = tf.argmax(labels, axis=1, output_type=tf.int32)
  accuracy = tf.cast(tf.equal(self.logits_flat, labels_flat), tf.float32, name='accuracy')

现在,当您运行模型时(在测试或培训期间),将sess.run调用的准确性添加为:

sess.run([train_op, accuracy], feed_dict=...)

sess.run([accuracy, logits], feed_dict=...)

当您致电sess.run时,您所做的就是告诉tensorflow计算您要求的任何值。您需要将它传递给执行这些计算所需的任何数据。 Tensorflow是懒惰的,它不会执行任何明确生成您请求的结果所必需的计算。例如。如果您运行上面列出的sess.run的第二个版本,则不会运行优化程序,因此您的权重将不会更新。

请注意,您可以在训练网络后添加OP,因为它们实际上都没有添加任何变量,因此它们不会影响保存/恢复过程。