使用具有动态分配值的形状的tf.scatter_nd(idx,update,shape)

时间:2017-10-08 04:22:48

标签: python tensorflow

我想在某个索引处更新向量中的值。我发现tf.scatter_nd可以做我想做的事。但我正在批处理操作,以便我有一个数组batch_size*5,其中每一行都是我更新的大小为5的向量。我的batch_size是在运行时确定的。

在使用tf.scatter_nd时,形状参数采用张量,张量是要产生的张量的形状。但是,如果在运行时确定第一个维度(例如,如果它是批量大小),那么我会收到错误:

TypeError: Input 'shape' of 'ScatterNd' Op has type int32 that does not match type int64 of argument 'indices'.

此错误实际上是由于形状变量的值为[None,5]。即产生一个大小为batch_size*5的张量,用大小batch_size的更新进行更新并使用大小为'batch_size'的标记。

如何在动态分配的空间上正确使用tf.scatter_nd()?

1 个答案:

答案 0 :(得分:1)

您可以使用动态形状,例如tf.shape()

查看代码:

import tensorflow as tf 
import numpy as np

inputs = tf.placeholder(tf.int32, shape=[None, 5])

new_inputs = tf.scatter_nd(indices=[[0], [2]], 
                           updates=[[1,1,1,1,1], [1,1,1,1,1]],
                           shape=tf.shape(inputs))

with tf.Session() as sess:
    _new_np = sess.run(new_inputs, feed_dict={inputs: np.zeros([4, 5])})
print(_new_np)

这是你想要的吗?