如何通过tensorflow中的两个dims连接张量

时间:2017-02-05 08:06:54

标签: tensorflow

我想通过tensorflow中的两个dims来连接张量。

例如,有四个具有4维的张量。所有张量都与张量流中的图像类似,因此每个维度表示以下内容:[batch_size,image_width_size,image_height_size,image_channel_size]。

import tensorflow as tf

image_tensor_1 = 1*tf.ones([60, 2, 2, 3])
image_tensor_2 = 2*tf.ones([60, 2, 2, 3])
image_tensor_3 = 3*tf.ones([60, 2, 2, 3])
image_tensor_4 = 4*tf.ones([60, 2, 2, 3])

image_result_wanted = ... # Some operations here

sess = tf.Session()
print(sess.run([image_result_wanted])

不考虑批量大小和通道尺寸(我的意思是,只考虑图像宽度和图像高度),我想解决以下问题:

[[1, 1, 2, 2],
 [1, 1, 2, 2],
 [3, 3, 4, 4],
 [3, 3, 4, 4]]

因此,image_result_wanted的形状应为(60, 4, 4, 3)

我应该如何处理此操作?

2 个答案:

答案 0 :(得分:3)

您可以使用tf.concat沿所需的轴连接张量。

下面:

import tensorflow as tf

image_tensor_1 = 1*tf.ones([60, 2, 2, 3])
image_tensor_2 = 2*tf.ones([60, 2, 2, 3])
image_tensor_3 = 3*tf.ones([60, 2, 2, 3])
image_tensor_4 = 4*tf.ones([60, 2, 2, 3])

try:
    temp_1 = tf.concat_v2([image_tensor_1, image_tensor_2], 2)
    temp_2 = tf.concat_v2([image_tensor_3, image_tensor_4], 2)
    result = tf.concat_v2([temp_1, temp_2], 1)
except AttributeError:
    temp_1 = tf.concat(2, [image_tensor_1, image_tensor_2])
    temp_2 = tf.concat(2, [image_tensor_3, image_tensor_4])
    result = tf.concat(1, [temp_1, temp_2])


sess = tf.Session()
print sess.run([result[0,:,:,0]])

答案 1 :(得分:2)

真的不知道如何在一行中做到这一点,所以我想出了以下内容:

import tensorflow as tf

image_tensor_1 = 1 * tf.ones([60, 2, 2, 3])
image_tensor_2 = 2 * tf.ones([60, 2, 2, 3])
image_tensor_3 = 3 * tf.ones([60, 2, 2, 3])
image_tensor_4 = 4 * tf.ones([60, 2, 2, 3])

# make two tensors with shapes of [60, 2, 4, 3]
concat1 = tf.concat(2, [image_tensor_1, image_tensor_2])
concat2 = tf.concat(2, [image_tensor_3, image_tensor_4])
# stack two tensors together to obtain desired result with shape [60, 4, 4, 3]
result = tf.concat(1, [concat1, concat2])

以下代码:

sess = tf.Session()
print(sess.run(result[0, :, :, 0]))

的结果

[[ 1.  1.  2.  2.]
 [ 1.  1.  2.  2.]
 [ 3.  3.  4.  4.]
 [ 3.  3.  4.  4.]]

根据需要。

太晚了,哈哈:)