我在使用随机数和扫描时遇到了一个小问题。
请参阅我试图解决问题的这个小例子。
import theano as th
import numpy as np
from theano import tensor as T
stream=th.tensor.shared_randomstreams.RandomStreams()
avg = T.vector()
initial_values = np.array([1,2,3,4,5], dtype=th.config.floatX)
initials = th.shared(initial_values)
def get_output(prev_rand):
rand = stream.normal(size=prev_rand.shape, avg=prev_rand.mean())
random_fn = th.function([], rand)
random_numbers = random_fn()
return random_numbers
result, updates = th.scan(get_output, outputs_info=[initials], n_steps=10)
f = th.function([], result)
print f()
此代码应执行的操作如下: - 以数组开头,在本例中为[1,2,3,4,5] - 生成从正态分布中采样的随机数,平均值是先前输出的平均值(或初始观测值) 在这种情况下,第一步的平均值为3。 - 让我们说抽样的数字是:[2,3,3.5,4,5],新的平均值现在是3.5 - 重复以上10次步骤
相反,我得到以下错误输出:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\Anaconda\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 585, in runfile
execfile(filename, namespace)
File "C:/Users/Main/Documents/Python Scripts/untitled13.py", line 24, in <module>
result, updates = th.scan(get_output, outputs_info=[initials], n_steps=10)
File "C:\Anaconda\lib\site-packages\theano\scan_module\scan.py", line 737, in scan
condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args))
File "C:/Users/Main/Documents/Python Scripts/untitled13.py", line 21, in get_output
random_numbers = th.function([], rand, givens={avg:rand})
File "C:\Anaconda\lib\site-packages\theano\compile\function.py", line 265, in function
profile=profile)
File "C:\Anaconda\lib\site-packages\theano\compile\pfunc.py", line 511, in pfunc
on_unused_input=on_unused_input)
File "C:\Anaconda\lib\site-packages\theano\compile\function_module.py", line 1545, in orig_function
on_unused_input=on_unused_input).create(
File "C:\Anaconda\lib\site-packages\theano\compile\function_module.py", line 1224, in __init__
fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
File "C:\Anaconda\lib\site-packages\theano\compile\function_module.py", line 141, in std_fgraph
fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs)
File "C:\Anaconda\lib\site-packages\theano\gof\fg.py", line 135, in __init__
self.__import_r__(outputs, reason="init")
File "C:\Anaconda\lib\site-packages\theano\gof\fg.py", line 257, in __import_r__
self.__import__(apply_node, reason=reason)
File "C:\Anaconda\lib\site-packages\theano\gof\fg.py", line 353, in __import__
detailed_err_msg)
theano.gof.fg.MissingInputError: A variable that is an input to the graph was neither provided as an input to the function nor given a value. A chain of variables leading from this input to an output is [<TensorType(float32, vector)>, Shape.0, Elemwise{Cast{int32}}.0, RandomFunction{normal}.1]. This chain may not be unique
Backtrace when the variable is created:
File "C:\Anaconda\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 585, in runfile
execfile(filename, namespace)
File "C:/Users/Main/Documents/Python Scripts/untitled13.py", line 24, in <module>
result, updates = th.scan(get_output, outputs_info=[initials], n_steps=10)
File "C:\Anaconda\lib\site-packages\theano\scan_module\scan.py", line 597, in scan
arg = safe_new(init_out['initial'])
File "C:\Anaconda\lib\site-packages\theano\scan_module\scan_utils.py", line 75, in safe_new
nw_x = x.type()
我可能在这里再次遗漏了一些简单明了的东西。
非常感谢帮助,谢谢!
答案 0 :(得分:0)
您需要将获得表单扫描的更新字典传递给您的函数,所以:
f=th.function([initials], result, updates=updates)
此外,您不能将共享变量作为函数的输入。
你可以实现你想要做的事情,例如:
import theano as th
import numpy as np
from theano import tensor as T
stream=th.tensor.shared_randomstreams.RandomStreams()
avg=T.vector()
initials=T.fvector()
def get_output(prev_rand):
return stream.normal(size=prev_rand.shape, avg=prev_rand.mean())
result, updates = th.scan(get_output, outputs_info=[initials], n_steps=10)
f = th.function([initials], result, updates=updates)
initial_values = np.array([1,2,3,4,5], dtype=th.config.floatX)
print f(initial_values)