如何在双向LSTM中提供3个输入

时间:2019-03-15 06:51:37

标签: python-3.x lstm bidirectional

MY项目包含3个输入Passage,question,option,我想从这些输入中获得输出“正确答案”。所以,您能帮助您知道hoe为模型提供输入吗,我已经通过压缩3个输入完成了操作,但是我没有得到答案。我的代码是

def create_train_dev_set(tokenizer,pair, is_similar, max_sequence_length, validation_split_ratio):

    quest = [x[0] for x in pair]
    art = [x[1] for x in pair]
    opt = [x[2] for x in pair]
    train_sequences_1 = tokenizer.texts_to_sequences(quest)
    train_sequences_2 = tokenizer.texts_to_sequences(art)
    train_sequences_3 = tokenizer.texts_to_sequences(opt)
    leaks = [[len(set(x1)), len(set(x2)),len(set(x3)),len(set(x1).intersection(x2)).intersection(x3)]
             for x1, x2,x3 in zip(train_sequences_1, train_sequences_2,train_sequences_3)

    train_padded_data_1 = pad_sequences(train_sequences_1, maxlen=max_sequence_length)
    train_padded_data_2 = pad_sequences(train_sequences_2, maxlen=max_sequence_length)
    train_padded_data_3 = pad_sequences(train_sequences_3, maxlen=max_sequence_length)
    train_labels = np.array(is_similar)
    leaks = np.array(leaks)

    shuffle_indices = np.random.permutation(np.arange(len(train_labels)))
    train_data_1_shuffled = train_padded_data_1[shuffle_indices]
    train_data_2_shuffled = train_padded_data_2[shuffle_indices]
    train_data_3_shuffled = train_padded_data_3[shuffle_indices]
    train_labels_shuffled = train_labels[shuffle_indices]
    leaks_shuffled = leaks[shuffle_indices]

    dev_idx = max(1, int(len(train_labels_shuffled) * validation_split_ratio))

    gc.collect()

    train_data_1, val_data_1 = train_data_1_shuffled[:-dev_idx], train_data_1_shuffled[-dev_idx:]
    train_data_2, val_data_2 = train_data_2_shuffled[:-dev_idx], train_data_2_shuffled[-dev_idx:]
    train_data_3, val_data_3 = train_data_3_shuffled[:-dev_idx], train_data_3_shuffled[-dev_idx:]
    labels_train, labels_val = train_labels_shuffled[:-dev_idx], train_labels_shuffled[-dev_idx:]
    leaks_train, leaks_val = leaks_shuffled[:-dev_idx], leaks_shuffled[-dev_idx:]

    return train_data_1, train_data_2,train_data_3,labels_train, leaks_train, val_data_1, val_data_2, val_data_3 labels_val, leaks_val
tokenizer, embedding_matrix = word_embed_meta_data(quest+ art+ opt, EMBEDDING_DIM)

embedding_meta_data = {
    'tokenizer': tokenizer,
    'embedding_matrix': embedding_matrix
}

pair = [(x1, x2, x3) for x1, x2, x3 in zip(quest,art,opt)]

我的输入是3个csv文件,输出是另一个csv

0 个答案:

没有答案