预训练的densitynet / vgg16 / resnet50 + gp无法在cifar10数据上进行训练

时间:2019-06-25 20:26:35

标签: gpflow

我正在尝试在具有CIFAR10数据的预先训练的CNN(Densenet,VGG和Resnet)的基础上训练GP的混合模型,模仿gpflow文档中的ex2函数。但是测试结果总是在0.1〜0.2之间,这通常意味着随机猜测(Wilson + 2016论文显示CIFAR10数据的混合模型的准确度应为0.7)。谁能给我提示可能有什么问题吗?

我已经用更简单的cnn模型(2个转换层或4个转换层)尝试了相同的代码,并且都具有合理的结果。我尝试使用不同的Keras应用程序:Densenet121,VGG16,ResNet50,都无法使用。我试图冻结仍无法正常工作的预训练模型中的权重。

def cnn_dn(output_dim):
    base_model = DenseNet121(weights='imagenet', include_top=False, input_shape=(32,32,3))
    bout = base_model.output
    fcl = GlobalAveragePooling2D()(bout)
    #for layer in base_model.layers:
    #    layer.trainable = False
    output=Dense(output_dim, activation='relu')(fcl)
    md=Model(inputs=base_model.input, outputs=output)
    return md

#add gp on top, reference:ex2() function in
#https://nbviewer.jupyter.org/github/GPflow/GPflow/blob/develop/doc/source/notebooks/tailor/gp_nn.ipynb
#needs to slightly change build graph part because keras variable #sharing is not the same as tensorflow
#......

## build graph
with tf.variable_scope('cnn'):
    md=cnn_dn(gp_dim)
    f_X = tf.cast(md(X), dtype=float_type)
    f_Xtest = tf.cast(md(Xtest), dtype=float_type)

#......

    ## predict

res=np.argmax(sess.run(my, feed_dict={Xtest:xts}),1).reshape(yts.shape)
correct = res == yts.astype(int)
print(np.average(correct.astype(float)))

1 个答案:

答案 0 :(得分:1)

我最终发现解决方案正在训练更大的迭代。在原始代码中,我仅使用ex2()函数中的50次迭代来处理MNIST数据,而对于更复杂的网络和CIFAR10数据而言,这还不够。调整一些超参数(例如学习率和激活功能)也有帮助。