将张量中某个维度的值用作另一个张量中的二维索引

时间:2019-07-17 06:27:05

标签: tensorflow

脾气暴躁的情况:

p是形状为(5000,2)的np数组,因此我可以将维度2中的两列用作形状为m的另一个np数组(100,100,1)的索引。这是示例代码:

    m = np.zeros(shape=(100,100,1))
    for i,j in zip(p[:,0], p[:,1]):    # i and j are between 0 and 99
        m[int(i),int(j),:] = 155

现在我想将其转换为张量流版本。令人困惑的是,还有一个维度是批量的,因此张量p变成(B,5000,2),目标张量m变成(B,100,100,1)。当前代码是:

p = tf.cast(p, dtype=tf.int32)
m = tf.zeros([batch_size, 100, 100, 1])
m[:, p[:, :, 0], p[:, :, 1], :].assign(155)   # here is wrong

但是很明显,直接使用p作为索引时会发生错误。我相信应该有一些替代方法来实现这一目标,例如tf.gather_ndtf.stack_nd,但我只是不知道细节。请帮忙。预先感谢!

0 个答案:

没有答案