这个LSTM有什么问题

时间:2018-07-23 09:02:21

标签: python machine-learning lstm

最近,我对机器学习算法感兴趣。我想使用LSTM生成文本。因此,请阅读a tutorial并借助以下方案,我在Python中实现了LSTM(或至少我认为LSTM是什么)。 LSTM scheme 这是我写的:

from math import log, tanh, exp
from string import printable as prin

inp = open('/path/to/training/text', 'r').read()
out = open('/path/to/output/file', 'w+')

out.seek(0)
out.truncate()

def encode(x):
    a = [0]*len(prin)
    a[prin.index(x)] = 1
    return a

def decode(x):
    s = ''
    i = 0
   while i<len(x):
        if x[i]:
            return prin[i]

def sigmoid(x):
    return 1/(1+exp(-x))

def tanhh(x):
    if x==1:
        return 0
    return 0.5*log((1+x)/(1-x))

class LSTM:
    def __init__(s, size):
        s.size = size
        s.out1 = [1]*size
        s.out2 = [1]*size

    def run(s, inp):
        if not len(inp)==s.size:
            raise ValueError('Size of list should be exactly '+str(s.size))

        sig = list(map(sigmoid, inp+s.out2))
        res1 = list(map(lambda x,y: x*y, sig, s.out1))
        res2 = list(map(lambda x,y: x*y, sig, map(tanh, s.out2)))
        res3 = list(map(lambda x,y: x+y, res1, res2))
        #list res3 -> out1
        s.out1 = res3

        res4 = list(map(lambda x,y: x*y, sig, map(tanh, res3)))
        #list res4 -> return
        #list res4 -> out2
        s.out2 = res4
        return list(map(tanhh, res4))

l = LSTM( len(prin) )

line = 20

c = 0
for i in inp:
    try:
        r = decode( l.run( encode(i) ) )
        print(r)
        out.write(r)
        if c == line:
            out.write('\n')
            c = 0
        c += 1
    except KeyboardInterrupt:
        out.close()
        break

我使用了 https://norvig.com/big.txt的Arthur Conan Doyle爵士撰写的“福尔摩斯历险记”作为培训文字。当我运行它时,它使用了我的Core-i5的15%左右,因此至少意味着它需要进行一些计算。但是它不会生成任何接近文本的内容,只会给出“ 0”。这是out.txt中的示例:

00000000000000000000
00000000000000000000
00000000000000000000
00000000000000000000

此程序有什么问题?也许我写的只是一个大型工作程序的一小部分?还是所有代码都是完全错误的?

0 个答案:

没有答案