如何在计算损失函数时使用keras层的权重?

时间:2018-02-23 22:50:44

标签: python tensorflow keras

我正在尝试构建一个只有一层的自动编码器:

from keras import backend as K

def cost2(y_true, y_pred):
    print "shapes:", model.get_weights()[0].shape
    yy = K.dot( y_pred, model.get_weights()[0].T )
    return np.sum((y_true - yy)**2)

x = Input(shape=(original_dim,))
y = Dense(latent_dim)(x)
model = Model(inputs=x, outputs=y)
model.summary()
model.compile(optimizer='adagrad', loss=cost2)

这给了我错误:

Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 784)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 1570      
=================================================================
Total params: 1,570
Trainable params: 1,570
Non-trainable params: 0
_________________________________________________________________

shapes: (784, 2)

追踪(最近一次通话):   文件“vae_kears_gidital_mnist3.py”,第45行,in     model.compile(optimizer ='adagrad',loss = cost2)   文件“/Users/asgharrazavi/anaconda/lib/python2.7/site-packages/keras/engine/training.py”,第830行,编译     sample_weight,mask)   文件“/Users/asgharrazavi/anaconda/lib/python2.7/site-packages/keras/engine/training.py”,第429行,加权     score_array = fn(y_true,y_pred)   在cost2中输入第18行“vae_kears_gidital_mnist3.py”     yy = K.dot(y_pred,model.get_weights()[0] .T)   文件“/Users/asgharrazavi/anaconda/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py”,第1048行,点     如果ndim(x)不是None且(ndim(x)> 2或ndim(y)> 2):   在ndim中输入文件“/Users/asgharrazavi/anaconda/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py”,第606行     dims = x.get_shape()._ dims AttributeError:'numpy.ndarray'对象没有属性'get_shape'

我只是试图将模型的输出乘以模型的转置权重以返回输入维度。 有什么想法吗?

1 个答案:

答案 0 :(得分:1)

您的费用函数应返回keras张量而不是numpdy ndarray。您应该仅在客户损失函数中使用keras.backend函数或特定的后端函数(例如tf.something)(即K.sum而不是np.sum

这是您在问题中提到的错误的原因,但更重要的是,您并未以创建自动编码器的 keras方式制作模型。在keras中,您的模型将使用两层(编码器和解码器)创建,其中图层通过转置和标准MSE损耗共享权重。我建议你阅读this post 的keras博客,看看this issue