Tensorflow:访问包含数组的变量的索引

时间:2016-10-20 18:32:19

标签: tensorflow

我需要将一些值保存到张量流数组中的特定位置:

import tensorflow as tf
import numpy as np

AVG = tf.Variable([0, 0, 0, 0, 0], name='data')

for i in range(5): 
   data = np.random.randint(1000, size=10000)
   AVG += np.average(data)     

我需要在AVG变量的不同位置对每次迭代进行平均。这可行吗?

1 个答案:

答案 0 :(得分:1)

您可以使用tf.scatter_add。这是一个完整的工作计划:

import tensorflow as tf
import numpy as np

AVG = tf.Variable([0, 0, 0, 0, 0], name='data')

for i in range(5):
  data = np.random.randint(1000, size=10000)
  AVG = tf.scatter_add(AVG, [i], [np.average(data).astype('int')])

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
print(AVG.eval())