Tensorflow Iris渴望执行TUTORIAL总是预测setosa

时间:2018-09-20 18:03:19

标签: python tensorflow

使用以下代码:https://www.tensorflow.org/tutorials/eager/custom_training_walkthrough

当我运行3个示例对新数据进行3个预测时(该新数据显示在本教程的结尾),它始终是“ setosa”。但是,教程作者声称他们还能得到其他东西。

教程作者声称他们得到了这些预测:

Example 0 prediction: Iris setosa (97.4%)
Example 1 prediction: Iris versicolor (81.9%)
Example 2 prediction: Iris virginica (69.3%)

通过对比我总是得到:

Example 0 prediction: Iris setosa (98.8%)
Example 1 prediction: Iris setosa (96.0%)
Example 2 prediction: Iris setosa (90.6%)

我什至在命令行中重新启动了jupyter守护程序,并且多次重绘了此笔记本,并且始终通过此特定代码段来预测“ setosa”。

这是我正在谈论的Google教程代码段:

predict_dataset = tf.convert_to_tensor([
    [5.1, 3.3, 1.7, 0.5,],
    [5.9, 3.0, 4.2, 1.5,],
    [6.9, 3.1, 5.4, 2.1]
])
class_names = ['Iris setosa', 'Iris versicolor', 'Iris virginica']
predictions = model(predict_dataset)
for i, logits in enumerate(predictions):
    class_idx = tf.argmax(logits).numpy()
    p = tf.nn.softmax(logits)[class_idx]
    name = class_names[class_idx]
    print("Example {} prediction: {} ({:4.1f}%)".format(i, name, 100*p))

您想运行本教程,看看是否遇到同样的问题,然后报告?

我已经花了一整天的时间,不会复制在这里尝试过的所有内容。您只需运行Google的教程代码,看看它是否有效。

以下可能是导致问题的一些原因。

  1. 也许这是一个jupyter问题。待办事项:尝试在 常规python .py文件。输出什么预测?

  2. 也许这是重量问题的初始化,而权重是 总是从零开始。待办事项:使用随机权重初始化 某种。

  3. 也许这是一个jupyter重新启动问题。待办事项:重新启动它,看看是否 产生相同的问题输出。结果:每次输出都相同。这表明模型训练是一致的,初始权重的随机化不会发生,但是如果这样做的话会很好,所以我应该尝试一下。

  4. 也许这是一个conda版本的问题,而这台笔记本是某种方式 使用与我想象不同的环境。 TODO:不知道。

  5. 也许这是Google提供的教程代码错误。待办事项:看是否 其他人在本教程代码中也遇到相同的问题。

  6. 也许这是我的代码问题。待办事项:用 严格来说是Google的教程代码,一点也没有。

  7. 也许cross_entropy应该汇总为一个标量 最小化之前的值,例如使用tf.reduce_mean()。 待办事项:这样做。

  8. 也许这是我TF版本中的错误,并且已在较新版本中修复。 待办事项:看看是否有人在教程代码中遇到了相同的错误。

  9. 也许这是GPU代码错误,并且可以在CPU上正常工作。待办事项:使用CPU。

在本教程代码中,还有其他问题。我一次只关注一个问题。例如,测试集是使用训练集文件生成的,但是应该已经对其进行编码以使用测试集文件。我修复了教程代码以加载测试集文件。

我的库版本是:

Python version 3.5.4 |Anaconda custom (64-bit)| (default, Nov 20 2017, 18:44:38) 
[GCC 7.2.0]
Numpy version 1.14.5
Tensorflow version 1.8.0
Matplotlib version 2.2.2
Pandas version 0.21.1
Scipy version 1.0.0
$ jupyter notebook --version = 4.4.1

我正在使用GPU执行TensorFlow引擎。

0 个答案:

没有答案