如何使用检查点文件或tf-record获得经过微调的网络的标签字符串?

时间:2018-06-28 15:55:56

标签: tensorflow tensorflow-slim

例如,我使用自己的数据集对VGG网络进行了微调,仅使用两个标签foobar。我已经通过this link通过示例将图像转换为tf.record:

labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

我将基于此新模型构建一个API来预测图像,我的问题是:是否有任何正式方法可从检查点文件或数据集中获取标签字符串(例如{{1 }}返回predict_image("abc.png")字符串)?由于我不知道logits层中哪个节点代表标签foo,哪个节点代表foo

我已经尝试过搜索,但是没有帮助,我仍然是一个tensorflow noobie。

1 个答案:

答案 0 :(得分:0)

该模型(以及附带的检查点文件)没有每个类的名称。 它所具有的只是一定数量的输出神经元,第一个对应于第一类,第二个对应于第二类,依此类推。

如果您想知道哪一个,请查看此行创建的标签文件(最有可能命名为labels.txt):

getParameters()

或者,您可以检查dataset_utils.write_label_file(labels_to_class_names, dataset_dir) 字典的内容:

labels_to_class_names

->模型输出中索引0处的值=类'aaa'等,