如何使用自己的输入使用Pytorch的“使用seq2seq进行合并”?

时间:2019-04-08 02:51:09

标签: machine-learning nlp pytorch data-science seq2seq

我正在遵循指南here

当前这是模型:

SOS_token = 0
EOS_token = 1


class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s
def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('Scribe/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pair
MAX_LENGTH = 5000

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)


def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

我要做的事情和本指南之间的区别是,我试图将输入语言作为字符串列表插入,而不是从文件中读取它们:

pairs=['string one goes like this', 'string two goes like this'] 
input_lang = Lang(pairs[0][0])
output_lang = Lang(pairs[1][1]) 

但是,当我尝试计算字符串中的单词input_lang.n_words的数量时,我总是得到2。

调用类Lang时是否缺少某些内容?

更新:
我跑了

language = Lang('english')

for sentence in pairs: language.addSentence(sentence)

print (language.n_words)

那给了我pairs中的单词数
但是,这并不像指南那样给我input_langoutput_lang

for pair in pairs:
    input_lang.addSentence(pair[0])
    output_lang.addSentence(pair[1])

1 个答案:

答案 0 :(得分:0)

So first of all you are initialising the Lang object with calls to pairs[0][0] and pairs[1][1] which is the same as Lang('s') and Lang('t')

The Lang object is supposed to be an object that stores information about a language so I would expect you need to only initialise it once with Lang('english') and then add the sentences from you dataset to the Lang object with the Lang.addSentence function.

Right now you aren't loading your dataset into the Lang object at all so when you want to know language.n_words it is just the initial value it gets when the object is created self.n_words = 2 # Count SOS and EOS

None of what you are doing in your question makes any sense, but I think what you want is the following:

language = Lang('english')

for sentence in pairs: language.addSentence(sentence)

print (language.n_words)