好吧,所以我想在Tensorflow中复制此keep[count] = i
Pytorch代码,这是在tf.while_loop
和count
都是张量的i
内发生的。
keep = tf.Variable(tf.zeros(tf.size(scores), tf.int64))
count = 0
idx = tf.argsort(scores, axis=0)#scores.sort(0) # sort in ascending order
idx = idx[-top_k:]
...
def loop_body(idx, keep, count, ...)
i = idx[-1] # index of current largest val
keep = tf.scatter_update(keep, count, i)
tf.while_loop(loop_cond, loop_body, loop_vars)
这就是我要尝试的keep = tf.scatter_update(keep, count, i)
,不幸的是,这引发了AttributeError: 'Tensor' object has no attribute '_lazy_read