我需要给Tensorflow中的一组变量提供一个字典,并实际更改它们的值。我尝试过:sess.run([],feed_dict = feed_dict),运行结果是使用feed_dict执行的,但变量值未更新。为了更新值,我还尝试对每个变量进行“加载”。可以,但是效率极低。加载约20个20M数据变量需要10秒钟。我需要在1秒以内。 (进纸仅需要100毫秒)是否有更有效的方法?
下面是一个示例:
import tensorflow as tf
import numpy as np
variables = []
for i in range(100):
variables.append(tf.Variable(np.random.rand(100, 100)))
sess = tf.Session()
feed_dict = dict()
for var in variables:
feed_dict[var] = np.ones((100, 100))
sess.run(tf.initialize_all_variables())
# this is fast but do not actaully load variables (~ 14ms)
sess.run(variables, feed_dict=feed_dict)
# this loads, but is extremely slow (~8s)
data = np.ones((100, 100))
for var in variables:
var.load(data, sess)
答案 0 :(得分:1)
命令
sess.run(variables, feed_dict=feed_dict)
不会更新变量的值,它只是运行张量,就好像变量是从feed_dict馈入值的占位符一样。如果要更改变量的值,可以使用tf.assign:
data = np.ones((100, 100))
assg = [tf.assign(var, data) for var in variables]
sess.run(assg)
print(sess.run(variables)) # arrays of 1s