我有n
个网络,每个网络都具有相同的输入/输出。我想根据分类分布随机选择输出之一。 Tfp.Categorical仅输出整数,而我尝试做类似
act_dist = tfp.distributions.Categorical(logits=act_logits) # act_logits are all the same, so the distribution is uniform
rand_out = act_dist.sample()
x = nn_out1 * tf.cast(rand_out == 0., dtype=tf.float32) + ... # for all my n networks
但是rand_out == 0.
以及其他条件总是错误的。
实现我所需要的任何想法吗?
答案 0 :(得分:1)
您可能还会看到MixtureSameFamily,它为您在幕后进行了聚拢。
nn_out1 = tf.expand_dims(nn_out1, axis=2)
...
outs = tf.concat([nn_out1, nn_nout2, ...], axis=2)
probs = tf.tile(tf.reduce_mean(tf.ones_like(nn_out1), axis=1, keepdims=True) / n, [1, n]) # trick to have ones of shape [None,1]
dist = tfp.distributions.MixtureSameFamily(
mixture_distribution=tfp.distributions.Categorical(probs=probs),
components_distribution=tfp.distributions.Deterministic(loc=outs))
x = dist.sample()
答案 1 :(得分:0)
我认为您需要使用tf.equal,因为Tensor == 0始终为False。
不过,您可能要单独使用OneHotCategorical。对于培训,您还可以尝试使用RelaxedOneHotCategorical。