在TensorFlow中实现复数的排列

时间:2016-05-05 12:21:43

标签: permutation tensorflow

在这篇关联的lstm论文http://arxiv.org/abs/1602.03032中,他们要求置换一个复杂的张量。

他们在这里提供了他们的代码:https://github.com/mohammadpz/Associative_LSTM/blob/master/bricks.py#L79

我试图在tensorflow中复制它。这就是我所做的:

# shape: C x F/2
# output = self.permutations: [num_copies x cell_size]
permutations = []
indices = numpy.arange(self._dim / 2) #[1 ,2 ,3 ...64]
for i in range(self._num_copies):
    numpy.random.shuffle(indices) #[4, 48, 32, ...64]
    permutations.append(numpy.concatenate(
        [indices,
         [ind + self._dim / 2 for ind in indices]]))
    #you're appending a row with two columns -- a permutation in the first column, and the same permutation + dim/2 for imaginary
# C x F (numpy)
self.permutations = tf.constant(numpy.vstack(permutations), dtype = tf.int32) #This is a permutation tensor that has the stored permutations
# output = self.permutations: [num_copies x cell_size]

def permute(complex_tensor): #complex tensor is [batch_size x cell_size]
 gather_tensor = tf.gather_nd(complex_tensor, self.permutations)
 return gather_tensor

基本上,我的问题是:在TensorFlow中如何有效地完成这项工作?无论如何都要将批量大小维度固定为complex tensor

另外,gather_nd是最好的方法吗?或者,使用self.permutations进行for循环并迭代tf.gather中的每一行是否更好?

def permute(self, complex_tensor):
 inputs_permuted = []
 for i in range(self.permutations.get_shape()[0].value):
  inputs_permuted.append(
    tf.gather(complex_tensor, self.permutations[i]))
 return tf.concat(0, inputs_permuted)

我认为gather_nd会更有效率。

1 个答案:

答案 0 :(得分:0)

没关系,我想通了,诀窍就是使用tf转置来使用原始输入张量置换。这将允许您在整个矩阵上执行tf.gather。然后你可以把矩阵连在一起。对不起,如果这浪费了任何人的时间。