在张量流中索引1D张量

时间:2016-10-31 13:14:19

标签: python indexing tensorflow

是否有可能从给定索引处的张量中获取元素以获得标量?例如,给定一个图像,我可以使用shape = tf.shape(image)检索其形状,但我如何检索其高度,宽度和深度?

我找到的唯一方法如下:

height = tf.reshape(tf.slice(shape, [0], [1]), [])
width = tf.reshape(tf.slice(shape, [1], [1]), [])
depth = tf.reshape(tf.slice(shape, [2], [1]), [])

还有其他办法吗?

2 个答案:

答案 0 :(得分:1)

切片语法(即使用[]运算符)基于NumPy切片,并提供了一种从shape张量获得高度,宽度和深度的更简洁方法:

shape = tf.shape(image)
height = shape[0]  # returns a scalar
width = shape[1]   # returns a scalar
depth = shape[2]   # returns a scalar
如果张量具有静态确定的形状,

Nessuno's answer也将很好地工作。但是,使用None时,可变大小的图像(例如tf.image.decode_jpeg()的结果)通常会为高度和宽度尺寸提供get_shape(),因为这些图像可能因图像而异

答案 1 :(得分:0)

使用tf.Variable.get_shape(variable)

image = tf.Variable([ [ [1,2,3],[3,4,4],[1,1,1] ], [[1,2,3],[3,4,4],[1,1,1] ]])
shape = image.get_shape()
print( [shape[i].value for i in range(len(shape))] )

输出:

[2,3,3]