我正在尝试在具有CIFAR10数据的预先训练的CNN(Densenet,VGG和Resnet)的基础上训练GP的混合模型,模仿gpflow文档中的ex2函数。但是测试结果总是在0.1〜0.2之间,这通常意味着随机猜测(Wilson + 2016论文显示CIFAR10数据的混合模型的准确度应为0.7)。谁能给我提示可能有什么问题吗?
我已经用更简单的cnn模型(2个转换层或4个转换层)尝试了相同的代码,并且都具有合理的结果。我尝试使用不同的Keras应用程序:Densenet121,VGG16,ResNet50,都无法使用。我试图冻结仍无法正常工作的预训练模型中的权重。
def cnn_dn(output_dim):
base_model = DenseNet121(weights='imagenet', include_top=False, input_shape=(32,32,3))
bout = base_model.output
fcl = GlobalAveragePooling2D()(bout)
#for layer in base_model.layers:
# layer.trainable = False
output=Dense(output_dim, activation='relu')(fcl)
md=Model(inputs=base_model.input, outputs=output)
return md
#add gp on top, reference:ex2() function in
#https://nbviewer.jupyter.org/github/GPflow/GPflow/blob/develop/doc/source/notebooks/tailor/gp_nn.ipynb
#needs to slightly change build graph part because keras variable #sharing is not the same as tensorflow
#......
## build graph
with tf.variable_scope('cnn'):
md=cnn_dn(gp_dim)
f_X = tf.cast(md(X), dtype=float_type)
f_Xtest = tf.cast(md(Xtest), dtype=float_type)
#......
## predict
res=np.argmax(sess.run(my, feed_dict={Xtest:xts}),1).reshape(yts.shape)
correct = res == yts.astype(int)
print(np.average(correct.astype(float)))
答案 0 :(得分:1)
我最终发现解决方案正在训练更大的迭代。在原始代码中,我仅使用ex2()函数中的50次迭代来处理MNIST数据,而对于更复杂的网络和CIFAR10数据而言,这还不够。调整一些超参数(例如学习率和激活功能)也有帮助。