CNN RNN集成图像

时间:2017-12-12 10:18:58

标签: python-3.x tensorflow conv-neural-network lstm tflearn

我试图通过以下代码将CNN和LSTM整合到MNIST图像中:

from __future__ import division, print_function, absolute_import
import tensorflow as tf
import tflearn
import numpy as np
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression

import tflearn.datasets.mnist as mnist
height = 128
width = 128
X, Y, testX, testY = mnist.load_data(one_hot=True)
X = X.reshape([-1, 28, 28, 1])
testX = testX.reshape([-1, 28, 28, 1])

# Building convolutional network
network = tflearn.input_data(shape=[None, 28, 28,1], name='input')
network = tflearn.conv_2d(network, 32, 3, activation='relu',regularizer="L2")
network = tflearn.max_pool_2d(network, 2)
network = tflearn.local_response_normalization(network)
network = tflearn.conv_2d(network, 64, 3, activation='relu',regularizer="L2")
network = tflearn.max_pool_2d(network, 2)
network = tflearn.local_response_normalization(network)
network = fully_connected(network, 128, activation='tanh')
network = dropout(network, 0.8)
network = fully_connected(network, 256, activation='tanh')
network = dropout(network, 0.8)
network = tflearn.reshape(network, [-1, 1, 28*28])
#lstm
network = tflearn.lstm(network, 128, return_seq=True)
network = tflearn.lstm(network, 128)
network = tflearn.fully_connected(network, 10, activation='softmax')
network = tflearn.regression(network, optimizer='adam',
                     loss='categorical_crossentropy', name='target')

#train
model = tflearn.DNN(network, tensorboard_verbose=0)
model.fit(X, Y, n_epoch=1, validation_set=0.1, show_metric=True,snapshot_step=100)

CNN接受4D张量,LSTM接受3D。因此,我通过以下方式重新塑造了网络:network = tflearn.reshape(network,[-1,1,28 * 28])

但是在运行错误时:

  

InvalidArgumentError(参见上面的回溯):重塑的输入是a   张量为16384的值,但要求的形状需要倍数   of 784 [[Node:Reshape / Reshape = Reshape [T = DT_FLOAT,   T形= DT_INT32,   _device =" / job:localhost / replica:0 / task:0 / cpu:0"](Dropout_1 / cond / Merge,Reshape / reshape / shape)]]

我不清楚为什么他们需要一个16384的张量,即使我硬编码128 * 128它仍然不起作用!我根本无法继续。

1 个答案:

答案 0 :(得分:2)

错误在这一行:

 <tr v-for="article in articles">
    <i class='fa fa-cog' :class='article.class'> 
 </tr>

之前的FC图层有network = tflearn.reshape(network, [-1, 1, 28*28]) ,因此无法将其重新整形为n_units=256。将此行更改为:

28*28

请注意,您正在将由CNN制作的功能,而不是输入的MNIST 图像提供给LSTM。