我有一个我使用keras训练的分类器,效果很好。它使用keras.applications.MobileNetV2
。
该分类器在大约200个类别上受过良好训练,并且具有很高的准确性。
但是,我想将此分类器中的特征提取层用作对象检测模型的一部分。
我一直在使用Tensorflow对象检测API,并研究了SSDLite + MobileNetV2模型。我可以开始进行训练,但是训练非常缓慢,损失的大部分来自分类阶段。
我想做的是将我的keras .h5
模型中的权重分配给Tensorflow中MobileNetV2的特征提取层,但是我不确定做到这一点的最佳方法。
我可以轻松加载h5
文件,并获取层名称列表:
import keras
keras_model = keras.models.load_model("my_classifier.h5")
keras_names = [l.name for l in keras_model.layers]
print(keras_names)
我还可以从对象检测API恢复张量流检查点并导出具有权重的图层:
tf.reset_default_graph()
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('models/model.ckpt.meta')
what = new_saver.restore(sess, 'models/model.ckpt')
tf_names = []
for op in sess.graph.get_operations():
if "MobilenetV2" in op.name and "Assign" in op.name:
tf_names.append(op.name)
print(tf_names)
我似乎无法在keras和tensorflow的图层名称之间找到很好的匹配。即使我不能确定下一步。
如果有人可以给我一些建议,以最好的方式解决这个问题,
更新:
我在下面遵循了Sharky的建议,并做了一些修改:
new_saver = tf.train.import_meta_graph(os.path.join(keras_checkpoint_dir, 'keras_model.ckpt.meta'))
new_saver.restore(sess, os.path.join(keras_checkpoint_dir, tf.train.latest_checkpoint(keras_checkpoint_dir)))
不幸的是,我现在收到此错误:
NotFoundError(请参阅上面的回溯):从检查点还原 失败了这很可能是由于变量名或其他图形键 检查点缺少的内容。请确保您没有 根据检查点更改了期望的图形。原始错误:
键 FeatureExtractor / MobilenetV2 / expanded_conv_6 / project / BatchNorm / gamma 在检查点[[node save / RestoreV2_295(在 :7)= RestoreV2 [dtypes = [DT_FLOAT], _device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”](_ arg_save / Const_0_0, 保存/恢复V2_295 /张量名称, 保存/恢复V2_295 / shape_and_slices)]] [[{{node save / RestoreV2_196 / _393}} = _Recvclient_terminated = false, recv_device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”, send_device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”, send_device_incarnation = 1,tensor_name =“ edge_789_save / RestoreV2_196”, tensor_type = DT_FLOAT, _device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”]]
关于如何消除此错误的任何想法?
答案 0 :(得分:1)
您可以使用tf.keras.estimator.model_to_estimator
estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=path)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, os.path.join(path/keras, tf.train.latest_checkpoint(path/keras)))
print(tf.global_variables())
这应该可以完成工作。请注意,它将在原始指定的路径内创建一个子目录。