如何使用`MonitoredTrainingSession` /`Scaffold`微调模型

时间:2017-07-14 15:50:00

标签: python tensorflow

我想恢复VGG_19的模型参数,该参数用作附加的新初始化图形的特征​​提取器,并在分布式设置中训练所有内容。

如果我使用slim.learning.train,一切正常,但我无法使用Scaffold所需的tf.train.MonitoredTrainingSession。如果我将restore_fn(使用tf.contrib.framework.assign_from_checkpoint_fn as in documentaiton创建)作为init_fn传递到Scaffold我正在获取 TypeError: callback() takes 1 positional argument but 2 were given

我试过"修复"通过传递lambda scaffold, sess: restore_fn(sess)

如果我尝试创建还原运算符并将其作为init_op传递(使用tf.contrib.slim.assign_from_checkpoint创建

INFO:tensorflow:Create CheckpointSaverHook.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
    267         self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 268             fetch, allow_tensor=True, allow_operation=True))
    269       except TypeError as e:

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operatio
n)
   2608     if self._finalized:
-> 2609       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   2610

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_
operation)
   2700       raise TypeError("Can not convert a %s into a %s."
-> 2701                       % (type(obj).__name__, types_str))
   2702

TypeError: Can not convert a ndarray into a Tensor or Operation.

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
/ScanAvoidanceML/ScanAvoidanceML/datasets/project_daphnis/train.py in <module>()
    129         )
    130         FLAGS, unparsed = parser.parse_known_args()
--> 131         tf.app.run(main=train, argv=[sys.argv[0]] + unparsed)

/opt/conda/lib/python3.6/site-packages/tensorflow/python/platform/app.py in run(main, argv)
     46   # Call the main function, passing through any arguments
     47   # to the final program.
---> 48   _sys.exit(main(_sys.argv[:1] + flags_passthrough))
     49
     50

/ScanAvoidanceML/ScanAvoidanceML/datasets/project_daphnis/train.py in train(_)
     83                 scaffold=tf.train.Scaffold(
     84                     init_op=restore_op,
---> 85                     summary_op=tf.summary.merge_all())) as mon_sess:
     86             while not mon_sess.should_stop():
     87                 # Run a training step asynchronously.

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in MonitoredTrainingSession(master, is_chief,
checkpoint_dir, scaffold, hooks, chief_only_hooks, save_checkpoint_secs, save_summaries_steps, save_summaries_secs, config, stop_grac
e_period_secs, log_step_count_steps)
    351     all_hooks.extend(hooks)
    352   return MonitoredSession(session_creator=session_creator, hooks=all_hooks,
--> 353                           stop_grace_period_secs=stop_grace_period_secs)
    354
    355

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in __init__(self, session_creator, hooks, stop
_grace_period_secs)
    654     super(MonitoredSession, self).__init__(
    655         session_creator, hooks, should_recover=True,
--> 656         stop_grace_period_secs=stop_grace_period_secs)
    657
    658

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in __init__(self, session_creator, hooks, shou
ld_recover, stop_grace_period_secs)
    476         stop_grace_period_secs=stop_grace_period_secs)
    477     if should_recover:
--> 478       self._sess = _RecoverableSession(self._coordinated_creator)
    479     else:
    480       self._sess = self._coordinated_creator.create_session()


/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in __init__(self, sess_creator)
    828     """
    829     self._sess_creator = sess_creator
--> 830     _WrappedSession.__init__(self, self._create_session())
    831
    832   def _create_session(self):

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in _create_session(self)
    833     while True:
    834       try:
--> 835         return self._sess_creator.create_session()
    836       except _PREEMPTION_ERRORS as e:
    837         logging.info('An error was raised while a session was being created. '

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in create_session(self)
    537       """Creates a coordinated session."""
    538       # Keep the tf_sess for unit testing.
--> 539       self.tf_sess = self._session_creator.create_session()
    540       # We don't want coordinator to suppress any exception.
    541       self.coord = coordinator.Coordinator(clean_stop_exception_types=[])

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/monitored_session.py in create_session(self)
    411         init_op=self._scaffold.init_op,
    412         init_feed_dict=self._scaffold.init_feed_dict,
--> 413         init_fn=self._scaffold.init_fn)
    414
    415

/opt/conda/lib/python3.6/site-packages/tensorflow/python/training/session_manager.py in prepare_session(self, master, init_op, saver,
 checkpoint_dir, checkpoint_filename_with_path, wait_for_checkpoint, max_wait_secs, config, init_feed_dict, init_fn)
    277                            "init_fn or local_init_op was given")
    278       if init_op is not None:
--> 279         sess.run(init_op, feed_dict=init_feed_dict)
    280       if init_fn:
    281         init_fn(sess)

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    894     try:
    895       result = self._run(None, fetches, feed_dict, options_ptr,
--> 896                          run_metadata_ptr)
    897       if run_metadata:
    898         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_met
adata)
   1107     # Create a fetch handler to take care of the structure of fetches.
   1108     fetch_handler = _FetchHandler(
-> 1109         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1110
   1111     # Run request and get response.

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
    409     """
    410     with graph.as_default():
