tensorflow.python.framework.errors_impl.InvalidArgumentError:在KERAS LSTM模型中

时间:2018-10-02 16:19:19

标签: python tensorflow keras nlp lstm

我有以下lstm模型:

class LSTM_model():

def __init__(self):
    w2v_model = gensim.models.Word2Vec(sentences, size=150, window=10, min_count=2, workers=10)
    pretrained_weights = w2v_model.wv.syn0
    vocab_size, emdedding_size = pretrained_weights.shape
    self.w2v_model = w2v_model
    self.keras_lstm_model = Sequential()
    self.keras_lstm_model.add(Embedding(input_dim = vocab_size, output_dim = emdedding_size, weights = [pretrained_weights]))
    self.keras_lstm_model.add(LSTM(units = emdedding_size))
    self.keras_lstm_model.add(Dense(units = vocab_size))
    self.keras_lstm_model.add(Activation('sigmoid'))
    self.keras_lstm_model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics = ['mae','acc'])

我的目的是预测上下文中给定单词的概率。 我有句子列表,我想训练这个模型:

   def train_lstm_model(self, sentences):

        sentences_as_indexes = []


        filtered_sentences = list(filter(lambda x : all([w in self.w2v_model.wv.vocab for w in x]) , sentences)) #filter to take only sentences with no OOV words
        for sentence in filtered_sentences:
            if all([w in self.w2v_model.wv.vocab for w in sentence]): #todo use filter
                indexes_row = []
                for word in sentence:

                    idx = self.w2v_model.wv.vocab.get(word).index
                    indexes_row.append(idx)

                sentences_as_indexes.append(indexes_row)
        X = pd.DataFrame([sentence[:-1] for sentence in sentences_as_indexes])
        y = pd.DataFrame([sentence[-1] for sentence in sentences_as_indexes])
        print(datetime.datetime.now(), ": Fitting LSTM model , size of X is ", X.shape)

        self.keras_lstm_model.fit(X, y) #HERE the error

这里的句子是大约一百万个句子的列表(我的训练集)。 我在fit()中收到以下错误:

  

tensorflow.python.framework.errors_impl.InvalidArgumentError:索引[7,38] = -2147483648不在[0,694415)中        [[节点:embedding_1 / embedding_lookup = GatherV2 [Taxis = DT_INT32,Tindices = DT_INT32,Tparams = DT_FLOAT,_class = [“ loc:@ training / Adam / Assign_2”],_ device =“ / job:localhost / replica:0 / task :0 /设备:CPU:0“](嵌入_1 /嵌入/读取,嵌入_1 /投放,训练/亚当/渐变/嵌入/嵌入_1 /嵌入_lookup_grad / concat /轴)]]

我虽然可能是因为我拥有的数据是OOV,但是如您所见-我过滤了句子以仅包含词汇中所有单词的句子。

什么可能导致此错误?

谢谢!

0 个答案:

没有答案