在keras模型函数中使用loss

时间:2017-10-15 22:53:47

标签: tensorflow keras

我正在尝试使用模型函数使用keras构建一个非常简单的模型,如下所示,其中模型函数的输入和输出是[img,labels]和损失。 我很困惑,为什么这个代码不起作用,如果输出不能丢失。 Model函数应该如何工作以及何时应该使用Model函数?感谢。

sess = tf.Session()
K.set_session(sess)
K.set_learning_phase(1)
img = Input((784,),name='img')
labels = Input((10,),name='labels')
# img = tf.placeholder(tf.float32, shape=(None, 784))
# labels = tf.placeholder(tf.float32, shape=(None, 10))

x = Dense(128, activation='relu')(img)
x = Dropout(0.5)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)
preds = Dense(10, activation='softmax')(x)

from keras.losses import binary_crossentropy
#loss = tf.reduce_mean(categorical_crossentropy(labels, preds))
loss = binary_crossentropy(labels, preds)
print(type(loss))
model = Model([img,labels], loss, name='squeezenet')
model.summary()

1 个答案:

答案 0 :(得分:3)

正如@ yu-yang指出的,损失是用compile()指定的。 如果你考虑它,它是有道理的,因为模型的实际输出是你的预测,而不是损失,损失只用于训练模型。

您网络的一个工作示例:

import keras
from keras.optimizers import Adam
from keras.models import Model
from keras.layers import Input, Dense, Dropout
from keras.losses import categorical_crossentropy

img = Input((784,),name='img')

x = Dense(128, activation='relu')(img)
x = Dropout(0.5)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)
preds = Dense(10, activation='softmax')(x)

model = Model(inputs=img, outputs=preds, name='squeezenet')


model.compile(optimizer=Adam(),
              loss=categorical_crossentropy,
              metrics=['acc'])

model.summary()

输出:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
img (InputLayer)             (None, 784)               0         
_________________________________________________________________
dense_32 (Dense)             (None, 128)               100480    
_________________________________________________________________
dropout_21 (Dropout)         (None, 128)               0         
_________________________________________________________________
dense_33 (Dense)             (None, 128)               16512     
_________________________________________________________________
dropout_22 (Dropout)         (None, 128)               0         
_________________________________________________________________
dense_34 (Dense)             (None, 10)                1290      
=================================================================
Total params: 118,282
Trainable params: 118,282
Non-trainable params: 0
_________________________________________________________________

使用MNIST数据集:

from keras.datasets import mnist
from keras.utils import to_categorical

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(-1, 784)
y_train = to_categorical(y_train, num_classes=10)
x_test = x_test.reshape(-1, 784)
y_test = to_categorical(y_test, num_classes=10)

model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))

输出:

Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================] - 4s - loss: 12.2797 - acc: 0.2360 - val_loss: 11.0902 - val_acc: 0.3116
Epoch 2/10
60000/60000 [==============================] - 4s - loss: 10.4161 - acc: 0.3527 - val_loss: 8.7122 - val_acc: 0.4589
Epoch 3/10
60000/60000 [==============================] - 4s - loss: 9.5797 - acc: 0.4051 - val_loss: 8.9226 - val_acc: 0.4460
Epoch 4/10
60000/60000 [==============================] - 4s - loss: 9.2017 - acc: 0.4285 - val_loss: 8.0564 - val_acc: 0.4998
Epoch 5/10
60000/60000 [==============================] - 4s - loss: 8.8558 - acc: 0.4501 - val_loss: 8.0878 - val_acc: 0.4980
Epoch 6/10
60000/60000 [==============================] - 5s - loss: 8.8239 - acc: 0.4521 - val_loss: 8.2495 - val_acc: 0.4880
Epoch 7/10
60000/60000 [==============================] - 4s - loss: 8.7842 - acc: 0.4547 - val_loss: 7.7146 - val_acc: 0.5211
Epoch 8/10
60000/60000 [==============================] - 4s - loss: 8.7395 - acc: 0.4575 - val_loss: 7.7944 - val_acc: 0.5163
Epoch 9/10
60000/60000 [==============================] - 5s - loss: 8.7109 - acc: 0.4593 - val_loss: 7.8235 - val_acc: 0.5145
Epoch 10/10
60000/60000 [==============================] - 4s - loss: 8.4927 - acc: 0.4729 - val_loss: 7.5933 - val_acc: 0.5288