检查输入时出错:预期lstm_4_input具有3维,但数组的形状为(8,23,1,16)

时间:2018-07-31 15:05:57

标签: python machine-learning neural-network

我一直遇到这个错误,我不确定如何解决。我知道我的输入期望3维,但是我的形状是4维。有人可以指导我如何缩小范围吗?我已经读过3维的数组,数组的形状小于3个左右。但是我还没有看到更大的形状和维的溢出解决方案...

def generator(batch_size,from_list_x,from_list_y):

    assert len(from_list_x) == len(from_list_y)
    total_size = len(from_list_x)

    while True: #keras generators should be infinite

        for i in range(0,total_size,batch_size):
            yield np.array(from_list_x[i:i+batch_size]), np.array(from_list_y[i:i+batch_size])

model = Sequential()
model.add(LSTM(1, input_shape=(1,16),return_sequences=True))
model.add(Flatten())
model.add(Dense(1, activation='tanh'))
model.compile(loss='mae', optimizer='adam', metrics=['accuracy'])
model.summary()

    # fit network
pyplot.figure(figsize=(16, 25))
# for i in range(len(train)):
history = model.fit_generator(generator(8, train_X, train_Y), epochs=20, steps_per_epoch = 7, verbose=0, shuffle=True)
print('loss ', str(i), history.history['loss'][len(history.history['loss'])-1],'\n')

此行出现错误:

history = model.fit_generator(generator(8, train_X, train_Y), epochs=20, steps_per_epoch = 7, verbose=0, shuffle=True)

跟踪:

ValueError

        Traceback (most recent call last)
<ipython-input-41-17bb96b50588> in <module>()
      2 pyplot.figure(figsize=(16, 25))
      3 # for i in range(len(train)):
----> 4 history = model.fit_generator(generator(8, train_X, train_Y), epochs=20, steps_per_epoch = 7, verbose=0, shuffle=True)
      5 
      6 #     plot history

C:\ProgramData\Anaconda3\envs\dh\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

C:\ProgramData\Anaconda3\envs\dh\lib\site-packages\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1424             use_multiprocessing=use_multiprocessing,
   1425             shuffle=shuffle,
-> 1426             initial_epoch=initial_epoch)
   1427 
   1428     @interfaces.legacy_generator_methods_support

C:\ProgramData\Anaconda3\envs\dh\lib\site-packages\keras\engine\training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    189                 outs = model.train_on_batch(x, y,
    190                                             sample_weight=sample_weight,
--> 191                                             class_weight=class_weight)
    192 
    193                 if not isinstance(outs, list):

C:\ProgramData\Anaconda3\envs\dh\lib\site-packages\keras\engine\training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1212             x, y,
   1213             sample_weight=sample_weight,
-> 1214             class_weight=class_weight)
   1215         if self._uses_dynamic_learning_phase():
   1216             ins = x + y + sample_weights + [1.]

C:\ProgramData\Anaconda3\envs\dh\lib\site-packages\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
    752             feed_input_shapes,
    753             check_batch_axis=False,  # Don't enforce the batch size.
--> 754             exception_prefix='input')
    755 
    756         if y is not None:

C:\ProgramData\Anaconda3\envs\dh\lib\site-packages\keras\engine\training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    124                         ': expected ' + names[i] + ' to have ' +
    125                         str(len(shape)) + ' dimensions, but got array '
--> 126                         'with shape ' + str(data_shape))
    127                 if not check_batch_axis:
    128                     data_shape = data_shape[1:]

ValueError: Error when checking input: expected lstm_4_input to have 3 dimensions, but got array with shape (8, 23, 1, 16)

<Figure size 1152x1800 with 0 Axes>

0 个答案:

没有答案