Tensorflow:如何通过索引获得张量值并分配新值

时间:2018-06-02 10:59:45

标签: tensorflow keras

我正在尝试对Tensorflow中的张量的前K值进行一些操作。基本上,我想要的是首先得到最高K值的指数,做一些操作并分配新值。例如:

A = tf.constant([[1,2,3,4,5],[6,7,8,9,10]])
values, indices = tf.nn.top_k(A, k=3)

对于这里,值将是数组([[5,4,3,[10,9,8]],dtype = int32)

在对值进行一些操作后,比如说prob = tf.nn.softmax(values),我应该如何根据索引将此值赋给A?这类似于numpy A [indices] = prob。似乎无法在tensorflow中找到正确的功能来执行此操作。

1 个答案:

答案 0 :(得分:0)

不幸的是,一旦你想使用张量的索引,Tensorflow就很痛苦,所以为了实现你的想法,你必须使用一些丑陋的解决方法。我的选择是:

import tensorflow as tf

#First you better use Variable as constant is not designed to be updated
A = tf.Variable(initial_value = [[1,2,3,4,5],[6,7,8,9,10]])

#Create a buffer variable which will store tentative updates,
#initialize it with random values
t = tf.Variable(initial_value = tf.cast(tf.random_normal(shape=[5]),dtype=tf.int32))
values, indices = tf.nn.top_k(A, k=3)

#Create a function for manipulation on the values you want
def val_manipulation(v):
  return 2*v+1

#Create a while loop to update each entry of the A one-be-one, 
#as scatter_nd_update can update only by slices, but not individual entries
i = tf.constant(0)
#Stop once updated every slice
c = lambda i,x: tf.less(i, tf.shape(A)[0])
#Each iteration update i and 
#update every slice of A (A[i]) with updated slice 
b = lambda i,x: [i+1,tf.scatter_nd_update(A,[[i]],[tf.scatter_update(tf.assign(t,A[i]),indices[i],val_manipulation(values[i]) )])]
#While loop
r = tf.while_loop(c, b, [i,A])

init = tf.initialize_all_variables()

with tf.Session() as s:
  s.run(init)
  #Test it!
  print s.run(A)
  s.run(r)
  print s.run(A)

所以你基本上做的是:

  1. scatter_update只能使用变量,因此我们从A(作为A [i])获取切片并将这些值存储到缓冲区变量t
  2. 使用所需的值更新缓冲区变量中的值
  3. 更新i - 更新A
  4. t切片
  5. 重复A
  6. 的其余条目

    最终你应该得到以下输出:

    [[ 1  2  3  4  5]  [ 6  7  8  9 10]] 
    [[ 1  2  7  9 11]  [ 6  7 17 19 21]]