如何使用tf.cond?

时间:2018-09-25 22:17:44

标签: python tensorflow deep-learning

我正在尝试使用条件语句从模型中获取结果。 有3个张量都具有相同的形状:mu, sigma, alpha
其中,如果alpha[:, -1]的概率为零,我想得到0值。具体而言,如果alpha[:, -1]大于0.5,则其他值,而alpha[:, -1]小于0.5。 这是我的代码。

def get_mdn_result(self, params):
    # ouput of each step -> params : 2d
    mu = params[:, :self.num_mixture]
    sigma = params[:, self.num_mixture:self.num_mixture * 2]
    alpha = params[:, self.num_mixture * 2:self.num_mixture * 3]

    mu = tf.reduce_sum(mu * alpha, 1, keepdims=True)
    sigma = tf.reduce_sum(sigma**2 * alpha**2, 1, keepdims=True)

    mu = tf.cond(alpha[:, -1] < 0.5, lambda: tf.reduce_sum(mu * alpha, 1, keepdims=True), lambda: 0)

    return mu, sigma

但是我收到如下错误消息。

  

形状必须为0,但“ while / cond / Switch”的形状为1(操作:   输入形状为[?],[?]的“ Switch”)。

我该如何解决?

mu,阿尔法就是这样

mu

[[1, 2, 3], [2, 3, 4]]

alpha

[[0.1, 0.3, 0.6],[0.5, 0.2, 0.3]]

0 个答案:

没有答案