我试图从Tensor沿特定轴提取所有可能的排列。我的输入是[B, S, L]
张量(B批S长度为L的向量),我想提取这些向量中的所有可能的排列(S!置换),即[B, S!, S, L]
Tensor作为输出。
这就是我现在尝试的,但我正在努力获得正确的输出形状。我认为我的错误可能是我创建了一个batch_range,但我也应该创建一个permutation_range。
import tensorflow as tf
import numpy as np
from itertools import permutations
S = 3
B = 5
L = 10
input = tf.constant(np.random.randn(B, S, L))
perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])
batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
indicies = tf.concat([batch_range, perms], axis=3)
permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) #
# I get a [ B, P, S, S, L] instead of the desired [B, P, S, L]
我发布了一个可能的'解决方案'就在下面,但我认为这个问题仍有问题。我对它进行了测试,如果B> 1,它的进展并不顺利。
答案 0 :(得分:0)
我刚刚找到答案,如果您认为我错了或者有更简单的方法可以解答,请纠正我:
import tensorflow as tf
import numpy as np
from itertools import permutations
S = 3
B = 5
L = 10
input = tf.constant(np.random.randn(B, S, L))
perms = list(permutations(range(S))) # ex with 3: [0, 1, 2], [0, 2 ,1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]
length_perm = len(perms)
perms = tf.reshape(tf.constant(perms), [1, length_perm, S, 1])
perms = tf.tile(perms, [B, 1, 1, 1])
batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1, 1]), [1, length_perm, S, 1])
perm_range = tf.tile(tf.reshape(tf.range(length_perm, dtype=tf.int32), shape=[1, length_perm, 1, 1]), [B, 1, S, 1])
indicies = tf.concat([batch_range, perm_range, perms], axis=3)
permutations = tf.gather_nd(tf.tile(tf.reshape(input, [B, 1, S, L]), [1, length_perm, 1, 1]), indicies) #
print permutations