我目前正在阅读我的同事的代码实现。然后我看到了这些代码行
label_embedding = Flatten()(Embedding(y_class, param_embed)(cond))
model_input = multiply([input_img, label_embedding])
label_embedding
的形状为[batch_size, 64]
,而input_img
的形状为[batch_size, H, W]
,其中H和W为图像尺寸(高度和宽度)。调试代码后,我发现model_output
的形状为[batch_size, H, W, 64]
。
基于Keras docs,Multiply
操作在相同大小的张量上进行逐元素乘法。但是代码如何工作并产生上述形状?