我有一个像下面的张量
x = tf.Variable(tf.truncated_normal([batch, input]), stddev=0.1))
假设批次= 99,输入= 5,我想分成一个小张量。
如果x低于:
[[1.0, 2.0, 3.0, 4.0, 5.0]
[2.0, 3.0, 4.0, 5.0, 6.0]
[3.0, 4.0, 5.0, 6.0, 7.0]
[4.0, 5.0, 6.0, 7.0, 8.0]
.........................
.........................
.........................
[44.0, 55.0, 66.0, 77.0, 88.0]
[55.0, 66.0, 77.0, 88.0, 99.0]]
我想分成两个张量
[[1.0, 2.0, 3.0, 4.0, 5.0]
[2.0, 3.0, 4.0, 5.0, 6.0]
[3.0, 4.0, 5.0, 6.0, 7.0]]
和
[4.0, 5.0, 6.0, 7.0, 8.0]
.........................
.........................
[44.0, 55.0, 66.0, 77.0, 88.0]
[55.0, 66.0, 77.0, 88.0, 99.0]]
我不知道如何使用tf.split
拆分行。
答案 0 :(得分:1)
一种权宜之计是两次拨打tf.slice
。