初始v3网络中的标签错误(Tensorflow)

时间:2016-12-14 12:05:33

标签: python computer-vision tensorflow deep-learning conv-neural-network

我在这里使用build_image_data.py脚本:https://github.com/tensorflow/models/blob/master/inception/inception/data/build_image_data.py 与将数据集转换为TFRecords格式的文档完全一样。在inception_train.py脚本中,当我打印图像和标签时,标签与图像不对应,因此我无法进行正确的培训。我使用的数据集是不平衡的(类之间的图像数量不同)。我也在类和标签之间使用相同数量的图像进行了测试仍然是错误的。 tensorflow代码不受影响,我所做的唯一更改是不在image_processing.py脚本中应用扭曲。由于我的TFR转换或者因为返回图像和标签的image_processing.py脚本,我不知道标签是否错误。有什么想法吗?

Tensorflow版本:0.10 操作系统:Ubuntu 14.04

用于检查它的inception_train.py脚本中的代码片段是:

labs = sess.run(labels)
imgs = sess.run(images)


for i in range(FLAGS.batch_size):
  print('Label ' + str(labs[i]))
  plt.imshow(imgs[i, :, :, :])
  plt.show()

1 个答案:

答案 0 :(得分:1)

你应该同时运行两个:也就是说,只调用一次sess.run。像这样:

imgs,labs = sess.run([images,labels])# ONLY ONE CALL 

for i in range(FLAGS.batch_size):
    print('Label ' + str(labs[i]))
    plt.imshow(imgs[i, :, :, :])
    plt.show()