在Tensorflow中嵌套while循环并进行分散更新

时间:2019-02-26 19:24:43

标签: tensorflow

变量v1=[[0,0],[0,0]] 张量t1=[[-1,0],[1,1]]

我要输出op=[[1,0],[0,2]]

逻辑: 如果t1==-1,则忽略。否则,请使用t1值作为v1的索引,并向该v1值加1。

Python等效项:

for row in range(len(t1)):
    for col in range(len(t1[row])):
        t1_val=t1[row][col];
        if t1_val!=-1:
            v1[row][t1_val]+=1

我浏览了while循环和分散更新方面的许多问题,但不知道如何解决上述问题。

谢谢

1 个答案:

答案 0 :(得分:1)

您可以尝试tf.map_fn

import tensorflow as tf

v1 = tf.Variable([[0,0],[0,0]], dtype=tf.int32)
t1 = tf.constant([[-1,0],[1,1]], dtype=tf.int32)

result = tf.map_fn(lambda x: x[0]+tf.math.bincount(tf.gather_nd(x[1], tf.where(tf.not_equal(x[1],-1))),minlength=x[0].shape[0])
                   , [v1,t1]
                   , dtype=tf.int32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(result))

# print
[[1 0]
 [0 2]]