tf.contrib.learn快速入门: - 将n_classes更改为2不起作用

时间:2016-10-26 08:43:53

标签: python tensorflow

我正在试用tf.contrib.learn Quickstart,它在使用教程中给出的代码时有效。但是,如果我将训练和测试集改为只有2个分类(即只有2个虹膜种类),我得到以下输出和错误:

WARNING:tensorflow:Change warning: default value of `enable_centered_bias` will change after 2016-10-09. It will be disabled by default.Instructions for keeping existing behaviour:
Explicitly set `enable_centered_bias` to 'True' if you want to keep existing behaviour.
WARNING:tensorflow:Using default config.
Traceback (most recent call last):
  File "test.py", line 34, in <module>
    steps=2000)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/estimators/dnn.py", line 435, in fit
    max_steps=max_steps)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 333, in fit
    max_steps=max_steps)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 662, in _train_model
    train_op, loss_op = self._get_train_ops(features, targets)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 963, in _get_train_ops
    _, loss, train_op = self._call_model_fn(features, targets, ModeKeys.TRAIN)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 944, in _call_model_fn
    return self._model_fn(features, targets, mode=mode, params=self.params)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/estimators/dnn.py", line 258, in _dnn_classifier_model_fn
    weight=_get_weight_tensor(features, weight_column_name))
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/losses/python/losses/loss_ops.py", line 329, in sigmoid_cross_entropy
    logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_shape.py", line 750, in assert_is_compatible_with
    raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (?, 1) and (?,) are incompatible

我改变的唯一代码是创建分类器(将n_classes从3更改为2):

# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=2,
                                            model_dir="/tmp/iris_model")

有人可以解释为什么这不起作用吗?

1 个答案:

答案 0 :(得分:1)

我遇到了同样的错误,显然这是来自tensorflow的错误,请参阅下面的链接以获取更多信息:

Shape error using Tensorflow (tf.learn, DNNClassifier)

我将set n_classes修复为3,即使我只有2个类