在Keras自定义层中乘以3矩阵

时间:2019-02-06 10:00:24

标签: keras matrix-multiplication

我想创建一个自定义Keras层,该层可计算2个输入矩阵和1个权重矩阵(对角矩阵)之间的乘积:x W y

x = Input((8,200)) # (?,8,200)
y = Input((10,200)) # (?,10,200)
W # Weight matrix define with Keras (200,) 

我想要一个输出形状为(?,8,10)的xWy的输出矩阵

我尝试:

K.dot(x*W, K.transpose(Y)) # Raise Dimension error
K.dot(x*W, Permute(2,1))(Y)) # (?, 8, ?, 10)

没有第一个尺寸(批量),我知道如何做,但是有了它我有点迷茫。

2 个答案:

答案 0 :(得分:1)

您可以使用为此目的而制作的K.batch_dot

 K.batch_dot(x*W, K.permute_dimensions(y, (0,2,1)), axes=[2, 1]) # (?, 8, 10)

可以解决问题。

答案 1 :(得分:0)

您可以指定在Keras Dot图层中获取点积所沿的轴。以下代码显示了如何将输入xy相乘。如果要添加权重矩阵W,可以用类似的方法(首先将xW相乘)来完成。

x = Input((8,200)) # (?,8,200)
y = Input((10,200)) # (?,10,200)
output = keras.layers.Dot(axes=-1)([x, y]) # (?,8,10)