如何设计适合多对一数据集的网络结构?

时间:2018-11-29 07:17:54

标签: tensorflow machine-learning keras deep-learning artificial-intelligence

有一个数据集让我感到疯狂。像这样的结构:

A    B
1    x
2    x
3    x
4    y
5    y
6    y
7    y

我想使用此数据集来训练分类器。这个集合有两列:A和B。来自A的数据都不同,但是B相反。我有 5000 个不同的A和 5 个不同的b。因此,当我尝试使用NN拟合此数据集时,网络只会看到B列,这会导致acc低。 (我仅使用A列就获得了85%的acc,但是B列也包含很多有价值的信息) 我现在的网络结构:

def build_matching_model(opts, vocab_size=0, ew=None):
    N, L = opts.max_sents, opts.max_len
    logger.info(
        "params: max_sentlen = %d, embedding dim = %s, nbfilters = %s, filter1_len = %s, dropout rate = %s" % (
            opts.max_len, opts.embedding_dim, opts.lstm_units, opts.filter1_len, opts.dropout))
    if ew is None:
        embed = Embedding(vocab_size, opts.embedding_dim, input_length=N * L, mask_zero=True, name='embed')
    else:
        embed = Embedding(ew.shape[0], ew.shape[1], weights=[ew], input_length=N * L, mask_zero=True, name='embed')

    zme = ZeroMaskedEntries(name='maskedout')
    resh = Reshape((N, L, opts.embedding_dim), name='resh_W')
    zcnn = TimeDistributed(Conv1D(opts.nbfilters, opts.filter1_len, padding='valid'), name='zcnn')

    att = TimeDistributed(AttLayer(), name='avg_zcnn')
    dropout = Dropout(opts.dropout, name='dropout')
    bilstm = Bidirectional(LSTM(opts.hidden_units, return_sequences=True, name='lstm'))
    output_layer = Dense(units=3, name='output_layer')

    in_a = Input(shape=(N * L,), dtype='int32', name='in_a')
    in_b = Input(shape=(N * L,), dtype='int32', name='in_b')
    a_maskedout, b_maskedout = zme(embed(in_a)), zme(embed(in_b))
    a_resh, b_resh = resh(a_maskedout), resh(b_maskedout)
    a_cnn, b_cnn = zcnn(dropout(a_resh)), zcnn(dropout(b_resh))  # (S W V)
    a_att, b_att = att(a_cnn), att(b_cnn)  # (S V)
    a_feat, b_feat = bilstm(a_att), bilstm(b_att)
    aoa_feat = AoA(name='AOA')([a_feat, b_feat])
    output = output_layer(aoa_feat)
    model = Model(inputs=[in_a, in_b], outputs=output)
    optimizer = SGD()
    model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['acc'])

    return model

我需要一些有关如何处理此类数据集的帮助或有关网络设计的建议。谢谢

0 个答案:

没有答案