编写代码来测试seq2seq预测模型

时间:2017-08-16 06:02:27

标签: python machine-learning deep-learning classification keras

我曾尝试为lang翻译编写培训代码。现在在测试分类时,我面临着问题。它只需要调整测试输入的X和Y值,但我有点困惑。

这是训练模型的代码

def train_seq2seq(self):
    print("Input sequence read, starting training")
    s2s = seq2seq(self.vocab_size + 3, self.maxlen + 2, \
                                  self.vocab_size + 3)
    self.model = s2s.seq2seq_plain()
    #For testing considering 100 epoch instead of 10000
    for e in range(10):
        print("epoch %d \n" % e)
        for ind, (X,Y) in enumerate(self.proproces.gen_batch()):
            loss, acc = model.train_on_batch(X, Y)#, batch_size=64, nb_epoch=1)
            #print("Loss is %f, accuracy is %f " % (loss, acc), end='\r')
            # After one epoch test one sentence
            if ind % 10 == 0:
                testX = X[0,:].reshape(1, self.maxlen + 2)
                testY = Y[0]
                pred = model.predict(testX, batch_size=1)
                self.decode(testX, pred)

我面临问题的测试代码是 -

def encode(self):
    #Encodes input sentence into fixed length vector
    #print("Enter sentence in hindi")
    inp = raw_input("Please enter the sentence\n").decode("utf-8")
    tokens = inp.split()
    seq = []
    for token in tokens:
        if token in self.proproces.vocab_tar:
            seq.append(self.proproces.vocab_tar[token])
        else:
            token = "UNK"
            seq.append(self.proproces.vocab_tar[token])
    #seq = map(lambda x:self.proproces.vocab_hind[x], tokens)
    # Normalize seq to maxlen
    X = []
    X.append(seq)
    print(X) #[[400, 23, 400]]
    temp = pad_sequences(X, maxlen=self.maxlen)
    print(temp.shape) #(1, 6)
    temp[0:len(seq)] = seq
    # print(len(temp))
    # temp = np.asarray(temp).reshape(128,)
    # print(temp.shape)
    prob = self.model.predict_on_batch(temp)#, batch_size=1, verbose=0)
    translated = self.decode(prob)
    print("Tranlated is", translated)
    print("Probabilities are", prob)
    print("Shape of prob tensor is",prob.shape)

我指的是本教程 - https://github.com/shashankg7/Seq2Seq/blob/master/seq2seq/seq2seq.py

我对调整X和Y的形状很困惑,任何指导都非常赞赏

0 个答案:

没有答案