DMN神经网络具有较差的验证结果-仅50%

时间:2018-07-03 12:43:11

标签: python-3.x recurrent-neural-network pytorch

我的神经网络存在此问题。我正在尝试为babi数据集实现所谓的DMN(动态内存网络)。可以在以下位置找到关于DMN模型的论文:http://arxiv.org/abs/1506.07285在这里可以找到关于DMN模型的另一篇论文:https://yerevann.github.io/2016/02/05/implementing-dynamic-memory-networks/

这是我的问题。顺便说一句,我正在使用PyTorch。

我将训练和测试数据分为几个部分,以进行训练,测试和验证。我使用1000个零件进行培训,使用500个零件进行测试,并使用500个零件进行验证。我遇到了问题。我可以成功地进行训练,但是当我进入验证步骤时,我的分数从未超过50%。使用babi数据集可以证明,使用第一个测试集,您应该能够获得100%的准确性。 (共有20个测试集)。在培训期间,我可以获得100%的准确性,但在验证中只能获得50%的准确性。我想问的是,程序的哪一部分将负责这种行为?换句话说,你能告诉我为什么我总是得到50%吗?谢谢你的时间。目前,我的实验仅限于第一个babi测试。

我以为我已经弄明白了,但是我的问题又出现了。我真的不知道这是什么。这是代码的链接。如果您可以看看,我将不胜感激。 https://github.com/radiodee1/awesome-chatbot/blob/master/model/babi_iv.py

下面包含一些代码。

