TensorFlow中的索引操作

时间:2016-03-06 23:32:56

标签: indexing tensorflow

当我为某些数据进行批量标记时,我有一个用于记录所有计算结果的变量:

p_all = tf.Variable(tf.zeros([batch_num, batch_size]), name = "probability");

在计算中,我有一个循环来处理每个批次:

for i in range(batch_num):
    feed = {x: testDS.test.next_batch(batch_size)}
    sess.run(p_each_batch, feed_dict=feed)

如何将p_each_bach的值复制到p_all

为了更清楚,我想要像:

... ...
p_all[batch_index,:] = p_each_batch
for i in range(batch_num):
    feed = {x: testDS.test.next_batch(batch_size), batch_index: i}
    sess.run(p_all, feed_dict=feed)

如何让这些代码真正起作用?

1 个答案:

答案 0 :(得分:1)

由于p_alltf.Variable,您可以使用tf.scatter_update()操作来更新每个批次中的各个行:

# Equivalent to `p_all[batch_index, :] = p_each_batch`
update_op = tf.scatter_update(p_all,
                              tf.expand_dims(batch_index, 0),
                              tf.expand_dims(p_each_batch, 0)) 

for i in range(batch_num):
    feed = {x: testDS.test.next_batch(batch_size), batch_index: i}
    sess.run(update_op, feed_dict=feed)