Keras:乘(2D-矩阵,向量)= 3D-矩阵

时间:2019-12-17 11:08:57

标签: python keras matrix-multiplication

我目前正在阅读我的同事的代码实现。然后我看到了这些代码行

label_embedding = Flatten()(Embedding(y_class, param_embed)(cond))
model_input = multiply([input_img, label_embedding])

label_embedding的形状为[batch_size, 64],而input_img的形状为[batch_size, H, W],其中H和W为图像尺寸(高度和宽度)。调试代码后,我发现model_output的形状为[batch_size, H, W, 64]

基于Keras docsMultiply操作在相同大小的张量上进行逐元素乘法。但是代码如何工作并产生上述形状?

0 个答案:

没有答案