我有一个包含多个模块的张量流图,我想重用其中一个以前的网络状态(在参数更新之前)来评估下一个状态的输入 (参数更新后)。
示例
考虑我希望在时间步长network_B
基本上复制t
的玩具示例,以便在下一个培训步骤t+1
中使用:
def network_A(x):
A1 = tf.matmul(x, A_W1) + A_b1
return tf.nn.relu(A1)
def network_B(x):
B1 = tf.matmul(x, B_W1) + B_b1
Z1 = tf.nn.relu(B1)
B2 = tf.matmul(Z1, B_W2) + B_b2
return B2
x = tf.placeholder(tf.float32, shape=[None, x_dim])
x_2 = network_A(x)
# Evaluate input x_2 with current state of network
y_hatB_current = network_B(x)
# Evaluate same input x_2 with past state of network
y_hatB_past = network_B_past(x) #
# Get some loss
loss = ...
然后,一旦评估了两者,将当前网络状态保存为新的过去状态,并仅优化当前状态:
# Save state of parameters
network_B_past = network_B # (How do I do this efficiently?)
# Optimize the current state
train = tf.train.AdamOptimizer().minimize(loss, var_list=current_vars)
详情
因此,在每个培训步骤中,应该有两个network_B
版本可用于评估输入:
network_B
在时间步长t-1
(过去的州)network_B
at timetep t
(当前状态)在两个训练步骤之间有一个参数更新,所以两者之间的权重应该略有不同,但它们应该相同。然后,在评估新输入之后,当前状态替换过去状态,并且更新网络发生另一个训练步骤。
我知道我可以在tensorflow中保存和重新加载检查点,但这对我的用例来说似乎效率太低,因为它需要在每个训练步骤中发生。实现此网络克隆步骤的有效方法是什么,以便我维护一个跨州的副本?
Tensorflow版本:1.5
答案 0 :(得分:1)
我会在不同的变量范围下使用函数create_graph
创建网络两次:一个用于当前,一个用于备份。请注意,这会使内存消耗增加一倍。
然后您需要的只是自定义sync_op
。 MWE是
import tensorflow as tf
def copy_vars(src_scope, dst_scope):
src_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=src_scope)
dst_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=dst_scope)
update_op = []
for src_var in src_vars:
for dst_var in dst_vars:
if src_var.name.replace('%s' % src_scope, '') == dst_var.name.replace('%s' % dst_scope, ''):
assert dst_var.shape == src_var.shape
print(" copy: add assign {} -> {}".format(src_var.name, dst_var.name))
update_op.append(dst_var.assign(src_var))
return tf.group(update_op)
def create_graph(name, x, use_c=False, uses_gradient_updates=True):
var_setter = lambda x: x # noqa
if uses_gradient_updates:
var_setter = lambda x: tf.stop_gradient(x) # noqa
with tf.variable_scope(name, custom_getter=var_setter):
a = tf.Variable([1], dtype=tf.float32)
b = tf.Variable([1], dtype=tf.float32)
result = x + a + b
if use_c:
# create dummy variable just to show both graphs do not need to be exactly the same
c = tf.Variable([1], dtype=tf.float32)
return result, a, b
x = tf.placeholder(tf.float32)
c1, a1, b1 = create_graph('original', x, use_c=True)
c2, a2, b2 = create_graph('backup', x, use_c=False)
sync_op = copy_vars('original', 'backup')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([c1, c2], {x: 5})) # in sync
sess.run(a1.assign([3])) # update your graph either by tf.train.Adam or by:
print(sess.run([c1, c2], {x: 5})) # out of sync
sess.run(sync_op) # do syncing
print(sess.run([c1, c2], {x: 5})) # in sync
custom_getter
有助于防止渐变更新。