如何在参数更新后复制内存中的张量流网络状态以进行检索?

时间:2018-02-01 05:32:47

标签: python tensorflow

我有一个包含多个模块的张量流图,我想重用其中一个以前的网络状态(在参数更新之前)来评估下一个状态的输入 (参数更新后)。

示例

考虑我希望在时间步长network_B基本上复制t的玩具示例,以便在下一个培训步骤t+1中使用:

def network_A(x):
    A1 = tf.matmul(x, A_W1) + A_b1
    return tf.nn.relu(A1)

def network_B(x):
    B1 = tf.matmul(x, B_W1) + B_b1
    Z1 = tf.nn.relu(B1)
    B2 = tf.matmul(Z1, B_W2) + B_b2
    return B2

x = tf.placeholder(tf.float32, shape=[None, x_dim])
x_2 = network_A(x)

# Evaluate input x_2 with current state of network
y_hatB_current = network_B(x)

# Evaluate same input x_2 with past state of network
y_hatB_past = network_B_past(x) # 

# Get some loss
loss = ...

然后,一旦评估了两者,将当前网络状态保存为新的过去状态,并仅优化当前状态:

# Save state of parameters
network_B_past = network_B  # (How do I do this efficiently?)

# Optimize the current state
train = tf.train.AdamOptimizer().minimize(loss, var_list=current_vars)

详情

因此,在每个培训步骤中,应该有两个network_B版本可用于评估输入:

  1. network_B在时间步长t-1(过去的州)
  2. network_B at timetep t(当前状态)
  3. 在两个训练步骤之间有一个参数更新,所以两者之间的权重应该略有不同,但它们应该相同。然后,在评估新输入之后,当前状态替换过去状态,并且更新网络发生另一个训练步骤。

    我知道我可以在tensorflow中保存和重新加载检查点,但这对我的用例来说似乎效率太低,因为它需要在每个训练步骤中发生。实现此网络克隆步骤的有效方法是什么,以便我维护一个跨州的副本?

    Tensorflow版本:1.5

1 个答案:

答案 0 :(得分:1)

我会在不同的变量范围下使用函数create_graph创建网络两次:一个用于当前,一个用于备份。请注意,这会使内存消耗增加一倍。

然后您需要的只是自定义sync_op。 MWE是

import tensorflow as tf


def copy_vars(src_scope, dst_scope):
    src_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=src_scope)
    dst_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=dst_scope)

    update_op = []

    for src_var in src_vars:
        for dst_var in dst_vars:
            if src_var.name.replace('%s' % src_scope, '') == dst_var.name.replace('%s' % dst_scope, ''):
                assert dst_var.shape == src_var.shape
                print("  copy: add assign {} -> {}".format(src_var.name, dst_var.name))
                update_op.append(dst_var.assign(src_var))
    return tf.group(update_op)


def create_graph(name, x, use_c=False, uses_gradient_updates=True):

    var_setter = lambda x: x  # noqa
    if uses_gradient_updates:
        var_setter = lambda x: tf.stop_gradient(x)  # noqa

    with tf.variable_scope(name, custom_getter=var_setter):
        a = tf.Variable([1], dtype=tf.float32)
        b = tf.Variable([1], dtype=tf.float32)
        result = x + a + b
        if use_c:
            # create dummy variable just to show both graphs do not need to be exactly the same
            c = tf.Variable([1], dtype=tf.float32)
        return result, a, b

x = tf.placeholder(tf.float32)

c1, a1, b1 = create_graph('original', x, use_c=True)
c2, a2, b2 = create_graph('backup', x, use_c=False)

sync_op = copy_vars('original', 'backup')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print(sess.run([c1, c2], {x: 5}))  # in sync

    sess.run(a1.assign([3]))  # update your graph either by tf.train.Adam or by:
    print(sess.run([c1, c2], {x: 5}))  # out of sync

    sess.run(sync_op)  # do syncing
    print(sess.run([c1, c2], {x: 5}))  # in sync

custom_getter有助于防止渐变更新。