尝试建立条件规范化流时,遇到嵌套** kwargs错误

时间:2019-09-10 13:46:42

标签: tensorflow kwargs

我使用一些看中的机器学习术语,但是我认为对基础python尤其是kwargs如何工作有更好了解的人也可以帮助我,所以不要让这吓到您了。我正在尝试在张量流中建立条件规范化流。 TF设置了一些kwarg以使该任务更容易,但它并未在基本tensorflow中实现(尚未检查2.0,但我正在使用1.13.1的版本)。问题是,存在一些嵌套的kwarg,其中一个函数具有kwargs,而kwargs调用了另一个具有kwargs的函数,依此类推。我试图建立一个像kwargs = {"kwargs" : {"condit" : tf.concat([params,x_i], axis = 0)}}这样的嵌套字典,但是当我调用这样的dist.log_prob(y_i, **kwargs)时,其中dist的定义是这样的:

base_dist1 = tfd.MultivariateNormalDiag(loc=tf.zeros([dim]))

for k in range(4):
    bijector.append(tfb.RealNVP(2, shift_and_log_scale_fn=tfb.real_nvp.real_nvp_conditional_template(
                                     hidden_layers=[16])))
    bijector.append(tfb.BatchNormalization(training=boolval))
    bijector.append(tfb.Permute(permutation=[3,4,1,2,0]))#this needs to change as states and features change

bijector.append(tfb.RealNVP(2, shift_and_log_scale_fn=tfb.real_nvp.real_nvp_conditional_template(
                                     hidden_layers=[16])))


chainedbij = tfb.Chain(list(reversed(bijector)))

dist = tfd.ConditionalTransformedDistribution(
    distribution=base_dist1,
    bijector=chainedbij)

我收到此错误/回溯:

  File "<ipython-input-1-0a984deabb65>", line 1, in <module>
    runfile('/home/cameron/AnacondaProjects/Max_Like_NF/VAR_NF_test.py', wdir='/home/cameron/AnacondaProjects/Max_Like_NF')

  File "/home/cameron/anaconda3/lib/python3.6/site-packages/spyder/utils/site/sitecustomize.py", line 866, in runfile
    execfile(filename, namespace)

  File "/home/cameron/anaconda3/lib/python3.6/site-packages/spyder/utils/site/sitecustomize.py", line 102, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "/home/cameron/AnacondaProjects/Max_Like_NF/VAR_NF_test.py", line 102, in <module>
    dens = -calclik(ysim)

  File "/home/cameron/AnacondaProjects/Max_Like_NF/VAR_NF_test.py", line 98, in calclik
    lprob = lprob + tf.reduce_sum(dist.log_prob(y_i, **kwargs))

  File "/home/cameron/anaconda3/lib/python3.6/site-packages/tensorflow_probability/python/internal/distribution_util.py", line 2073, in _fn
    return fn(*args, **kwargs)

  File "/home/cameron/anaconda3/lib/python3.6/site-packages/tensorflow_probability/python/distributions/conditional_distribution.py", line 45, in log_prob
    return self._call_log_prob(value, name, **condition_kwargs)

  File "/home/cameron/anaconda3/lib/python3.6/site-packages/tensorflow_probability/python/distributions/distribution.py", line 697, in _call_log_prob
    return self._log_prob(value, **kwargs)

  File "/home/cameron/anaconda3/lib/python3.6/site-packages/tensorflow_probability/python/internal/distribution_util.py", line 2073, in _fn
    return fn(*args, **kwargs)

TypeError: _log_prob() got an unexpected keyword argument 'kwargs'

0 个答案:

没有答案