将权重从Keras分类器导入TF对象检测API

时间:2019-04-08 17:40:00

标签: tensorflow keras

我有一个我使用keras训练的分类器,效果很好。它使用keras.applications.MobileNetV2

该分类器在大约200个类别上受过良好训练,并且具有很高的准确性。

但是,我想将此分类器中的特征提取层用作对象检测模型的一部分。

我一直在使用Tensorflow对象检测API,并研究了SSDLite + MobileNetV2模型。我可以开始进行训练,但是训练非常缓慢,损失的大部分来自分类阶段。

我想做的是将我的keras .h5模型中的权重分配给Tensorflow中MobileNetV2的特征提取层,但是我不确定做到这一点的最佳方法。

我可以轻松加载h5文件,并获取层名称列表:

import keras

keras_model = keras.models.load_model("my_classifier.h5")

keras_names = [l.name for l in keras_model.layers]

print(keras_names)

我还可以从对象检测API恢复张量流检查点并导出具有权重的图层:

tf.reset_default_graph()

with tf.Session() as sess:

    new_saver = tf.train.import_meta_graph('models/model.ckpt.meta')

    what = new_saver.restore(sess, 'models/model.ckpt')


    tf_names = []
    for op in sess.graph.get_operations():
        if "MobilenetV2" in op.name and "Assign" in op.name:
            tf_names.append(op.name)

    print(tf_names)

我似乎无法在keras和tensorflow的图层名称之间找到很好的匹配。即使我不能确定下一步。

如果有人可以给我一些建议,以最好的方式解决这个问题,

更新:

我在下面遵循了Sharky的建议,并做了一些修改:

new_saver = tf.train.import_meta_graph(os.path.join(keras_checkpoint_dir, 'keras_model.ckpt.meta'))

new_saver.restore(sess, os.path.join(keras_checkpoint_dir, tf.train.latest_checkpoint(keras_checkpoint_dir)))

不幸的是,我现在收到此错误:

  

NotFoundError(请参阅上面的回溯):从检查点还原   失败了这很可能是由于变量名或其他图形键   检查点缺少的内容。请确保您没有   根据检查点更改了期望的图形。原始错误:

     

键   FeatureExtractor / MobilenetV2 / expanded_conv_6 / project / BatchNorm / gamma   在检查点[[node save / RestoreV2_295(在   :7)= RestoreV2 [dtypes = [DT_FLOAT],   _device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”](_ arg_save / Const_0_0,   保存/恢复V2_295 /张量名称,   保存/恢复V2_295 / shape_and_slices)]] [[{{node   save / RestoreV2_196 / _393}} = _Recvclient_terminated = false,   recv_device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”,   send_device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”,   send_device_incarnation = 1,tensor_name =“ edge_789_save / RestoreV2_196”,   tensor_type = DT_FLOAT,   _device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”]]

关于如何消除此错误的任何想法?

1 个答案:

答案 0 :(得分:1)

您可以使用tf.keras.estimator.model_to_estimator

estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=path)
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, os.path.join(path/keras, tf.train.latest_checkpoint(path/keras)))
    print(tf.global_variables())

这应该可以完成工作。请注意,它将在原始指定的路径内创建一个子目录。