像素RNN Pytorch实施

时间:2018-07-16 14:35:58

标签: python machine-learning lstm pytorch rnn

我正在尝试在pytorch中实现Pixel RNN,但似乎找不到任何文档。 Pixel RNN的主要部分是Row LSTM和BiDiagonal LSTM,因此我正在寻找这些算法的一些代码,以更好地了解它们的作用。具体地说,我对这些算法分别一次计算一行和对角线感到困惑。任何帮助将不胜感激。

1 个答案:

答案 0 :(得分:1)

摘要

这是一个进行中的部分实现:

https://github.com/carpedm20/pixel-rnn-tensorflow

以下是Google deepmind对Row LSTM和BiDiagonal LSTM的描述:

https://towardsdatascience.com/summary-of-pixelrnn-by-google-deepmind-7-min-read-938d9871d6d9


行LSTM

来自链接的deepmind博客:

一个像素的隐藏状态(在下面的图像中为红色)基于其前面三个三角形的“内存”。因为它们在“行”中,所以我们可以并行计算,从而加快了计算速度。我们牺牲了一些上下文信息(使用更多的历史记录或内存)来进行并行计算并加快训练速度。

enter image description here

实际的实现依赖于其他几个优化,并且涉及很多。来自original paper

  

计算过程如下。 LSTM层具有   输入到状态组件和循环状态到状态组件,   共同确定LSTM内核内部的四个门。加强   在行LSTM中并行化,输入到状态组件是第一个   为整个二维输入图计算;为此a k×1   卷积用于遵循LSTM的行方向   本身。卷积被屏蔽为仅包含有效上下文   (请参阅第3.4节),并产生大小为4h×n×n的张量,   代表输入图中每个位置的四个门矢量,   其中h是输出要素图的数量。计算一个   LSTM层的状态到状态组件,给定一个   先前的隐藏状态和单元格状态hi-1和ci-1,大小分别为h×n×1。   新的隐藏状态和单元状态hi,ci如下获得:

enter image description here

  

其中,大小为h×n×1的xi是输入映射的第i行,〜表示卷积运算和元素级   乘法。权重Kss和Kis是   状态到状态和输入到状态组件,其中后者是   如上所述预先计算。对于输出,忘记并   输入门oi,fi和ii,激活σ是逻辑S形   函数,而对于内容门gi,σ是tanh函数。   每个步骤都会立即计算整行的新状态   输入地图

对角线BLSTM

对角BLSTM的开发是为了利用并行化的速度而不会牺牲尽可能多的上下文信息。 DBLSTM中的一个节点从其左侧和上方看。由于这些节点也向左上方看,因此在某种意义上,给定节点的条件概率取决于其所有祖先。否则,架构非常相似。来自Deepmind博客:

enter image description here