我有一个张量nextq
,它是某个问题集的概率分布。我对synthetic_answers
中的每个问题都可以得到0或1的潜在答案。我想通过以下方式更新名为cur_qinput
的张量:
为批次中的每个向量查找nextq
中具有最大值的索引
如果该索引处的synthetic_answers
为1,则将该索引处的cur_qinput
的第三个特征设置为1,否则将第二个特征设置为
这是一些非功能性代码,在for循环中它是非功能性的,因为我不知道其他张量/赋值具有适当的切片张量,为了清楚起见,我只是尝试用python语法编写它意向。
#nextq shape = batch_size x q_size
#nextq_index shape = batch_size
nextq_index = tf.argmax(nextq,axis=1)
#synthetic_answers shape = batch_size x q_size
#cur_qinput shape = batch_size x q_size x 3
#"iterate over batch", doesn't actually work and I guess needs to be done entirely differently
for k in tf.range(tf.shape(nextq_index)[0]):
cur_qinput[k,nextq_index[k],1+synthetic_answers[k,nextq_index[k]]]=1
答案 0 :(得分:1)
让我假设您的数据如下,因为问题中没有示例。
import tensorflow as tf
nextq = tf.constant([[1,5,4],[6,8,10]],dtype=tf.float32)
synthetic_answers = tf.constant([[0,1,1],[1,1,0]],dtype=tf.int32)
cur_qinput = tf.random_normal(shape=(tf.shape(nextq)[0],tf.shape(nextq)[1],3))
首先,您可以使用tf.one_hot
构建mask
来描述该索引处的synthetic_answers
是否等于1
。
nextq_index = tf.argmax(nextq,axis=1)
# [1 2]
nextq_index_hot = tf.one_hot(nextq_index,depth=nextq.shape[1],dtype=tf.int32)
# [[0 1 0]
# [0 0 1]]
mask = tf.logical_and(tf.equal(nextq_index_hot,synthetic_answers),tf.equal(nextq_index_hot,1))
# [[False True False]
# [False False False]]
然后将mask
展开为与cur_qinput
相同的形状。
mask = tf.one_hot(tf.cast(mask,dtype=tf.int32)+1,depth=3)
# [[[0. 1. 0.]
# [0. 0. 1.]
# [0. 1. 0.]]
#
# [[0. 1. 0.]
# [0. 1. 0.]
# [0. 1. 0.]]]
最后,您可以tf.where
将1
分配给cur_qinput
。
scatter = tf.where(tf.equal(mask,1),tf.ones_like(cur_qinput),cur_qinput)
with tf.Session() as sess:
cur_qinput_val,scatter_val = sess.run([cur_qinput,scatter])
print(cur_qinput_val)
print(scatter_val)
[[[ 1.3651905 -0.96688586 0.74061954]
[-1.1236337 -0.6730857 -0.8439895 ]
[-0.52024084 1.1968751 0.79242617]]
[[ 1.4969068 -0.12403865 0.06582119]
[ 0.79385823 -0.7952771 -0.8562217 ]
[-0.05428046 1.4613343 0.2726114 ]]]
[[[ 1.3651905 1. 0.74061954]
[-1.1236337 -0.6730857 1. ]
[-0.52024084 1. 0.79242617]]
[[ 1.4969068 1. 0.06582119]
[ 0.79385823 1. -0.8562217 ]
[-0.05428046 1. 0.2726114 ]]]