ValueError:输入0与图层flatten_2不兼容

时间:2019-11-11 07:27:43

标签: python tensorflow keras deep-learning

我正在尝试合并CNN-LSTM分类模型,但出现以下错误:

  

ValueError:输入0与图层flatten_2不兼容:预期的min_ndim = 3,找到的ndim = 2

环境:

  

Python 3.5

     

Keras 2.2.0

     

Tf-GPU 1.6.0

关于如何解决此问题的任何想法?非常感谢!

from keras.layers import Convolution2D, MaxPooling2D, Flatten, Reshape
from keras.models import Sequential
from keras.utils.np_utils import to_categorical
from keras.layers.wrappers import TimeDistributed
from keras.layers.pooling import GlobalAveragePooling1D
import gc
import numpy as np


timesteps = 100;
number_of_samples = 2500;
nb_samples = number_of_samples;
frame_row = 32;
frame_col = 32;
channels = 3;

nb_epoch = 1;
batch_size = timesteps;

data = np.random.random((2500, timesteps, frame_row, frame_col, channels))
label = np.random.randint(4, size=(2500, 1))

X_train = data[0:2000, :]
y_train = label[0:2000]
y_train = to_categorical(y_train)
X_test = data[2000:, :]
y_test = label[2000:, :]

# %%

model = Sequential();

model.add(TimeDistributed(Convolution2D(32, 3, 3, border_mode='same'), input_shape=(100, 32, 32, 3)))
model.add(TimeDistributed(Convolution2D(32, 3, 3, border_mode='same'), input_shape=(100, 32, 32, 3)))
model.add(TimeDistributed(Activation('relu')))
model.add(TimeDistributed(Convolution2D(32, 3, 3)))
model.add(TimeDistributed(Activation('relu')))
model.add(TimeDistributed(MaxPooling2D(pool_size=(2, 2))))
model.add(TimeDistributed(Dropout(0.25)))

model.add(TimeDistributed(Flatten()))
model.add(TimeDistributed(Dense(512)))

model.add(TimeDistributed(Dense(35, name="first_dense")))

model.add(LSTM(20, return_sequences=True, name="lstm_layer"));

# %%
model.add(TimeDistributed(Dense(4), name="time_distr_dense_one"))
model.add(GlobalAveragePooling1D(name="global_avg"))
model.add(Flatten())
model.add(TimeDistributed(Dense(4, activation="softmax"), name="time_distr_dense"))

# %%

model.compile(loss='categorical_crossentropy',
             optimizer='adam',
             metrics=['accuracy'])
model.fit(X_train, y_train, epochs=3, validation_split=0.1, batch_size=32, verbose=2)

gc.collect()

0 个答案:

没有答案