--> 411       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    412     self._fetches = []
    413     self._targets = []

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
    229     elif isinstance(fetch, (list, tuple)):
    230       # NOTE(touts): This is also the code path for namedtuples.
--> 231       return _ListFetchMapper(fetch)
    232     elif isinstance(fetch, dict):
    233       return _DictFetchMapper(fetch)

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches)
    336     """
    337     self._fetch_type = type(fetches)
--> 338     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    339     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    340

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in <listcomp>(.0)
    336     """
    337     self._fetch_type = type(fetches)
--> 338     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    339     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    340

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
    231       return _ListFetchMapper(fetch)
    232     elif isinstance(fetch, dict):
--> 233       return _DictFetchMapper(fetch)
    234     else:
    235       # Look for a handler in the registered expansions.

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches)
    369     self._keys = fetches.keys()
    370     self._mappers = [_FetchMapper.for_fetch(fetch)
--> 371                      for fetch in fetches.values()]
    372     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    373

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
    237         if isinstance(fetch, tensor_type):
    238           fetches, contraction_fn = fetch_fn(fetch)
--> 239           return _ElementFetchMapper(fetches, contraction_fn)
    240     # Did not find anything.
    241     raise TypeError('Fetch argument %r has invalid type %r' %

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches)
    369     self._keys = fetches.keys()
    370     self._mappers = [_FetchMapper.for_fetch(fetch)
--> 371                      for fetch in fetches.values()]
    372     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    373

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in <listcomp>(.0)
    369     self._keys = fetches.keys()
    370     self._mappers = [_FetchMapper.for_fetch(fetch)
--> 371                      for fetch in fetches.values()]
    372     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    373

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
    237         if isinstance(fetch, tensor_type):
    238           fetches, contraction_fn = fetch_fn(fetch)
