tensorflow concat张贴行

时间:2017-12-06 22:03:53

标签: python-3.x tensorflow

我有一个BasicLSTMCell被送入static_rnn,展开sent_max_len = 2次。输入是一批batch_size = 2的句子。每个单词都有一个embed_size = 6。输出如下:

       e1           e2           e3            e4          e5           e6
[[-0.01236094, -0.00423804, -0.01091367,  0.00286771, -0.00911439, -0.00964547]  s1.w1      
[-0.0316297 ,  0.00904978, -0.02972977, -0.00720989,  0.00432076,  0.00946513]] s2.w1
                                       t1

[[-0.03660333,  0.00613474, -0.03758694, -0.0070029 , -0.00036427, 0.00386676], s1.w2        
[-0.04840172,  0.01757939, -0.05444464, -0.01508901,  0.01216465, 0.01938537]] s2.w2
                                        t2
s1 = sentence 1
w1 = word 1
t1 = timestep 1
e1 = embedding 1

我想连接输出。但问题是当我沿tf.concat axis=0时,它取t1的输出并与t2连接,如下所示:

[[-0.01236094 -0.00423804 -0.01091367  0.00286771 -0.00911439 -0.00964547]  s1.w1
 [-0.0316297   0.00904978 -0.02972977 -0.00720989  0.00432076  0.00946513]  s2.w1
 [-0.03660333  0.00613474 -0.03758694 -0.0070029  -0.00036427  0.00386676]  s1.w2
 [-0.04840172  0.01757939 -0.05444464 -0.01508901  0.01216465  0.01938537]  s2.w2

但我想这样连接:

[[-0.01236094 -0.00423804 -0.01091367  0.00286771 -0.00911439 -0.00964547]  s1.w1
 [-0.03660333  0.00613474 -0.03758694 -0.0070029  -0.00036427  0.00386676]  s1.w2
 [-0.0316297   0.00904978 -0.02972977 -0.00720989  0.00432076  0.00946513]  s2.w1
 [-0.04840172  0.01757939 -0.05444464 -0.01508901  0.01216465  0.01938537]  s2.w2

因为此concated_output遍历其他图层,而我的最终predicted_outputactual_output进行比较,actual_output看起来像这样:

[[s1.w1.actualOutput]
 [s1.w2.actualOutput]
 [s2.w1.actualOutput]
 [s2.w2.actualOutput]
]

显然我无法连接axis=1,因为即使它会以正确的顺序输出,也会合并嵌入字。

2 个答案:

答案 0 :(得分:2)

我会使用tf.gather_nd,你需要提供索引来从给定的张量中收集项目。例如:

data1 = tf.constant(
    [
        [[1,1,1],[2,2,2]],
        [[3,3,3],[4,4,4]]

    ]
)
indices = tf.constant([
    [[0,0], [1,0]],
    [[0,1], [1,1]]
])
result = tf.gather_nd(data1, indices)

会给:

[[[1 1 1]
  [3 3 3]]

 [[2 2 2]
 [4 4 4]]]

然后你可以使用concat with axis = 0将张量转换为你想要的格式

答案 1 :(得分:1)

一种方式:

output_sent = tf.stack(output_sent, axis=1) 
output_sent = tf.reshape(output_sent, [-1, sent_embed_size])

greeness提到的另一种方式:

concat沿着axis=1然后执行reshape,以便最里面的尺寸尺寸为6。