在Tensorflow中,我可以使用tf.gather()进行部分连接吗?

时间:2016-04-19 16:35:14

标签: python tensorflow deep-learning

我正在尝试在图层之间实现部分连接。比方说,我想只使用一些特征图,例如第一个和第三个。

  • 为此目的使用tf.gather()是否正确?
  • 我可以使用索引operator []代替tf.gather(),如下所示吗?
  • 收集指数是否会在反向传播方面发挥作用?我很难想象Tensorflow将如何在内部知道内部连接来自内部反向支持过程中的第一个和第三个(信息是硬编码的)。功能tf.gather是否记得连接?

代码:

# let say, L1 is layer1 output of shape [batch_size x image_size x image_size x depth1]
partL1 = L1[:, :, :, [0,2]]
# W2 is a tf variable of shape [5, 5, 2, depth2]
conv2 = tf.nn.conv2d(partL1, W2)

1 个答案:

答案 0 :(得分:3)

是的,不,是的。 :-) (a)是的,您可以使用聚集来挑选图层的子集以传播到下一层,如您所建议的那样。

(b)不,遗憾的是,您无法使用索引运算符。您需要明确调用tf.gather()

(c)是的,TensorFlow将存储用于收集的索引的副本并将其保存为backprop。如果您对如何查看op的输入并使用这些输入进行传播,您可以看到the implementation of Gather's Gradient