我正在尝试对Tensorflow中的张量的前K值进行一些操作。基本上,我想要的是首先得到最高K值的指数,做一些操作并分配新值。例如:
A = tf.constant([[1,2,3,4,5],[6,7,8,9,10]])
values, indices = tf.nn.top_k(A, k=3)
对于这里,值将是数组([[5,4,3,[10,9,8]],dtype = int32)
在对值进行一些操作后,比如说prob = tf.nn.softmax(values),我应该如何根据索引将此值赋给A?这类似于numpy A [indices] = prob。似乎无法在tensorflow中找到正确的功能来执行此操作。
答案 0 :(得分:0)
不幸的是,一旦你想使用张量的索引,Tensorflow就很痛苦,所以为了实现你的想法,你必须使用一些丑陋的解决方法。我的选择是:
import tensorflow as tf
#First you better use Variable as constant is not designed to be updated
A = tf.Variable(initial_value = [[1,2,3,4,5],[6,7,8,9,10]])
#Create a buffer variable which will store tentative updates,
#initialize it with random values
t = tf.Variable(initial_value = tf.cast(tf.random_normal(shape=[5]),dtype=tf.int32))
values, indices = tf.nn.top_k(A, k=3)
#Create a function for manipulation on the values you want
def val_manipulation(v):
return 2*v+1
#Create a while loop to update each entry of the A one-be-one,
#as scatter_nd_update can update only by slices, but not individual entries
i = tf.constant(0)
#Stop once updated every slice
c = lambda i,x: tf.less(i, tf.shape(A)[0])
#Each iteration update i and
#update every slice of A (A[i]) with updated slice
b = lambda i,x: [i+1,tf.scatter_nd_update(A,[[i]],[tf.scatter_update(tf.assign(t,A[i]),indices[i],val_manipulation(values[i]) )])]
#While loop
r = tf.while_loop(c, b, [i,A])
init = tf.initialize_all_variables()
with tf.Session() as s:
s.run(init)
#Test it!
print s.run(A)
s.run(r)
print s.run(A)
所以你基本上做的是:
scatter_update
只能使用变量,因此我们从A(作为A [i])获取切片并将这些值存储到缓冲区变量t
i
- 更新A
t
切片
A
最终你应该得到以下输出:
[[ 1 2 3 4 5] [ 6 7 8 9 10]]
[[ 1 2 7 9 11] [ 6 7 17 19 21]]