TensorFlow - Tflearning错误feed_dict

时间:2016-06-27 22:53:09

标签: python tensorflow deep-learning

我正在研究python中的分类问题。事实是,我在TensorFlow中还不是很好。所以我很久以来就遇到了同样的问题,我不知道如何解决它。我希望你能帮助我:)。

这是我的数据:

X:8000张图片:32 * 32px和3种颜色(rgb),所以我加载一个矩阵X.shape =(8000,32,32,3)

Y:4个等级(1,2,3和4):Y。形状=(8000,1)

这是我的代码:

network = input_data(shape=[None, 32, 32, 3], name='iput')                   
# Step 1: Convolution
network = conv_2d(network, 32, 3, activation='relu')
# Step 2: Max pooling
network = max_pool_2d(network, 2)
# Step 3: Convolution again
network = conv_2d(network, 64, 3, activation='relu')
# Step 4: Convolution yet again
network = conv_2d(network, 64, 3, activation='relu')
# Step 5: Max pooling again
network = max_pool_2d(network, 2)
# Step 6: Fully-connected 512 node neural network
network = fully_connected(network, 512, activation='relu')
# Step 7: Dropout - throw away some data randomly during training to prevent over-fitting
network = dropout(network, 0.5)
# Step 8: Fully-connected neural network with 4 outputs
network = fully_connected(network, 4, activation='softmax')
# Tell tflearn how we want to train the network
network = regression(network, optimizer='adam',
                     loss='categorical_crossentropy',
                     learning_rate=0.001)
model = tflearn.DNN(network)                
model.fit(X, Y)

这是我的错误

  

追踪(最近一次呼叫最后一次):

     

文件“”,第3行,

     

model.fit(X,Y)

     

文件“/home/side/anaconda3/lib/python3.5/site-packages/tflearn/models/dnn.py”,

     

第157行,适合

     

self.targets)

     

文件“/home/side/anaconda3/lib/python3.5/site-packages/tflearn/utils.py”,   第267行,在feed_dict_builder中       feed_dict [net_inputs [i]] = x   IndexError:列表索引超出范围

我还尝试将X传递给(8000,3072)Matrix 和Y为(8000,4)矩阵,例如:

[0 0 1 0 < - Y [0] = 3

0 1 0 0&lt; -Y [1] = 2

...

我重复使用此代码:https://github.com/tflearn/tflearn/blob/master/examples/images/convnet_cifar10.py,用于对cifar10数据进行分类。

感谢您的帮助,

西莉亚

3 个答案:

答案 0 :(得分:4)

另一个选择是添加:

tf.reset_default_graph()

作为代码的第一行

答案 1 :(得分:3)

从源代码中解释:

输入数量与预期输入数量不匹配。 如果您使用的是ipython笔记本,请确保您没有多次运行图形构造单元格。或者将图形结构包含在with tf.Graph().as_default()块中。

答案 2 :(得分:1)

正如motjuste所说,当你使用带有TFLearn的笔记本时,你应该在每次运行代码时重新启动你的内核。

在github中查看此问题:

https://github.com/tflearn/tflearn/issues/360