我只是按照Keras的文档,尝试在cifar-10数据集上微调resnet50。
(https://keras.io/applications/)
但是,当我调用它时,Model.fit()函数不起作用。程序刚关闭而没有任何通知。
以下是代码:
from keras.applications.resnet50 import ResNet50
import keras
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.datasets.cifar10 import load_data
(x_train, y_train), (x_test, y_test) = load_data()
print(x_train.shape)
y_train = keras.utils.to_categorical(y_train.reshape(-1), num_classes=10)
print(y_train.shape)
# create the base pre-trained model
base_model = ResNet50(weights='imagenet', include_top=False)
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
# x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 200 classes
predictions = Dense(10, activation='softmax')(x)
# this is the model we will train
model = Model(inputs=base_model.input,
outputs=predictions)
# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional InceptionV3 layers
for layer in base_model.layers:
layer.trainable = False
# compile the model (should be done *after* setting layers to non-trainable)
model.compile(optimizer=keras.optimizers.SGD(lr=0.00001, momentum=0.9), loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=8, epochs=2, shuffle=True, verbose=2)
score = model.evaluate(x_test, y_test, batch_size=100)
print(score)
运行之后,我得到了以下输出(图片):
... ... ...
2018-04-02 18:07:48.153260: I C:\tf_jenkins\workspace\rel-win\M\windows-gpu\PY\36\tensorflow\core\common_runtime\gpu\gpu_device.cc:1312] Adding visible gpu devices: 0
2018-04-02 18:07:48.803934: I C:\tf_jenkins\workspace\rel-win\M\windows-gpu\PY\36\tensorflow\core\common_runtime\gpu\gpu_device.cc:993] Creating TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 3031 MB memory) -> physical GPU (device: 0, name: GeForce GTX 970, pci bus id: 0000:03:00.0, compute capability: 5.2)
Epoch 1/2
Process finished with exit code -1073741676 (0xC0000094)
我对Keras来说是全新的,对tensorflow有一点了解。我真的不知道fit fit里面发生了什么。
非常感谢!