对于高级张量,我不知道如何自动操纵其形状。
例如:
# 0 1 2 3 -1
a.shape # [?, ?, ?, ?, ..., ?]
merge_dims(a, [0] ).shape # [?* ?, ?, ?, ..., ?]
merge_dims(a, [1, 2]).shape # [?, ?* ?* ?, ..., ?]
# ^ ^ ^ ^ ^
使用merge_dims
时,以位置编号标记的逗号应相乘,从而使张量较低。
谢谢:)
答案 0 :(得分:1)
此功能可以执行以下操作:
import tensorflow as tf
def merge_dims(x, axis, num=1):
# x: input tensor
# axis: first dimension to merge
# num: number of merges
shape = tf.shape(x)
new_shape = tf.concat([
shape[:axis],
[tf.reduce_prod(shape[axis:axis + num + 1])],
shape[axis + num + 1:]], axis=0)
return tf.reshape(x, new_shape)
with tf.Graph().as_default(), tf.Session() as sess:
a = tf.ones([2, 4, 6, 8, 10])
print(sess.run(tf.shape(merge_dims(a, 0))))
# [ 8 6 8 10]
print(sess.run(tf.shape(merge_dims(a, 1, num=2))))
# [ 2 192 10]