扫描中的随机数错误,theano.gof.fg.MissingInputError

时间:2015-01-03 16:27:08

标签: python theano

我在使用随机数和扫描时遇到了一个小问题。

请参阅我试图解决问题的这个小例子。

import theano as th
import numpy as np
from theano import tensor as T

stream=th.tensor.shared_randomstreams.RandomStreams()

avg = T.vector()

initial_values = np.array([1,2,3,4,5], dtype=th.config.floatX)
initials = th.shared(initial_values)

def get_output(prev_rand):
    rand = stream.normal(size=prev_rand.shape, avg=prev_rand.mean())
    random_fn = th.function([], rand)
    random_numbers = random_fn()
    return random_numbers

result, updates = th.scan(get_output, outputs_info=[initials], n_steps=10)

f = th.function([], result)

print f()

此代码应执行的操作如下: - 以数组开头,在本例中为[1,2,3,4,5] - 生成从正态分布中采样的随机数,平均值是先前输出的平均值(或初始观测值)   在这种情况下,第一步的平均值为3。 - 让我们说抽样的数字是:[2,3,3.5,4,5],新的平均值现在是3.5 - 重复以上10次步骤

相反,我得到以下错误输出:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Anaconda\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 585, in runfile
    execfile(filename, namespace)
  File "C:/Users/Main/Documents/Python Scripts/untitled13.py", line 24, in <module>
    result, updates = th.scan(get_output, outputs_info=[initials], n_steps=10)
  File "C:\Anaconda\lib\site-packages\theano\scan_module\scan.py", line 737, in scan
    condition, outputs, updates = scan_utils.get_updates_and_outputs(fn(*args))
  File "C:/Users/Main/Documents/Python Scripts/untitled13.py", line 21, in get_output
    random_numbers = th.function([], rand, givens={avg:rand})
  File "C:\Anaconda\lib\site-packages\theano\compile\function.py", line 265, in function
    profile=profile)
  File "C:\Anaconda\lib\site-packages\theano\compile\pfunc.py", line 511, in pfunc
    on_unused_input=on_unused_input)
  File "C:\Anaconda\lib\site-packages\theano\compile\function_module.py", line 1545, in orig_function
    on_unused_input=on_unused_input).create(
  File "C:\Anaconda\lib\site-packages\theano\compile\function_module.py", line 1224, in __init__
    fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
  File "C:\Anaconda\lib\site-packages\theano\compile\function_module.py", line 141, in std_fgraph
    fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs)
  File "C:\Anaconda\lib\site-packages\theano\gof\fg.py", line 135, in __init__
    self.__import_r__(outputs, reason="init")
  File "C:\Anaconda\lib\site-packages\theano\gof\fg.py", line 257, in __import_r__
    self.__import__(apply_node, reason=reason)
  File "C:\Anaconda\lib\site-packages\theano\gof\fg.py", line 353, in __import__
    detailed_err_msg)
theano.gof.fg.MissingInputError: A variable that is an input to the graph was neither provided as an input to the function nor given a value. A chain of variables leading from this input to an output is [<TensorType(float32, vector)>, Shape.0, Elemwise{Cast{int32}}.0, RandomFunction{normal}.1]. This chain may not be unique
Backtrace when the variable is created:
  File "C:\Anaconda\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 585, in runfile
    execfile(filename, namespace)
  File "C:/Users/Main/Documents/Python Scripts/untitled13.py", line 24, in <module>
    result, updates = th.scan(get_output, outputs_info=[initials], n_steps=10)
  File "C:\Anaconda\lib\site-packages\theano\scan_module\scan.py", line 597, in scan
    arg = safe_new(init_out['initial'])
  File "C:\Anaconda\lib\site-packages\theano\scan_module\scan_utils.py", line 75, in safe_new
    nw_x = x.type()

我可能在这里再次遗漏了一些简单明了的东西。

非常感谢帮助,谢谢!

1 个答案:

答案 0 :(得分:0)

您需要将获得表单扫描的更新字典传递给您的函数,所以:

f=th.function([initials], result, updates=updates)

此外,您不能将共享变量作为函数的输入。

你可以实现你想要做的事情,例如:

import theano as th
import numpy as np
from theano import tensor as T

stream=th.tensor.shared_randomstreams.RandomStreams()

avg=T.vector()
initials=T.fvector()

def get_output(prev_rand):
    return stream.normal(size=prev_rand.shape, avg=prev_rand.mean())

result, updates = th.scan(get_output, outputs_info=[initials], n_steps=10)

f = th.function([initials], result, updates=updates)

initial_values = np.array([1,2,3,4,5], dtype=th.config.floatX)

print f(initial_values)