在使用共享权重嵌入图层后,在RNN模型中添加元数据

时间:2018-03-16 13:04:27

标签: tensorflow neural-network keras rnn word-embedding

我有一个包含文本线程共享权重的嵌入矩阵,之后我想在我的模型中添加元数据。但是,使用与以前相同的功能添加新图层以初始化模型图层会给出尺寸错误。有人能告诉我如何继续吗?

def build_post_submodel(arch='cnn', isEmbedding = True):
    """
    Creates and returns a function from Input to a chain of layer with shared weights for embedding matrix of different posts in our thread.
    """
    if isEmbedding:
     layers = [ Embedding(vocab_size,EMBEDDING_DIM,input_length=MAX_SEQUENCE_LENGTH)]
        if arch == 'cnn':
            layers.extend([
                Conv1D(128, 5, activation='relu'),
                MaxPooling1D(50),
                Flatten()
            ])
        elif arch == 'average':
            layers.extend([
                Masking(mask_value=WORD_PADDING_VALUE),
                # Average()
                Lambda(lambda x: K.mean(x, axis=1), output_shape=lambda s: (s[0], s[2]))
            ])
        else:
            raise ValueError('Unknown post architecture: %s' % arch)
    else:
        layers = [Conv1D(128, 5, activation='relu'),
                    MaxPooling1D(50),
                    Flatten()]

    def composed_layers(x):
        for layer in layers:
            x = layer(x)
        return x

    return composed_layers


post_submodel = build_post_submodel(post_arch, True)

# create an input for each post
input_tensors = []
encoded_posts = []
for i in range(use_number_of_posts):
    post_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
    input_tensors.append(post_input)
    encoded_post = post_submodel(post_input)
    encoded_posts.append(encoded_post)

#trying to add metadata to my model    
post_submodel = build_post_submodel(post_arch, False)
metadata_tensor = Input(shape=np.array(metadata_train).shape)
input_tensors.append(metadata_input)
encoded_post = post_submodel(metadata_input)
encoded_posts.append(encoded_post)

merged_vector = concatenate(encoded_posts, axis=-1)
preds = Dense(num_classes, activation='softmax')(merged_vector)
model = Model(input_tensors, preds)    
model.summary()

0 个答案:

没有答案