Tensorflow:使用tf.expand_dims时?

时间:2016-08-18 01:57:22

标签: tensorflow

Tensorflow教程包括使用tf.expand_dims添加"批量维度"到一个张量。我已经阅读了这个函数的文档,但它对我来说仍然是相当神秘的。有谁知道在什么情况下必须使用它?

我的代码如下。我的意图是根据预测箱和实际箱之间的距离计算损失。 (例如predictedBin = 10truthBin = 7然后binDistanceLoss = 3)。

batch_size = tf.size(truthValues_placeholder)
labels = tf.expand_dims(truthValues_placeholder, 1)
predictedBin = tf.argmax(logits)
binDistanceLoss = tf.abs(tf.sub(labels, logits))

在这种情况下,我是否需要将tf.expand_dims应用于predictedBinbinDistanceLoss?提前谢谢。

2 个答案:

答案 0 :(得分:40)

expand_dims不会在张量中添加或减少元素,只是通过向维度添加1来更改形状。例如,具有10个元素的向量可以被视为10x1矩阵。

我遇到使用expand_dims的情况是我尝试构建ConvNet以对灰度图像进行分类。灰度图像将作为大小为[320, 320]的矩阵加载。但是,tf.nn.conv2d要求输入为[batch, in_height, in_width, in_channels],我的数据中缺少in_channels维度,在这种情况下应为1。所以我使用expand_dims添加了一个维度。

在您的情况下,我认为您不需要expand_dims

答案 1 :(得分:13)

要添加Da Tong的答案,您可能希望同时扩展多个维度。例如,如果您对等级为1的向量执行TensorFlow的class ViewController2: UIViewController { var navigation = UINavigationController() let navigationRoot = ViewController3() override func loadView() { setView() addNavigation() } func setView() { view = UIView() view.frame = UIScreen.main.bounds } func addNavigation() { self.addChildViewController(navigationRoot) navigationRoot.didMove(toParentViewController: self) navigation = UINavigationController(rootViewController: navigationRoot) view.addSubview(navigation.view) } } 操作,则需要为它们提供等级3。

多次执行conv1d是可读的,但可能会在计算图中引入一些开销。您可以使用expand_dims

在单行中获得相同的功能
reshape

注意:如果您收到错误import tensorflow as tf # having some tensor of rank 1, it could be an audio signal, a word vector... tensor = tf.ones(100) print(tensor.get_shape()) # => (100,) # expand its dimensionality to fit into conv2d tensor_expand = tf.expand_dims(tensor, 0) tensor_expand = tf.expand_dims(tensor_expand, 0) tensor_expand = tf.expand_dims(tensor_expand, -1) print(tensor_expand.get_shape()) # => (1, 1, 100, 1) # do the same in one line with reshape tensor_reshape = tf.reshape(tensor, [1, 1, tensor.get_shape().as_list()[0],1]) print(tensor_reshape.get_shape()) # => (1, 1, 100, 1) ,请按照建议here尝试传递TypeError: Failed to convert object of type <type 'list'> to Tensor.而不是tf.shape(x)[0]

希望它有所帮助! 干杯,
安德烈