使用tf.estimator时如何通过import_meta_graph训练/微调模型构建图

时间:2019-08-29 09:09:30

标签: python tensorflow multi-gpu

我正在微调tf.estimator训练的模型,但是通过import_meta_graph而不是代码来构建图形。 如何使用多个GPU对其进行微调?

我想要什么:

  • 1-我有一个受tf.estimator训练的模型。
def model_fn_original(feature, labels, mode, params):
  input = tf.layers.xxx
  input = tf.layers.xxx
  logits = tf.layers.dense(input)
  if mode == tf.estimator.ModeKeys.TRAIN:
    return tf.estimator.EstimatorSpec()
  • 2-使用export_meta_graph保存前向图
def export_meta_graph():
  input = tf.layers.xxx
  input = tf.layers.xxx
  logits = tf.layers.dense(input)
  tf.export_meta_graph(meta_file_name)

  • 3-使用import_meta_graph方式替换model_fn中的构建图形代码以微调具有多个GPU的模型
def model_fn_imported(feature, labels, mode, params):
  tf.import_meta_graph(meta_file_name)
  logits = tf.get_defalt_graph().get_tensor_by_name("dense/BiasAdd:0")
  if mode == tf.estimator.ModeKeys.TRAIN:
    return tf.estimator.EstimatorSpec()

estimator = tf.estimator.Estiator(model_fn=model_fn_imported, ...)
estimator.train()

我已经测试过首先通过import_meta_graph导出meta_graph(mode = eval)构建图。 成功进行评估,准确性与原始准确性相同。

微调时,如果我只使用一个GPU,一切都会正常进行。

但是当我使用多个GPU时:

  • 1-使用tf.contrib.estimator.replicate_model_fn

    发生一些错误: 在检查点中找不到tower_1 / model / xxx / xxx / xxx

在线调查后,这可能是由于无法设置import_meta_graph中的变量而导致的。因此,在运行replicate_model_fn并运行model_fn many次后,每次都创建一个新变量。

  • 2-使用tf.contrib.distribute.MirroredStrategy:

    会发生一些错误:目标位置必须是DistributedValues对象之一tf.Variable对象一个设备字符串以及一个设备字符串列表

所以问题是,当使用import_meta_graph构建图形时,是否可以使用多GPU来微调/训练模型。

理论上可以吗?

0 个答案:

没有答案
相关问题