如何(有效地)在TensorFlow中应用通道方式的完全连接层

时间:2017-11-29 15:12:29

标签: python tensorflow deep-learning autoencoder

我再来找你,抓住我可以开始工作的东西,但真的很慢。我希望你能帮助我优化它。

我试图在TensorFlow中实现一个卷积自动编码器,在编码器和解码器之间有一个很大的潜在空间。通常,人们会将编码器连接到具有完全连接层的解码器,但是因为这个潜在空间具有高维度,所以这样做会产生太多的特征,使其在计算上是可行的。

我在this paper找到了一个很好的解决方案。他们将其称为“通道方式完全连接的层”。它基本上是每个通道的完全连接层。

我正在进行实施,我让它工作,但图表的生成需要很长时间。到目前为止,这是我的代码:

def _network(self, dataset, isTraining):
        encoded = self._encoder(dataset, isTraining)
        with tf.variable_scope("fully_connected_channel_wise"):
            shape = encoded.get_shape().as_list()
            print(shape)
            channel_wise = tf.TensorArray(dtype=tf.float32, size=(shape[-1]))
            for i in range(shape[-1]):  # last index in shape should be the output channels of the last conv
                channel_wise = channel_wise.write(i, self._linearLayer(encoded[:,:,i], shape[1], shape[1]*4, 
                                  name='Channel-wise' + str(i), isTraining=isTraining))
            channel_wise = channel_wise.concat()
            reshape = tf.reshape(channel_wise, [shape[0], shape[1]*4, shape[-1]])
        reconstructed = self._decoder(reshape, isTraining)
        return reconstructed

那么,关于为什么这么长时间的任何想法?实际上这是一个范围(2048),但所有的线性层都非常小(4x16)。我是以错误的方式接近这个吗?

谢谢!

1 个答案:

答案 0 :(得分:3)

您可以在Tensorflow中查看该论文的实施情况。 以下是他们实现的“通道完全连接层”#。

    try
    {

        PDDocument documentSrc = PDDocument.load(new File(SRC));
        PDAcroForm acroFormSrc = documentSrc.getDocumentCatalog().getAcroForm();

        PDDocument documentDest = PDDocument.load(new File(DEST));
        PDAcroForm acroFormDest = new PDAcroForm(documentDest);

        acroFormDest.setCacheFields(true);
        acroFormDest.setFields(acroFormSrc.getFields());
        documentDest.getDocumentCatalog().setAcroForm(acroFormDest);

        int pageIndex = 0;
        for(PDPage page: documentSrc.getPages()){
            documentDest.getPage(pageIndex).setAnnotations(page.getAnnotations());
            documentDest.getPage(pageIndex).setResources(page.getResources());
            pageIndex++;
        }

        documentDest.save(DEST_MERGED);
        documentDest.close();
        documentSrc.close();
    }
    catch (IOException e)
    {
        // TODO Auto-generated catch block
        e.printStackTrace();
    }
}

https://github.com/jazzsaxmafia/Inpainting/blob/8c7735ec85393e0a1d40f05c11fa1686f9bd530f/src/model.py#L60

主要思想是使用tf.batch_matmul函数。

但是,在最新版本的Tensorflow中删除了tf.batch_matmul,您可以使用tf.matmul替换它。