哪个API可以在tensorflow中实现张量扩展?

时间:2019-04-27 08:54:15

标签: python tensorflow

如果我有一个(30,40,50)的张量,并且我想将其扩展到一阶,那么我得到一个二阶张量(30,2000),并且我不知道是否tensorflow有一个实现它的API。

import tensorflow as tf
import numpy as np
data1=tf.constant([
    [[2,5,7,8],[6,4,9,10],[14,16,86,54]],
    [[16,43,65,76],[43,65,7,24],[15,75,23,75]]])

data5=tf.reshape(data1,[3,8])

data2,data3,data4=tf.split(data1,3,1)
data6=tf.reshape(data2,[1,8])
data7=tf.reshape(data3,[1,8])
data8=tf.reshape(data4,[1,8])
data9=tf.concat([data6,data7,data8],0)

with tf.Session() as sess:
    print(sess.run(data5))
    print(sess.run(data))

这给出了:

data5
[[ 2 5 7 8 6 4 9 10]
[14 16 86 54 16 43 65 76]
[43 65 7 24 15 75 23 75]]

data9
[[ 2 5 7 8 16 43 65 76]
[ 6 4 9 10 43 65 7 24]
[14 16 86 54 15 75 23 75]]

如何直接获取data9?

2 个答案:

答案 0 :(得分:0)

好像您正在尝试获取横跨轴0的子张量(data1[0]data1[1],...)并沿轴2串联它们。

在重塑之前移调应该可以解决问题:

tf.reshape(tf.transpose(data1, [1,0,2]), [data1.shape[1], data1.shape[0] * data1.shape[2]])

答案 1 :(得分:0)

您可以尝试:

data9 = tf.layers.flatten(tf.transpose(data1, perm=[1, 0, 2]))

输出:

array([[ 2,  5,  7,  8, 16, 43, 65, 76],
       [ 6,  4,  9, 10, 43, 65,  7, 24],
       [14, 16, 86, 54, 15, 75, 23, 75]], dtype=int32)