我有一个BasicLSTMCell
被送入static_rnn
,展开sent_max_len = 2
次。输入是一批batch_size = 2
的句子。每个单词都有一个embed_size = 6
。输出如下:
e1 e2 e3 e4 e5 e6
[[-0.01236094, -0.00423804, -0.01091367, 0.00286771, -0.00911439, -0.00964547] s1.w1
[-0.0316297 , 0.00904978, -0.02972977, -0.00720989, 0.00432076, 0.00946513]] s2.w1
t1
[[-0.03660333, 0.00613474, -0.03758694, -0.0070029 , -0.00036427, 0.00386676], s1.w2
[-0.04840172, 0.01757939, -0.05444464, -0.01508901, 0.01216465, 0.01938537]] s2.w2
t2
s1 = sentence 1
w1 = word 1
t1 = timestep 1
e1 = embedding 1
我想连接输出。但问题是当我沿tf.concat
axis=0
时,它取t1的输出并与t2连接,如下所示:
[[-0.01236094 -0.00423804 -0.01091367 0.00286771 -0.00911439 -0.00964547] s1.w1
[-0.0316297 0.00904978 -0.02972977 -0.00720989 0.00432076 0.00946513] s2.w1
[-0.03660333 0.00613474 -0.03758694 -0.0070029 -0.00036427 0.00386676] s1.w2
[-0.04840172 0.01757939 -0.05444464 -0.01508901 0.01216465 0.01938537] s2.w2
但我想这样连接:
[[-0.01236094 -0.00423804 -0.01091367 0.00286771 -0.00911439 -0.00964547] s1.w1
[-0.03660333 0.00613474 -0.03758694 -0.0070029 -0.00036427 0.00386676] s1.w2
[-0.0316297 0.00904978 -0.02972977 -0.00720989 0.00432076 0.00946513] s2.w1
[-0.04840172 0.01757939 -0.05444464 -0.01508901 0.01216465 0.01938537] s2.w2
因为此concated_output
遍历其他图层,而我的最终predicted_output
与actual_output
进行比较,actual_output
看起来像这样:
[[s1.w1.actualOutput]
[s1.w2.actualOutput]
[s2.w1.actualOutput]
[s2.w2.actualOutput]
]
显然我无法连接axis=1
,因为即使它会以正确的顺序输出,也会合并嵌入字。
答案 0 :(得分:2)
我会使用tf.gather_nd,你需要提供索引来从给定的张量中收集项目。例如:
data1 = tf.constant(
[
[[1,1,1],[2,2,2]],
[[3,3,3],[4,4,4]]
]
)
indices = tf.constant([
[[0,0], [1,0]],
[[0,1], [1,1]]
])
result = tf.gather_nd(data1, indices)
会给:
[[[1 1 1]
[3 3 3]]
[[2 2 2]
[4 4 4]]]
然后你可以使用concat with axis = 0将张量转换为你想要的格式
答案 1 :(得分:1)
一种方式:
output_sent = tf.stack(output_sent, axis=1)
output_sent = tf.reshape(output_sent, [-1, sent_embed_size])
greeness提到的另一种方式:
concat
沿着axis=1
然后执行reshape
,以便最里面的尺寸尺寸为6。