将张量切片与其他张量一起使用以分配值

时间:2019-05-17 00:09:41

标签: python tensorflow

我有一个张量nextq,它是某个问题集的概率分布。我对synthetic_answers中的每个问题都可以得到0或1的潜在答案。我想通过以下方式更新名为cur_qinput的张量:

  1. 为批次中的每个向量查找nextq中具有最大值的索引

  2. 如果该索引处的synthetic_answers为1,则将该索引处的cur_qinput的第三个特征设置为1,否则将第二个特征设置为

这是一些非功能性代码,在for循环中它是非功能性的,因为我不知道其他张量/赋值具有适当的切片张量,为了清楚起见,我只是尝试用python语法编写它意向。

#nextq shape =  batch_size x q_size
#nextq_index shape =  batch_size
nextq_index = tf.argmax(nextq,axis=1)


#synthetic_answers shape =  batch_size x q_size
#cur_qinput shape = batch_size x q_size x 3

#"iterate over batch", doesn't actually work and I guess needs to be done entirely differently
for k in tf.range(tf.shape(nextq_index)[0]):
    cur_qinput[k,nextq_index[k],1+synthetic_answers[k,nextq_index[k]]]=1

1 个答案:

答案 0 :(得分:1)

让我假设您的数据如下,因为问题中没有示例。

import tensorflow as tf

nextq = tf.constant([[1,5,4],[6,8,10]],dtype=tf.float32)
synthetic_answers = tf.constant([[0,1,1],[1,1,0]],dtype=tf.int32)
cur_qinput = tf.random_normal(shape=(tf.shape(nextq)[0],tf.shape(nextq)[1],3))

首先,您可以使用tf.one_hot构建mask来描述该索引处的synthetic_answers是否等于1

nextq_index = tf.argmax(nextq,axis=1)
# [1 2]
nextq_index_hot = tf.one_hot(nextq_index,depth=nextq.shape[1],dtype=tf.int32)
# [[0 1 0]
#  [0 0 1]]
mask = tf.logical_and(tf.equal(nextq_index_hot,synthetic_answers),tf.equal(nextq_index_hot,1))
# [[False  True False]
#  [False False False]]

然后将mask展开为与cur_qinput相同的形状。

mask = tf.one_hot(tf.cast(mask,dtype=tf.int32)+1,depth=3)
# [[[0. 1. 0.]
#   [0. 0. 1.]
#   [0. 1. 0.]]
#
#  [[0. 1. 0.]
#   [0. 1. 0.]
#   [0. 1. 0.]]]

最后,您可以tf.where1分配给cur_qinput

scatter = tf.where(tf.equal(mask,1),tf.ones_like(cur_qinput),cur_qinput)

with tf.Session() as sess:
    cur_qinput_val,scatter_val = sess.run([cur_qinput,scatter])
    print(cur_qinput_val)
    print(scatter_val)
[[[ 1.3651905  -0.96688586  0.74061954]
  [-1.1236337  -0.6730857  -0.8439895 ]
  [-0.52024084  1.1968751   0.79242617]]

 [[ 1.4969068  -0.12403865  0.06582119]
  [ 0.79385823 -0.7952771  -0.8562217 ]
  [-0.05428046  1.4613343   0.2726114 ]]]
[[[ 1.3651905   1.          0.74061954]
  [-1.1236337  -0.6730857   1.        ]
  [-0.52024084  1.          0.79242617]]

 [[ 1.4969068   1.          0.06582119]
  [ 0.79385823  1.         -0.8562217 ]
  [-0.05428046  1.          0.2726114 ]]]