--> 239           return _ElementFetchMapper(fetches, contraction_fn)
    240     # Did not find anything.
    241     raise TypeError('Fetch argument %r has invalid type %r' %

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
    270         raise TypeError('Fetch argument %r has invalid type %r, '
    271                         'must be a string or Tensor. (%s)'
--> 272                         % (fetch, type(fetch), str(e)))
    273       except ValueError as e:
    274         raise ValueError('Fetch argument %r cannot be interpreted as a '

TypeError: Fetch argument array([[[[ 0.39416704, -0.08419707, -0.03631314, ..., -0.10720515,
          -0.03804016,  0.04690642],
         [ 0.46418372,  0.03355668,  0.10245045, ..., -0.06945956,
          -0.04020201,  0.04048637],
         [ 0.34119523,  0.09563112,  0.0177449 , ..., -0.11436455,
          -0.05099866, -0.00299793]],

        [[ 0.37740308, -0.07876257, -0.04775979, ..., -0.11827433,
          -0.19008617, -0.01889699],
         [ 0.41810837,  0.05260524,  0.09755926, ..., -0.09385028,
          -0.20492788, -0.0573062 ],
         [ 0.33999205,  0.13363543,  0.02129423, ..., -0.13025227,
          -0.16508926, -0.06969624]],

        [[-0.04594866, -0.11583115, -0.14462094, ..., -0.12290562,
          -0.35782176, -0.27979308],
         [-0.04806903, -0.00658076, -0.02234544, ..., -0.0878844 ,
          -0.3915486 , -0.34632796],
         [-0.04484424,  0.06471398, -0.07631404, ..., -0.12629718,
          -0.29905206, -0.28253639]]],


       [[[ 0.2671299 , -0.07969447,  0.05988706, ..., -0.09225675,
           0.31764674,  0.42209673],
         [ 0.30511212,  0.05677647,  0.21688674, ..., -0.06828708,
           0.3440761 ,  0.44033417],
         [ 0.23215917,  0.13365699,  0.12134422, ..., -0.1063385 ,
           0.28406844,  0.35949969]],

        [[ 0.09986369, -0.06240906,  0.07442063, ..., -0.02214639,
           0.25912452,  0.42349899],
         [ 0.10385381,  0.08851637,  0.2392226 , ..., -0.01210995,
           0.27064082,  0.40848857],
         [ 0.08978214,  0.18505956,  0.15264879, ..., -0.04266965,
           0.25779948,  0.35873157]],

        [[-0.34100872, -0.13399366, -0.11510294, ..., -0.11911335,
          -0.23109646, -0.19202407],
         [-0.37314063, -0.00698938,  0.02153259, ..., -0.09827439,
          -0.2535741 , -0.25541356],
         [-0.30331427,  0.08002605, -0.03926321, ..., -0.12958746,
          -0.19778992, -0.21510386]]],


       [[[-0.07573577, -0.07806503, -0.03540679, ..., -0.1208065 ,
           0.20088433,  0.09790061],
         [-0.07646758,  0.03879711,  0.09974211, ..., -0.08732687,
           0.2247974 ,  0.10158388],
         [-0.07260918,  0.10084777,  0.01313597, ..., -0.12594968,
           0.14647409,  0.05009392]],

        [[-0.28034249, -0.07094654, -0.0387974 , ..., -0.08843154,
           0.18996507,  0.07766484],
         [-0.31070709,  0.06031388,  0.10412455, ..., -0.06832542,
           0.20279962,  0.05222717],
         [-0.246675  ,  0.1414054 ,  0.02605635, ..., -0.10128672,
           0.16340195,  0.02832468]],

        [[-0.41602272, -0.11491341, -0.14672887, ..., -0.13079506,
          -0.1379628 , -0.26588449],
         [-0.46453714, -0.00576723, -0.02660675, ..., -0.10017379,
          -0.15603794, -0.32566148],
         [-0.33683276,  0.06601517, -0.08144748, ..., -0.13460518,
          -0.1342358 , -0.27096185]]]], dtype=float32) has invalid type <class 'numpy.ndarray'>, must be a string or Tensor. (Can not
 convert a ndarray into a Tensor or Operation.)

我也尝试使用local_init_op,但是没有用。 我的代码:

import sys
import tensorflow as tf
slim = tf.contrib.slim
import argparse
import model as M
import decoder as D


FLAGS = None


def train(_):
    vgg_19_ckpt_path='/media/data/projects/project_daphnis/pretrained_models/vgg_19.ckpt'
    train_log_dir = "/media/data/projects/project_daphnis/train_log_dir"

    ps_hosts = FLAGS.ps_hosts.split(",")
    worker_hosts = FLAGS.worker_hosts.split(",")

    # Create a cluster from the parameter server and worker hosts.
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    # Create and start a server for the local task.
    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        server.join()
    elif FLAGS.job_name == "worker":
        if not tf.gfile.Exists(train_log_dir):
            tf.gfile.MakeDirs(train_log_dir)

        # Assigns ops to the local worker by default.
        with tf.device(tf.train.replica_device_setter(
                worker_device="/job:worker/task:%d" % FLAGS.task_index,
                cluster=cluster)):

            # Set up the data loading:
            image, c, p, s = \
                D.get_training_dataset_data_provider()

            image, c, p, s = \
                tf.train.batch([image, c, p, s],
                               batch_size=16)

            # Define the model:
            predictions, loss, end_points = M.model_as_in_paper(
                image, c, p, s
            )

            restore_fn = tf.contrib.framework.assign_from_checkpoint_fn(
                vgg_19_ckpt_path,
                var_list=slim.get_variables_to_restore(include=["vgg_19"],
                                                       exclude=[
                                                           'vgg_19/conv4_3_X',
                                                           'vgg_19/conv4_4_X']
                                                       )
            )


            # Specify the optimization scheme:
            optimizer = tf.train.AdamOptimizer(learning_rate=.00001)

            # create_train_op that ensures that when we evaluate it to get the loss,
            # the update_ops are done and the gradient updates are computed.
            train_op = slim.learning.create_train_op(loss, optimizer)
        tf.summary.scalar("losses/total_loss", loss)

        # The StopAtStepHook handles stopping after running given steps.
        hooks = [tf.train.StopAtStepHook(last_step=1000000)]

        # The MonitoredTrainingSession takes care of session initialization,
        # restoring from a checkpoint, saving to a checkpoint, and closing when done
        # or an error occurs.
        with tf.train.MonitoredTrainingSession(
                master=server.target,
                is_chief=(FLAGS.task_index == 0),
                checkpoint_dir=train_log_dir,
                hooks=hooks,
                scaffold=tf.train.Scaffold(
                    init_fn=restore_fn,
                    summary_op=tf.summary.merge_all())) as mon_sess:
            while not mon_sess.should_stop():
                # Run a training step asynchronously.
                # See `tf.train.SyncReplicasOptimizer` for additional details on how to
                # perform *synchronous* training.
                # mon_sess.run handles AbortedError in case of preempted PS.
                mon_sess.run(train_op)
        #
        # # Actually runs training.
        # slim.learning.train(train_tensor,
        #                     train_log_dir,
        #                     init_fn=restore_fn,
        #                     summary_op=tf.summary.merge_all(),
        #                     is_chief=False)

if __name__ == "__main__":
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.register("type", "bool", lambda v: v.lower() == "true")
        # Flags for defining the tf.train.ClusterSpec
        parser.add_argument(
            "--ps_hosts",
            type=str,
            default="",
            help="Comma-separated list of hostname:port pairs"
        )
        parser.add_argument(
            "--worker_hosts",
            type=str,
            default="",
            help="Comma-separated list of hostname:port pairs"
        )
        parser.add_argument(
            "--job_name",
            type=str,
            default="",
            help="One of 'ps', 'worker'"
        )
        # Flags for defining the tf.train.Server
        parser.add_argument(
            "--task_index",
            type=int,
            default=0,
            help="Index of task within the job"
        )
        FLAGS, unparsed = parser.parse_known_args()
        tf.app.run(main=train, argv=[sys.argv[0]] + unparsed)

1 个答案:

答案 0 :(得分:0)

答案是使用保护程序恢复参数并包装saver.restore函数,以便它可以用作init_fn的{​​{1}}。这个包装器必须有两个参数:Scaffoldscaffold,其中sess用于恢复参数,sess被丢弃。

完整代码:

scaffold