我正在尝试为混合密度网络计算一个简单的损失函数。 输出的形状为(batch_size = 128,2),mu参数的形状为(batch_size,2 * Num_mixes)。为了能够计算它们之间的差异,我需要对输出进行平铺,使其具有与mu相同的形状,但会出现主题错误
我不知道为什么输出的形状为(128,2),但其等级仍为0。请您帮忙?
def calc_pdf(y, mu, var):
"""Calculate component density"""
print(y.shape)
print(mu.shape)
value = tf.subtract(y, mu)**2
value = (1/tf.math.sqrt(2 * np.pi * var)) * tf.math.exp((-1/(2*var)) * value)
print(value.shape)
return value
y shape is (128, 2)
mu shape is (128, 60)
y rank is Tensor("Rank:0", shape=(), dtype=int32)
ValueError: Shape must be rank 1 but is rank 0 for 'Tile' (op: 'Tile') with input shapes: [128,2], [].