我想检查批次的偶数和奇数元素,并在需要时交换它们。我设法得到两个我想要交织的张量:
def tf_oplu(x, name=None):
even = x[:,::2] #slicing into odd and even parts on the batch
odd = x[:,1::2]
even_flatten = tf.reshape(even, [-1]) # flatten tensors
#in row-major order to apply function across them
odd_flatten = tf.reshape(odd, [-1])
compare = tf.to_float(even_flatten<odd_flatten)
compare_not = tf.to_float(even_flatten>=odd_flatten)
#def oplu(x,y): # trivial function
# if x<y : # (x<y)==1
# return y, x
# else:
# return x, y # (x<y)==0
even_flatten_new = odd_flatten * compare + even_flatten * compare_not
odd_flatten_new = odd_flatten * compare_not + even_flatten * compare
# convolute back
even_new = tf.reshape(even_flatten_new,[100,128])
odd_new = tf.reshape(odd_flatten_new,[100,128])
现在我想收回$ [100,256] $ tensor,偶数和奇数位置已经填满。在numpy我当然会这样做:
y = np.empty((even_new.size + odd_newsize,), dtype=even_new.dtype)
y[:,0::2] = even_new
y[:,1::2] = odd_new
return y
但是这种事情对于数十年来是不可能的,因为张量是不可修改的。我认为可以使用sparse tensor或tf.gather_nd,但两者都需要生成索引数组,这对我来说同样是非常重要的任务。
还有一点需要注意:我不想通过tf.py_func
使用任何python函数,因为我检查过它们只在CPU上运行。也许lambda和tf.map_fn
可能会有所帮助?谢谢!
答案 0 :(得分:5)
要垂直交错两个矩阵,您不需要大型枪支,例如gather
或map_fn
。您可以按如下方式简单地交错:
tf.reshape(
tf.stack([even_new, odd_new], axis=1),
[-1, tf.shape(even_new)[1]])
修改强>
水平交错:
tf.reshape(
tf.concat([even_new[...,tf.newaxis], odd_new[...,tf.newaxis]], axis=-1),
[tf.shape(even_new)[0],-1])
这个想法是使用堆栈在内存中交错它们。堆栈发生的维度给出了交错的粒度。如果我们堆叠在axis=0
,那么交错发生在每个元素,混合列。如果我们在axis=1
处堆叠,则整个输入行保持连续,在行之间发生交错。
答案 1 :(得分:2)
您可以使用tf.dynamic_stitch
,该参数将要插入的每个张量的索引张量列表作为第一个参数,而将要插入的张量列表作为第二参数。张量将沿第一维交错,因此我们需要先将它们转置然后再转回。这是代码:
even_new = tf.transpose(even_new,perm=[1,0])
odd_new = tf.transpose(odd_new,perm=[1,0])
even_pos = tf.convert_to_tensor(list(range(0,256,2)),dtype=tf.int32)
odd_pos = tf.convert_to_tensor(list(range(1,256,2)),dtype=tf.int32)
interleaved = tf.dynamic_stitch([even_pos,odd_pos],[even_new,odd_new])
interleaved = tf.transpose(interleaved,perm=[1,0])
答案 2 :(得分:1)
您可以使用assign
分配到切片中。
odd_new = tf.constant([1,3,5])
even_new = tf.constant([2,4,6])
y=tf.Variable(tf.zeros(6, dtype=tf.int32))
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
y[0::2].assign(odd_new).eval()
y[1::2].assign(even_new).eval()