tensorflow cifar10教程失败

时间:2016-12-28 12:04:38

标签: tensorflow

我已从教程here中的链接下载了CIFAR10代码,并尝试运行该教程。我用命令

运行它
python cifar10_train.py

启动正常并按预期下载数据文件。当它尝试打开输入文件时,它会失败并显示以下跟踪:

Traceback (most recent call last):
  File "cifar10_train.py", line 120, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 43, in run
    sys.exit(main(sys.argv[:1] + flags_passthrough))
  File "cifar10_train.py", line 116, in main
    train()
  File "cifar10_train.py", line 63, in train
    images, labels = cifar10.distorted_inputs()
  File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10.py", line 157, in distorted_inputs
    batch_size=FLAGS.batch_size)
  File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10_input.py", line 161, in distorted_inputs
    read_input = read_cifar10(filename_queue)
  File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10_input.py", line 87, in read_cifar10
    tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
TypeError: strided_slice() takes at least 4 arguments (3 given)

果然,当我调查代码时,在cifar10_input.py中调用strided_slice()只有3个参数:

tf.strided_slice(record_bytes, [0], [label_bytes])

而tensorflow文档确实表明必须至少有4个参数。

出了什么问题?我已经下载了最新的tensorflow(0.12),并且我正在运行cifar代码的主分支。

1 个答案:

答案 0 :(得分:2)

github进行了一些讨论后,我进行了以下更改,似乎可以使其发挥作用:

在cifar10_input.py

-  result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
+  result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)



-  depth_major = tf.reshape( tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]),      [result.depth, result.height, result.width])
+  depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), [result.depth, result.height, result.width])

然后在cifar10_input.py和cifar10.py中,我不得不搜索&#34;弃用&#34;无论我在哪里找到它,都要根据我在api指南中读到的内容(希望正确)替换它。例如:

-  tf.contrib.deprecated.image_summary('images', images)
+  tf.summary.image('images', images)

 - tf.contrib.deprecated.histogram_summary(tensor_name + '/activations', x)
 - tf.contrib.deprecated.scalar_summary(tensor_name + '/sparsity',
 + tf.summary.histogram(tensor_name + '/activations', x)
 + tf.summary.scalar(tensor_name + '/sparsity',

现在似乎正在愉快地徘徊。我将查看它是否完成正常,以及我在上面所做的更改是否提供了所需的诊断输出。

我仍然希望听到距离代码更近的人的确切答案。