基于tf.slim的多标签分类错误

时间:2018-05-29 12:41:35

标签: tensorflow

构建tfrecord时,我将标签编码为N-hot编码。

然后我评论以下代码:

labels = slim.one_hot_encoding(
labels, dataset.num_classes - FLAGS.labels_offset)

但火车时它会崩溃:

Traceback (most recent call last):
  File "../train_image_classifier_multilabel.py", line 620, in <module>
    tf.app.run()
  File "/root/anaconda3/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 126, in run
    _sys.exit(main(argv))
  File "../train_image_classifier_multilabel.py", line 519, in main
    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
  File "/root/Projects/scene-tf/slim/deployment/model_deploy.py", line 193, in create_clones
    outputs = model_fn(*args, **kwargs)
  File "../train_image_classifier_multilabel.py", line 513, in clone_fn
    label_smoothing=FLAGS.label_smoothing, weights=1.0)
  File "/root/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/losses/losses_impl.py", line 676, in sigmoid_cross_entropy
    logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
  File "/root/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py", line 764, in assert_is_compatible_with
    raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (32, 63) and (32,) are incompatible

sigmoid_cross_entropy遇到问题:

tf.losses.sigmoid_cross_entropy(
logits=logits, multi_class_labels=labels,
label_smoothing=FLAGS.label_smoothing, weights=1.0)

显然,logits的形状是(32,63),标签是(32,)。

似乎one_hot_encoding的评论是错误的。 但我不知道如何解决它。

0 个答案:

没有答案