TensorFlow在评估期间变慢

时间:2019-05-25 12:16:12

标签: python tensorflow

我正在训练我的模型,并且通过加载其他模型并为模型设置权重来每隔几个步骤对模型进行评估。这是我的代码

... training my model ....
if step % 1000 == 0:  # start evaluation
    evaluation(args, env, agents, sess, saver)

def evaluation(args, env, agents, sess, saver):
    old_graph = tf.get_default_graph()  # Save the old graph for later iteration

    avg_attack, attacks = 0., []
    # 1. Load Opponent model from other model
    opp_prefix = ['agent_%s/p_func' % i for i in list(range(env.n))[env.n-args.opp_agent_num:]]
    opp_ckpt_path = tf.train.latest_checkpoint(args.opponent_ckpt_dir)
    opp_weights = load_weights(opp_ckpt_path, opp_prefix)
    # 2. Set the weights from other model to the training model
    assign_ops = [tf.assign(tf.get_default_graph().get_tensor_by_name(_name), _value)
                  for _name, _value in opp_weights.items()]
    sess.run(assign_ops)
    # 3. Evaluation
    step, episode = 0, 0
    obs_n = env.reset()
    for some steps:
        # Get action
        do something here.... 

    # 4. Restore the training model from the saved latest ckpt data
    # because some parts of the training model has been changed.
    # ckpt data is saved every time steps.
    new_graph = tf.Graph()  # Create an empty graph
    new_graph.as_default()  # Makes the new graph default
    ckpt_path = tf.train.latest_checkpoint(args.save_dir)
    U.load_state(ckpt_path, saver)
    # Clear data
    del opp_weights
    return avg_attack / float(args.battle_episodes)


def load_weights(ckpt_path, prefix_list):
    """
    load weights from ckpt file and return weights in Numpy arry.
    Please give prefix_list, namely variable scope prefix to get the corresponding weights, 
        otherwise, it will return a empty dict.
    """
    vars_weights = {}
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in sorted(var_to_shape_map):
        for _pref in prefix_list:
            if key.startswith(_pref):
                vars_weights[key+':0'] = reader.get_tensor(key)
    return vars_weights

但是,随着培训的进行,时间成本增加了。通过删除评估部件永远不会发生此问题。我想随着训练的进行,assign_ops会增加。将操作添加到TensorFlow图中会增加时间成本。

我该如何解决?

0 个答案:

没有答案