例如,我使用自己的数据集对VGG网络进行了微调,仅使用两个标签foo
和bar
。我已经通过this link通过示例将图像转换为tf.record:
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
我将基于此新模型构建一个API来预测图像,我的问题是:是否有任何正式方法可从检查点文件或数据集中获取标签字符串(例如{{1 }}返回predict_image("abc.png")
字符串)?由于我不知道logits层中哪个节点代表标签foo
,哪个节点代表foo
我已经尝试过搜索,但是没有帮助,我仍然是一个tensorflow noobie。
答案 0 :(得分:0)
该模型(以及附带的检查点文件)没有每个类的名称。 它所具有的只是一定数量的输出神经元,第一个对应于第一类,第二个对应于第二类,依此类推。
如果您想知道哪一个,请查看此行创建的标签文件(最有可能命名为labels.txt):
getParameters()
或者,您可以检查dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
字典的内容:
labels_to_class_names
->模型输出中索引0处的值=类'aaa'等,