我正在尝试使用条件语句从模型中获取结果。
有3个张量都具有相同的形状:mu, sigma, alpha
。
其中,如果alpha[:, -1]
的概率为零,我想得到0值。具体而言,如果alpha[:, -1]
大于0.5,则其他值,而alpha[:, -1]
小于0.5。
这是我的代码。
def get_mdn_result(self, params):
# ouput of each step -> params : 2d
mu = params[:, :self.num_mixture]
sigma = params[:, self.num_mixture:self.num_mixture * 2]
alpha = params[:, self.num_mixture * 2:self.num_mixture * 3]
mu = tf.reduce_sum(mu * alpha, 1, keepdims=True)
sigma = tf.reduce_sum(sigma**2 * alpha**2, 1, keepdims=True)
mu = tf.cond(alpha[:, -1] < 0.5, lambda: tf.reduce_sum(mu * alpha, 1, keepdims=True), lambda: 0)
return mu, sigma
但是我收到如下错误消息。
形状必须为0,但“ while / cond / Switch”的形状为1(操作: 输入形状为[?],[?]的“ Switch”)。
我该如何解决?
mu,阿尔法就是这样
mu
[[1, 2, 3], [2, 3, 4]]
alpha
[[0.1, 0.3, 0.6],[0.5, 0.2, 0.3]]