我有一个TensorFlow模型,我希望能够选择检索和更新权重(在检查点机制之外)。由于我可能会多次这样做,所以每当我这样做时,我都不想在图表中添加节点,而是让我可以调用一些操作来完成它。我的想法是让tf.assign
ops对更新的变量和值使用相同的变量,并在feed_dict
中提供新的权重值;所以像这样:
weights_assigns = [tf.assign(v, v) for v in tf.trainable_variables()]
weights_update = tf.group(*weights_assigns)
# Now to update the weights
weights = [...] # List of new weight values
feed_dict = {v: w for v, w in zip(tf.trainable_variables(), weights)}
tf.run(weights_update_op, feed_dict=feed_dict)
在我看来,这应该将feed_dict
中传递的值作为变量的当前值,然后通过tf.assign
操作存储它们。但是,这不起作用,并且给出了一些关于意外值类型的奇怪错误。
我目前的替代方案是改为使用一些辅助节点(变量或占位符),并在赋值操作中用作值:
weights_updates = [tf.placeholder(v.dtype, v.get_shape()) for v in tf.trainable_variables()]
weights_assigns = [tf.assign(v, u) for v, u in zip(tf.trainable_variables(), weights_updates)]
weights_update_op = tf.group(*weights_assigns)
# Now to update the weights
weights = [...] # List of new weight values
feed_dict = {u: w for u, w in zip(weights_updates, weights)}
tf.run(weights_update_op, feed_dict=feed_dict)
这真的是唯一的方法吗?或者还有其他一些我没有看到的明显方式吗?
答案 0 :(得分:1)
好的,发现有一个load()方法:
with tf.Graph().as_default():
w = tf.get_variable('weights', shape=[3, 3],
initializer=tf.random_uniform_initializer(dtype=tf.float32))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print('initial W:')
print(sess.run(w))
new_vals = np.reshape(np.arange(9, dtype=np.float32), (3,3))
w.load(new_vals)
print('updated W:')
print(sess.run(w))
也许这有帮助?