以下是步骤。
火炬很容易,但是我必须在keras中实现它,所以我不知道。这让我困惑了两个多星期。
答案 0 :(得分:0)
给出一个输入:
from keras.layers import *
inp = Input((5,5,100)) #or the output of a layer coming before this
out = Reshape((25,100))(inp)
out = Lambda(operation, output_shape = (100,100))(out)
out = Reshape((100,100,1))(out)
操作地点:
import keras.backend as K
def operation(x):
xT = K.permute_dimensions(x,(0,2,1)) #batch axes 0 is kept, 1 and 2 are swaped
return K.batch_dot(x,xT,axes=[1,2])
如果您正在使用顺序模型:
model.add(Reshape(25,100))
model.add(Lambda(operation, output_shape = (100,100)))
model.add(Reshape((100,100,1)))
如果你想要一个单层内的所有东西:
def operation(x):
x2 = K.reshape(x,(25,100))
x2T = K.permute_dimensions(x,(0,2,1))
d = K.batch_dot(x2,x2T,axes=[1,2])
return K.reshape(d,(100,100,1))
myLayer = Lambda(operation, output_shape=(100,100,1))
如果您想要该图层的多个实例:
def myLayer():
return Lambda(operation, output_shape=(100,100,1))