有效地在Tensorflow中馈送变量

时间:2019-02-27 21:50:31

标签: tensorflow

我需要给Tensorflow中的一组变量提供一个字典,并实际更改它们的值。我尝试过:sess.run([],feed_dict = feed_dict),运行结果是使用feed_dict执行的,但变量值未更新。为了更新值,我还尝试对每个变量进行“加载”。可以,但是效率极低。加载约20个20M数据变量需要10秒钟。我需要在1秒以内。 (进纸仅需要100毫秒)是否有更有效的方法?

下面是一个示例:

import tensorflow as tf
import numpy as np

variables = []
for i in range(100):
    variables.append(tf.Variable(np.random.rand(100, 100)))

sess = tf.Session()

feed_dict = dict()
for var in variables:
    feed_dict[var] = np.ones((100, 100))

sess.run(tf.initialize_all_variables())



# this is fast but do not actaully load variables (~ 14ms)
sess.run(variables, feed_dict=feed_dict)

# this loads, but is extremely slow (~8s)
data = np.ones((100, 100))
for var in variables:
    var.load(data, sess)

1 个答案:

答案 0 :(得分:1)

命令

sess.run(variables, feed_dict=feed_dict)

不会更新变量的值,它只是运行张量,就好像变量是从feed_dict馈入值的占位符一样。如果要更改变量的值,可以使用tf.assign:

data = np.ones((100, 100))
assg = [tf.assign(var, data) for var in variables]
sess.run(assg)
print(sess.run(variables)) # arrays of 1s