如何从张量流中的不同对象类中采样n个像素?

时间:2018-04-13 16:47:10

标签: matrix tensorflow image-segmentation sampling

问题

我想从图像中的每个实例类中随机抽取n个像素。

让我们说我的图片是I,宽度为w,高度为h。我还有一个带有标签L的图像,用于描述与I形状相同的实例类。

当前方法

我目前的想法是首先将标签重塑为一个形状为(N_p, 1)的大型矢量。然后我重复N_c次以形成(N_p, N_c)形状。现在,我重复一个向量l,该向量由形状为(1, N_c)的所有唯一标签组成,以形成(N_p, N_c)。等于这两个得到一个矩阵,其中列y和行x中的一个,其中与行x对应的像素属于与列y对应的类。

下一步是将具有增加索引位置的矩阵与前一矩阵连接起来。现在,我可以在行中随机混洗该矩阵。

唯一缺少的步骤是提取该矩阵的n*N_c行,这些行首先为每个类提供一行。然后使用矩阵右边的索引,我可以使用

tf.gather_nd

从原始图像I中获取像素。

问题

  1. 如何在tensorflow中实现缺失的操作?即:获取k * n行,使得它们包含矩阵左侧部分中每一列的前n行中的每一行。

  2. 这些操作是否有效?

  3. 是否有一些更简单的方法?

1 个答案:

答案 0 :(得分:0)

解决方案

对于任何感兴趣的人,这里是我的问题的解决方案与相应的tensorflow代码。我在正确的轨道上,缺少的功能是

tf.nn.top_k

以下是一些示例代码,用于从每个图像的实例类中对k个像素进行采样。

import tensorflow as tf

seed = 42

width = 10
height = 6
embedding_dim = 3

sample_size = 2

image = tf.random_normal([height, width, embedding_dim], mean=0, stddev=4, seed=seed)
labels = tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                      [0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
                      [0, 0, 1, 1, 0, 2, 2, 2, 0, 0],
                      [0, 0, 1, 1, 0, 2, 2, 2, 0, 0],
                      [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=tf.uint8)


labels = tf.cast(labels, tf.int32)

# First reshape to one vector
image_v = tf.reshape(image, [-1, embedding_dim])
labels_v = tf.reshape(labels, [-1])

# Get classes
classes, indices = tf.unique(labels_v)

# Dimensions
N_c = tf.shape(classes)[0]
N_p = tf.shape(labels_v)[0]

# Helper matrices
I = tf.tile(tf.expand_dims(indices, [-1]), [1, N_c])
C = tf.tile(tf.transpose(tf.expand_dims(tf.range(N_c), [-1])), [N_p, 1])
E = tf.cast(tf.equal(I, C), tf.int32)
P = tf.expand_dims(tf.range(N_p) + 1, [-1])
R = tf.concat([E, P], axis=1)
R_rand = tf.random_shuffle(R, seed = seed)
E_rand, P_rand = tf.split(R_rand, [N_c, 1], axis = 1)
M = tf.transpose(E_rand)
_, topInidices = tf.nn.top_k(M, k = sample_size)
topInidicesFlat = tf.expand_dims(tf.reshape(topInidices, [-1]), [-1])
sampleIndices = tf.gather_nd(P_rand, topInidicesFlat)
samples = tf.gather_nd(image_v, sampleIndices)

sess = tf.Session()
list = [image,
        labels,
        image_v,
        labels_v,
        classes,
        indices,
        N_c,
        N_p,
        I,
        C,
        E,
        P,
        R,
        R_rand,
        E_rand,
        P_rand,
        M,
        topInidices,
        topInidicesFlat,
        sampleIndices,
        samples
        ]
list_ = sess.run(list)
print(list_)