如何从Tensorflow检查点将权重加载到Keras模型

时间:2018-10-01 19:18:43

标签: python tensorflow machine-learning keras

我有一些python代码来使用Tensorflow的TFRecords和Dataset API训练网络。我使用tf.Keras.layers构建了网络,可以说这是最简单,最快的方法。方便的函数model_to_estimator()

modelTF = tf.keras.estimator.model_to_estimator(
    keras_model=model,
    custom_objects=None,
    config=run_config,
    model_dir=checkPointDirectory
)

将Keras模型转换为估算器,这使我们可以很好地利用Dataset API,并在训练过程中和训练完成时将检查点自动保存到checkPointDirectory。估算器API提供了一些不可估量的功能,例如使用例如

在多个GPU上自动分配工作负载。
distribution = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(train_distribute=distribution)

现在对于大型模型和大量数据,使用某种形式的已保存模型进行训练后执行预测通常非常有用。似乎从Tensorflow 1.10开始(请参阅https://github.com/tensorflow/tensorflow/issues/19295),tf.keras.model对象从Tensorflow检查点支持load_weights()。在Tensorflow文档中简要提到了这一点,但在Keras文档中没有提到,而且我找不到任何显示此示例的人。在一些新的.py中再次定义模型层之后,我尝试了

checkPointPath = os.path.join('.', 'tfCheckPoints', 'keras_model.ckpt.index')
model.load_weights(filepath=checkPointPath, by_name=False)

但这会产生NotImplementedError:

Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.

2018-10-01 14:24:49.912087:
Traceback (most recent call last):
  File "C:/Users/User/PycharmProjects/python/mercury.classifier reductions/V3.2/wikiTestv3.2/modelEvaluation3.2.py", line 141, in <module>
    model.load_weights(filepath=checkPointPath, by_name=False)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1526, in load_weights
    checkpointable_utils.streaming_restore(status=status, session=session)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\training\checkpointable\util.py", line 880, in streaming_restore
    "Streaming restore not supported from name-based checkpoints. File a "
NotImplementedError: Streaming restore not supported from name-based checkpoints. File a feature request if this limitation bothers you.

我想按照警告的建议进行操作,而是使用“基于对象的保护程序”,但是我还没有找到通过传递给estimator.train()的RunConfig来执行此操作的方法。

因此,是否有更好的方法将保存的权重返回到估算器中以用于预测? github线程似乎暗示已经实现了(尽管基于错误,可能与我上面尝试的方式不同)。有没有人在TF检查点上成功使用过load_weights()?我还没有找到有关如何完成此操作的任何教程/示例,因此不胜感激。

2 个答案:

答案 0 :(得分:0)

我不确定,但也许您可以将keras_model.ckpt.index更改为keras_model.ckpt进行测试。

答案 1 :(得分:0)

您可以创建一个单独的图表,正常加载您的检查点,然后将权重转移到您的 Keras 模型:

_graph = tf.Graph()
_sess = tf.Session(graph=_graph)

tf.saved_model.load(_sess, ['serve'], '../tf1_save/')

_weights_all, _bias_all = [], []
with _graph.as_default():
  for idx, t_var in enumerate(tf.trainable_variables()):
    # substitue variable_scope with your scope
    if 'variable_scope/' not in t_var.name: break
    
    print(t_var.name)
    val = _sess.run(t_var)
    _weights_all.append(val) if idx % 2 == 0 else _bias_all.append(val)

for layer, (weight, bias) in enumerate(zip(_weights_all, _bias_all)):
  self.model.layers[layer].set_weights([np.array(weight), np.array(bias)])