Tensorflow预测特定ID

时间:2018-10-20 09:14:01

标签: python tensorflow

使用Tensorflow数据集API进行

im,并尝试预测多标签分类。它的工作正常,但最终的预测没有对应的id,所以我不知道谓词属于什么。

我正在使用以下代码创建数据集并进行预测:

def test_input_fn():
  filenames = tf.train.match_filenames_once("../input/test/*_green.png")
  dataset = tf.data.Dataset.from_tensor_slices(filenames)

  def _parse_image_data(filename):
      image_string = tf.read_file(filename)
      image_decoded = tf.image.decode_png(image_string, channels=1)
      image_reshape = tf.reshape(image_decoded, [512*512*1])
      return image_reshape

  return dataset.map(_parse_image_data).batch(8)

pred_result = estimator.predict(test_input)

是否有一种简单的方法可以将ID附加到预测中。

1 个答案:

答案 0 :(得分:0)

您可以在return image_reshape,filename函数中_parse_image_data(filename),并相应地修改模型函数,以便进行预测时,它不仅返回预测的标签,还返回文件名。

更详细地讲,假设您的模型函数为model_fn(features, labels, mode, params),现在您的features不再只是image_reshape,而是image_shape,filename,因此,基本上,无论您在哪里使用{ {1}}至features;您可能还需要在features[0]后面附加features[1](我假设您是这样命名变量的,如果可以向我展示您的模型函数,我可以提供更具体的代码)。

或者,如果您想要一个单独的功能,该功能只打印predictions内容:

filenames

在以上代码中,输出按原样排列,因此它应该与您的预测相同。