(batch_size, 70, 2, 2) -> Linear(2, 2)
(batch_size, 140, 2) -> Linear(2, 2)
(batch_size, 280) -> Linear(280, 2)
有人可以向我解释完全连接的图层如何处理非拼合的输入数据吗?我真的不知道如何考虑> 2D矩阵乘法。以上所有等同吗? (线性是pytorch中完全连接的模块)
答案 0 :(得分:0)
对于在大于2级(矩阵)的张量上进行的乘法,需要满足以下条件,例如,考虑2个张量A
和B
A.shape=[a1,a2,a3...a8]
和B.shape=[b1,b2,b3... b8]
尽管我不知道如何在pytorch中完成矩阵乘法,但我希望它与张量流相似。
如果您在张量流中执行tf.matmul
,则会在(a7,a8)
和(b7,b8)
上进行矩阵乘法,这要求a8
等于b7
才能进行操作A.B
,此外还要求a1..a6
等于b1..b6
输出形状为[a1, a2...a7,b8]
以上3个仅在将它们简单地在轴上展平的意义上是等效的