如何乘以张量的维数?

时间:2017-05-31 03:39:01

标签: python tensorflow

在我打印conv_out.get_shape()时的以下代码中,它为我提供了输出(1,14,14,1)。我想将第二个第三维和第四维(14*14*1)相乘。我怎么能这样做?

input = tf.Variable(tf.random_normal([1,28,28,1]))
filter = tf.Variable(tf.random_normal([5,5,1,1]))

def conv2d(input,filter):
    return tf.nn.conv2d(input,filter,strides=[1,2,2,1],padding='SAME')

conv_out = conv2d(input,filter)
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

print conv_out.get_shape()
print conv_out.get_shape().as_list()[2]

1 个答案:

答案 0 :(得分:2)

类似

import numpy as np
np.asarray(conv_out.get_shape().as_list()[1:]).prod()

应该做的工作。

或者,如果你想在内部使用张量流图,那就像:

tf_shape = tf.shape(conv_out)
tf_shape_prod = tf.reduce_prod(tf_shape[1:])