脾气暴躁的情况:
p
是形状为(5000,2)
的np数组,因此我可以将维度2中的两列用作形状为m
的另一个np数组(100,100,1)
的索引。这是示例代码:
m = np.zeros(shape=(100,100,1))
for i,j in zip(p[:,0], p[:,1]): # i and j are between 0 and 99
m[int(i),int(j),:] = 155
现在我想将其转换为张量流版本。令人困惑的是,还有一个维度是批量的,因此张量p
变成(B,5000,2)
,目标张量m
变成(B,100,100,1)
。当前代码是:
p = tf.cast(p, dtype=tf.int32)
m = tf.zeros([batch_size, 100, 100, 1])
m[:, p[:, :, 0], p[:, :, 1], :].assign(155) # here is wrong
但是很明显,直接使用p
作为索引时会发生错误。我相信应该有一些替代方法来实现这一目标,例如tf.gather_nd
或tf.stack_nd
,但我只是不知道细节。请帮忙。预先感谢!