TensorFlow:对图像进行分类

时间:2019-10-11 19:43:22

标签: python tensorflow classification tensorflow-datasets tensorflow2.0

我正在遵循有关TensorFlow 2.0的图像分类的教程:https://www.tensorflow.org/tutorials/images/classification

该教程显示了如何构建和训练模型,但我不了解如何实际使用模型。

我正在寻找一种传递图像(最好是其路径)并获得某种分类结果的方法。像这样:

result = model.evaluate('path/to/image.jpg')
# result == {'cat': 0.92, 'dog': 0.08}

如何实施?另外,模型保存在哪里以及训练完成后如何访问?

1 个答案:

答案 0 :(得分:2)

对于打印出图像是X%猫,%Y狗,this的特定几率结果的特定情况,特定的tensorflow教程可能更有用。

在其中,他们确实介绍了如何绘制百分比可能性以及使用张量流的大多数基础知识。

训练模型后,可以使用更多代码以图形方式显示结果,例如本教程中的以下代码:

def plot_image(i, predictions_array, true_label, img):
  predictions_array, true_label, img = predictions_array, true_label[i], img[i]
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])

  plt.imshow(img, cmap=plt.cm.binary)

  predicted_label = np.argmax(predictions_array)
  if predicted_label == true_label:
    color = 'blue'
  else:
    color = 'red'

  plt.xlabel("{} {:2.0f}% ({})".format(class_names[predicted_label],
                                100*np.max(predictions_array),
                                class_names[true_label]),
                                color=color)

def plot_value_array(i, predictions_array, true_label):
  predictions_array, true_label = predictions_array, true_label[i]
  plt.grid(False)
  plt.xticks(range(10))
  plt.yticks([])
  thisplot = plt.bar(range(10), predictions_array, color="#777777")
  plt.ylim([0, 1])
  predicted_label = np.argmax(predictions_array)

  thisplot[predicted_label].set_color('red')
  thisplot[true_label].set_color('blue')

然后,使用以下代码,可以对结果进行一些绘图: enter image description here

对于访问并保存模型,以下tensorflow tutorial可能有用。

希望有帮助!