如何将此代码(theano)转换为简单的python lign
[h_vals, _, y_vals] = theano.scan(fn=lstm_Step,
sequences=[dict(input=inputs, taps=[0])],
outputs_info=[h0, c0, None],
non_sequences=[Whx, Whh, Wcx, Wch, Wyh, bh, bc, by],
strict=True)[0]
答案 0 :(得分:0)
这是我的意思的一个例子,
import theano
import theano.tensor as tt
def add_multiply(a, b, k):
return a + b + k, a * b * k
def python_main():
x = 1
y = 2
k = 1
tuples = []
for i in range(5):
x, y = add_multiply(x, y, k)
tuples.append((x, y, k))
return tuples
def theano_main():
x = tt.constant(1, dtype='uint32')
y = tt.constant(2, dtype='uint32')
k = tt.scalar(dtype='uint32')
outputs, _ = theano.scan(add_multiply, outputs_info=[x, y], non_sequences=[k], n_steps=5)
g = theano.grad(tt.sum(outputs), k)
f = theano.function(inputs=[k], outputs=outputs + [g])
tuples = []
xvs, yvs, _ = f(1)
for xv, yv in zip(xvs, yvs):
tuples.append((xv, yv, 1))
return tuples
print 'Python:', python_main()
print 'Theano:', theano_main()
正如你所说@Nurzhan我需要知道这个库的作用,特别是theano.scan的意思。