Tensorflow:根据偶数和奇数指数合并两个二维张量

时间:2017-07-06 15:15:21

标签: python tensorflow

我想检查批次的偶数和奇数元素,并在需要时交换它们。我设法得到两个我想要交织的张量:

def tf_oplu(x, name=None):   


    even = x[:,::2] #slicing into odd and even parts on the batch
    odd = x[:,1::2]

    even_flatten = tf.reshape(even, [-1]) # flatten tensors 
    #in row-major order to apply function across them

    odd_flatten = tf.reshape(odd, [-1])

    compare = tf.to_float(even_flatten<odd_flatten)
    compare_not = tf.to_float(even_flatten>=odd_flatten)

    #def oplu(x,y): # trivial function  
    #    if x<y : # (x<y)==1
    #       return y, x 
    #    else:
    #       return x, y # (x<y)==0

    even_flatten_new = odd_flatten * compare + even_flatten * compare_not
    odd_flatten_new = odd_flatten * compare_not + even_flatten * compare

    # convolute back 

    even_new = tf.reshape(even_flatten_new,[100,128])
    odd_new = tf.reshape(odd_flatten_new,[100,128])

现在我想收回$ [100,256] $ tensor,偶数和奇数位置已经填满。在numpy我当然会这样做:

y = np.empty((even_new.size + odd_newsize,), dtype=even_new.dtype)
y[:,0::2] = even_new
y[:,1::2] = odd_new

return y   

但是这种事情对于数十年来是不可能的,因为张量是不可修改的。我认为可以使用sparse tensortf.gather_nd,但两者都需要生成索引数组,这对我来说同样是非常重要的任务。 还有一点需要注意:我不想通过tf.py_func使用任何python函数,因为我检查过它们只在CPU上运行。也许lambda和tf.map_fn可能会有所帮助?谢谢!

3 个答案:

答案 0 :(得分:5)

要垂直交错两个矩阵,您不需要大型枪支,例如gathermap_fn。您可以按如下方式简单地交错:

tf.reshape(
  tf.stack([even_new, odd_new], axis=1),
  [-1, tf.shape(even_new)[1]])

修改

水平交错:

tf.reshape(
  tf.concat([even_new[...,tf.newaxis], odd_new[...,tf.newaxis]], axis=-1), 
  [tf.shape(even_new)[0],-1])

这个想法是使用堆栈在内存中交错它们。堆栈发生的维度给出了交错的粒度。如果我们堆叠在axis=0,那么交错发生在每个元素,混合列。如果我们在axis=1处堆叠,则整个输入行保持连续,在行之间发生交错。

答案 1 :(得分:2)

您可以使用tf.dynamic_stitch,该参数将要插入的每个张量的索引张量列表作为第一个参数,而将要插入的张量列表作为第二参数。张量将沿第一维交错,因此我们需要先将它们转置然后再转回。这是代码:

even_new = tf.transpose(even_new,perm=[1,0])
odd_new = tf.transpose(odd_new,perm=[1,0])
even_pos = tf.convert_to_tensor(list(range(0,256,2)),dtype=tf.int32)
odd_pos = tf.convert_to_tensor(list(range(1,256,2)),dtype=tf.int32)
interleaved = tf.dynamic_stitch([even_pos,odd_pos],[even_new,odd_new])
interleaved = tf.transpose(interleaved,perm=[1,0])

答案 2 :(得分:1)

您可以使用assign分配到切片中。

odd_new = tf.constant([1,3,5])
even_new = tf.constant([2,4,6])
y=tf.Variable(tf.zeros(6, dtype=tf.int32))

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
y[0::2].assign(odd_new).eval()
y[1::2].assign(even_new).eval()