如何在不将值保存到磁盘的情况下将Tensors恢复为过去的值?

时间:2017-09-24 19:26:10

标签: python machine-learning tensorflow

我正在使用TensorFlow进行一些实验,并遇到了麻烦。我试图使用TF来评估模型中的变化,然后根据损失函数的结果变化保留或恢复模型。我已经找到了困难的部分(条件控制),但我仍然坚持应该相当简单的事情:我似乎无法存储tf.trainable_variables进行迭代,然后在需要时恢复它。

让我们说建立一个操作:

...
store_trainable_vars = []

for v in tf.trainable_variables():

    store_trainable_vars.append(v)
...

然后,我想将tf.trainable_variables恢复为上次运行此操作时的值。我想做类似的事情:

 def reject_move():

    revert_state = []

    for (v, s) in zip(tf.trainable_variables(), store_trainable_vars):

        revert_state.append(tf.assign(v, s, name="revert_state"))

    return(revert_state)

显然,这会重新评估store_trainable_vars,而tf.trainable_variables()又会链接到revert_state的当前值,从而避免使用... store_trainable_vars = [] for v in tf.trainable_variables(): store_trainable_vars.append(v.value_right_now()) ... Op。我需要一些方法来存储和检索Tensors的值,而不回调那些Tensors的现值。像

这样的东西
v.value_right_now()

其中window.intercomSettings = { app_id: "xxxxxx", "utm_source" : {{UTM Source}}, "utm_medium" : {{UTM Medium}}, "utm_campaign" : {{UTM Campaign}} }; 返回一个不会改变的常量,直到被覆盖。

我知道我可以使用Saver,但该解决方案会写入磁盘,这对于此应用程序是不可接受的,因为它将在训练循环内运行。

我可能遗漏了一些明显的东西 - 任何指导都会受到赞赏。

2 个答案:

答案 0 :(得分:5)

要手动恢复图形状态,您需要使用tf.tupletf.group操作,这将修改批量更改的流程:

  

这会创建一个具有与张量相同值的张量元组   参数,除了每个张量的值只在之后返回   已经计算了所有张量的值。

[更新] 以下是我的方法:

import numpy as np
import tensorflow as tf

x = tf.placeholder(shape=[None, 5], dtype=tf.float32, name='x')
W = tf.Variable(np.zeros([5, 5]), dtype=tf.float32, name='W')
b = tf.Variable(np.zeros([5]), dtype=tf.float32, name='b')
y = tf.add(tf.matmul(x, W), b)

with tf.Session() as session:
  batch = np.ones([2, 5])
  session.run(tf.global_variables_initializer())
  print session.run(y, feed_dict={x: batch})      # prints [2, 5] zeros

  # store the current value
  store = {v.name: v.eval(session) for v in tf.trainable_variables()}
  print store                                     # prints [5, 5] and [5] zeros

  # update
  new = {'W:0': np.ones([5, 5]), 'b:0': np.ones([5])}
  session.run(tf.tuple([tf.assign(var, new[var.name]) for var in tf.trainable_variables()]))
  print session.run(y, feed_dict={x: batch})      # prints [2, 5] sixes

  # restore
  session.run(tf.tuple([tf.assign(var, store[var.name]) for var in tf.trainable_variables()]))
  print session.run(y, feed_dict={x: batch})      # prints [2, 5] zeros again

但我真的认为你应该重新考虑你对Saver的决定,因为它也被设计用于训练循环中。在内部,Saver为您完成所有棘手的工作(特别是,如果需要,它会恢复操作tf.grouptf.control_dependencies),否则可能会成为非常讨厌的错误的来源。此外,磁盘(几乎)总是比GPU和主内存大,所以如果你能负担得起将模型存储在内存中,你也应该能够存储在磁盘上。

以下some parameters有助于控制磁盘上检查点文件的扩散:

  • max_to_keep表示最近的检查点文件的最大数量 保持。创建新文件时,将删除旧文件。如果为None或0,则保留所有检查点文件。默认为5(即最近的5个 保留检查点文件。
  • keep_checkpoint_every_n_hours:除了保持最新状态 max_to_keep检查点文件,您可能希望保留一个检查点文件 每N小时的训练。如果您想稍后这可能很有用 分析模型在长时间培训期间的进展情况。对于 例如,传递keep_checkpoint_every_n_hours=2可确保每2小时训练一次保留一个检查点文件。默认值10,000小时有效地禁用了该功能。

[更新] 正如评论中所阐明的那样,主要关注的是磁盘延迟,如果过于频繁访问,可能会降低培训速度。如果您使用的是Linux,它caches经常使用磁盘页面,Windows does it。但如果您想绝对确定,请考虑使用tmpfs

答案 1 :(得分:1)

我自己原本不打算回答这个问题,但我想出了一种效果相当好的方法。所以,我以为我会分享它。关键见解来自this非常聪明的答案。方法是重用为初始变量赋值创建的赋值节点。下面给出了实现该方法的完整类。

import tensorflow as tf


class TensorFlowState(object):

    def __init__(self):

        # Get the graph.
        graph = tf.get_default_graph()

        # Extract the global varibles from the graph.
        self.gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

        # Exract the Assign operations for later use.
        self.assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign")
                           for v in self.gvars]

        # Extract the initial value ops from each Assign op for later use.
        self.init_values = [op.inputs[1] for op in self.assign_ops]

    def start(self, sess):

        self.sess = sess

    def store(self):

        # Record the current state of the TF global varaibles
        self.state = self.sess.run(self.gvars)

    def restore(self):
    # Create a dictionary of the iniailizers and stored state of globals.
    feed_dict = {init_value: val
                 for init_value, val in zip(self.init_values, self.state)}

    # Use the initializer ops for each variable to load the stored values.
    return(self.sess.run(self.assign_ops, feed_dict=feed_dict))

要使用,只需实例化该类,调用start方法传递tf.Session,并在必要的训练循环中根据需要调用storerestore方法。我已经使用这个实现来构建一个优化器,它的运行速度与TensorFlow中包含的梯度下降优化器一样快。