class WrapMemRNN(nn.Module):
    def __init__(self,vocab_size, embed_dim,  hidden_size, n_layers, dropout=0.3, do_babi=True, bad_token_lst=[], freeze_embedding=False, embedding=None, print_to_screen=False):
        super(WrapMemRNN, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.do_babi = do_babi
        self.print_to_screen = print_to_screen
        self.bad_token_lst = bad_token_lst
        self.embedding = embedding
        self.freeze_embedding = freeze_embedding
        self.teacher_forcing_ratio = hparams['teacher_forcing_ratio']

        gru_dropout = dropout * 0

        self.model_1_enc = Encoder(vocab_size, embed_dim, hidden_size, n_layers, dropout=dropout,embedding=embedding, bidirectional=False)
        self.model_2_enc = Encoder(vocab_size, embed_dim, hidden_size, n_layers, dropout=gru_dropout, embedding=embedding, bidirectional=False)
        self.model_3_mem_a = MemRNN(hidden_size, dropout=gru_dropout)
        self.model_3_mem_b = MemRNN(hidden_size, dropout=gru_dropout)
        self.model_4_att = EpisodicAttn(hidden_size, dropout=gru_dropout)
        self.model_5_ans = AnswerModule(vocab_size, hidden_size,dropout=dropout)

        self.input_var = None  # for input
        self.q_var = None  # for question
        self.answer_var = None  # for answer
        self.q_q = None  # extra question
        self.inp_c = None  # extra input
        self.inp_c_seq = None
        self.all_mem = None
        self.last_mem = None  # output of mem unit
        self.prediction = None  # final single word prediction
        self.memory_hops = hparams['babi_memory_hops']

        self.reset_parameters()

        if self.freeze_embedding or self.embedding is not None:
            self.new_freeze_embedding()
        #self.criterion = nn.CrossEntropyLoss()

        pass

    def reset_parameters(self):
        #print('reset')
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            #print('here...')
            weight.data.uniform_(-stdv, stdv)
            if len(weight.size()) > 1:
                init.xavier_normal_(weight)

    def forward(self, input_variable, question_variable, target_variable, criterion=None):

        self.new_input_module(input_variable, question_variable)
        self.new_episodic_module()
        outputs,  ans = self.new_answer_module_simple()

        return outputs, None, ans, None

    def new_freeze_embedding(self):
        self.model_1_enc.embed.weight.requires_grad = False
        self.model_2_enc.embed.weight.requires_grad = False
        print('freeze embedding')
        pass

    def new_input_module(self, input_variable, question_variable):

        prev_h1 = []

        for ii in input_variable:

            ii = self.prune_tensor(ii, 2)

            out1, hidden1 = self.model_1_enc(ii, None)

            prev_h1.append(hidden1)

        self.inp_c_seq = prev_h1
        self.inp_c = prev_h1[-1]

        prev_h2 = []

        for ii in question_variable:
            ii = self.prune_tensor(ii, 2)

            out2, hidden2 = self.model_2_enc(ii, None)
            prev_h2.append(hidden2)

        self.q_q = hidden2[:,-1,:]

        return


    def new_episodic_module(self):
        if True:

            mem_list = []

            sequences = self.inp_c_seq

            for i in range(len(sequences)):

                m_list = [self.q_q.clone()]

                #print(sequences[i].size(),'seq')

                for iter in range(self.memory_hops):

                    x = self.new_attention_step(sequences[i], None, m_list[iter], self.q_q)

                    if self.print_to_screen and not self.training:
                        print(x,'x -- after', len(x), sequences[i].size())

                    e, _ = self.new_episode_small_step(sequences[i], x.permute(1,0), None)

                    assert len(sequences[i].size()) == 3
                    #print(e.size(),'e')
                    ee = e[:, 0, -1]#.permute(2,1,0)

                    _, out = self.model_3_mem_a(ee.unsqueeze(0), self.prune_tensor(m_list[iter], 3))

                    m_list.append(out)

                mem_list.append(m_list[self.memory_hops])

            mm_list = torch.cat(mem_list, dim=1)

            self.last_mem = mm_list

            #print(self.last_mem.size(),'lm')

        return None

    def new_episode_small_step(self, ct, g, prev_h):

        assert len(ct.size()) == 3
        bat, sen, emb = ct.size()
        #print(ct.size(),'ct')
        #print(sen,'sen', g.size())
        last = [prev_h]

        ep = []
        for iii in range(sen):

            c = ct[0,iii,:].unsqueeze(0)

            if prev_h is not None:
                prev_h = self.prune_tensor(prev_h, 3)

            out, gru = self.model_3_mem_b(c, last[iii] )

            last.append(out)

            g = g.squeeze(0)
            gru = gru.squeeze(0).permute(1,0)

            #if not self.training: print(g.size(),'g', iii)
            #ggg = g[:, iii]
            ggg = g[iii]
            h = torch.mul(ggg , gru)#  + torch.mul((1 - g[iii]) , prev_h.permute(1,0))

            index = -1 #-1 # -2
            if last[iii + index] is not None:
                #print(last[iii].size(),'last -',ggg.size(), ggg, sen)
                if False: h = h + torch.mul((1 - ggg), last[iii + index])

            #print(h.size(),'hsize')
            if iii == sen - 1 : ep.append(h.unsqueeze(1))

        h = torch.cat(ep, dim=1)

        #print(h.size(),ep[0].size(),'h',sen, gru.size())

        return h, gru

    def new_attention_step(self, ct, prev_g, mem, q_q):

        q_q = self.prune_tensor(q_q,3)
        mem = self.prune_tensor(mem,3)

        assert len(ct.size()) == 3
        bat, sen, emb = ct.size()

        #print(sen,'len sen')

        att = []
        for iii in range(sen):
            c = ct[0,iii,:]

            concat_list = [
                c.unsqueeze(0),
                mem.squeeze(0),
                q_q.squeeze(0),
                (c * q_q).squeeze(0),
                (c * mem).squeeze(0),
                (torch.abs(c - q_q) ).squeeze(0),
                (torch.abs(c - mem) ).squeeze(0)
            ]
            #for ii in concat_list: print(ii.size())
            #print(sen,'sen')
            #exit()
            #z = F.sigmoid(z)
            concat_list = torch.cat(concat_list, dim=1)
            #print(concat_list.size(),'cl')
            att.append(concat_list)

        att = torch.cat(att, dim=0)
        #z = torch.cat(att, dim=0)
        z = self.model_4_att(att)
        z = F.sigmoid(z)
        #z =  F.softmax(z, dim=1) #F.sigmoid(z)
        #print(z.size(),'z')
        return z

    def prune_tensor(self, input, size):
        if len(input.size()) < size:
            input = input.unsqueeze(0)
        if len(input.size()) > size:
            input = input.squeeze(0)
        return input

    def new_answer_module_simple(self):
        #outputs

        ansx = self.model_5_ans(self.last_mem, None)

        #ansx = F.softmax(ansx, dim=0)

        if self.print_to_screen:
            print(ansx, 'ansx printed')
            print(ansx.size(), 'ansx')
            vocab, sen = ansx.size()
            aa = torch.argmax(ansx, dim=0)
            print(aa.size(),'aa')
            for i in range(sen):
                zz = aa[i]
                z = ansx[:, i]
                a = torch.argmax(z, dim=0)
                print(a.item(), zz.item())
            print('----')
        #ans = torch.argmax(ansx,dim=1)#[0]


        return [None], ansx

        pass

0 个答案:

没有答案