如何在张量流中自动合并形状?

时间:2018-12-18 16:27:47

标签: python tensorflow

对于高级张量,我不知道如何自动操纵其形状。

例如:

                                #   0  1  2  3   -1
a.shape                         # [?, ?, ?, ?, ..., ?]
merge_dims(a, [0]   ).shape     # [?* ?, ?, ?, ..., ?]
merge_dims(a, [1, 2]).shape     # [?, ?* ?* ?, ..., ?]
                                #   ^  ^  ^  ^    ^

使用merge_dims时,以位置编号标记的逗号应相乘,从而使张量较低。

谢谢:)

1 个答案:

答案 0 :(得分:1)

此功能可以执行以下操作:

import tensorflow as tf

def merge_dims(x, axis, num=1):
    # x: input tensor
    # axis: first dimension to merge
    # num: number of merges
    shape = tf.shape(x)
    new_shape = tf.concat([
        shape[:axis],
        [tf.reduce_prod(shape[axis:axis + num + 1])],
        shape[axis + num + 1:]], axis=0)
    return tf.reshape(x, new_shape)

with tf.Graph().as_default(), tf.Session() as sess:
    a = tf.ones([2, 4, 6, 8, 10])
    print(sess.run(tf.shape(merge_dims(a, 0))))
    # [ 8  6  8 10]
    print(sess.run(tf.shape(merge_dims(a, 1, num=2))))
    # [  2 192  10]