安装SimpleSeq2Seq模型时输入尺寸误差

时间:2018-02-21 08:14:38

标签: python tensorflow neural-network deep-learning keras

我正在为主题分类构建模型并尝试将seq2seq用于模型层,但是当我实现它时会导致ValueError

  

“ValueError:检查输入时出错:预期input_4有3个维度,但得到的数组有形状(160980,15)”。

任何人都知道它是什么?原因我只有两个维度输入数据(201225,15)和标签(201225,41)。不知道为什么它需要三个维度。这是代码

from keras.models import Sequential, save_model
from keras.layers import Dense, Input, Flatten, Embedding, Dropout, Conv1D, 
MaxPooling1D, GlobalMaxPooling1D, LSTM
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from keras.utils import to_categorical
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
import keras.backend as K
from keras.utils import plot_model
from keras.layers.wrappers import TimeDistributed

from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd
import seaborn as sns
from pandas import Series

import seq2seq
from seq2seq.models import SimpleSeq2Seq

# load data
texts = open('c:\\Users/KW198/Documents/topic_model/keywords.txt', 
encoding='utf8').read().split('\n')
all_labels = open('c:\\Users/KW198/Documents/topic_model/topics.txt', 
encoding='utf8').read().split('\n')

# Tokenlize data
tok = Tokenizer()
tok.fit_on_texts(texts)        
sequences = tok.texts_to_sequences(texts)
word_index = tok.word_index
print('Found %s unique tokens.' % len(word_index))
data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)
labels = to_categorical(np.asarray(all_labels))
print('Shape of data tensor:', data.shape)
print('Shape of label tensor:', labels.shape)
  

找到341826个独特的令牌。
  数据张量的形状:(201225,15)
  标签张量的形状:(201225,41)

K.clear_session()

model = SimpleSeq2Seq(input_dim=15, hidden_dim=10, output_length=41, 
output_dim=41)

#plot_model(model, to_file='model.png',show_shapes=True)
model.compile(loss='categorical_crossentropy',
          optimizer='rmsprop',
          metrics=['acc'])

checkporint = EarlyStopping(monitor='val_acc', patience=5, mode='max', 
min_delta=0.003)
model.fit(x_train, y_train, epochs=13, batch_size=128, verbose=1, 
validation_split=0.2, callbacks=[checkporint])

score = model.evaluate(x_test, y_test, batch_size=128, verbose=1)
print('Test score:', score[0])
print('Test accuracy:', score[1])

这是一条错误消息

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-31-fbd441ff95e2> in <module>()
      5 
      6 checkporint = EarlyStopping(monitor='val_acc', patience=5, 
      mode='max',  min_delta=0.003)
----> 7 model.fit(x_train, y_train, epochs=13, batch_size=128, verbose=1, 
validation_split=0.2, callbacks=[checkporint])

~\Anaconda3\envs\ztdl\lib\site-packages\keras\engine\training.py in 
fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, 
validation_data, shuffle, class_weight, sample_weight, initial_epoch, 
**kwargs)
   1427             class_weight=class_weight,
   1428             check_batch_axis=False,
-> 1429             batch_size=batch_size)
   1430         # Prepare validation data.
   1431         if validation_data:

~\Anaconda3\envs\ztdl\lib\site-packages\keras\engine\training.py in 
_standardize_user_data(self, x, y, sample_weight, class_weight, 
check_batch_axis, batch_size)
   1303                                     self._feed_input_shapes,
   1304                                     check_batch_axis=False,
-> 1305                                     exception_prefix='input')
   1306         y = _standardize_input_data(y, self._feed_output_names,
   1307                                     output_shapes,

~\Anaconda3\envs\ztdl\lib\site-packages\keras\engine\training.py in 
_standardize_input_data(data, names, shapes, check_batch_axis, 
exception_prefix)
    125                                  ' to have ' + str(len(shapes[i])) +
    126                                  ' dimensions, but got array with 
shape ' +
--> 127                                  str(array.shape))
    128             for j, (dim, ref_dim) in enumerate(zip(array.shape, 
shapes[i])):
    129                 if not j and not check_batch_axis:

ValueError: Error when checking input: expected input_4 to have 3 
dimensions, but got array with shape (160980, 15)

1 个答案:

答案 0 :(得分:0)

输入应为3维张量,其中维度代表(batch_size, input_length, input_dim)

因此,在您的情况下,如果序列的长度为15且输入维度为1,则应将输入重新整形为(?, 15, 1)

如果您的单词序列具有固定长度(例如15),那么您应该使用input_length=15参数。