如何将PyTorch recurrent_block转换为Keras equvalet

时间:2019-08-08 13:24:15

标签: keras pytorch

我正在测试几种用于语义分割的体系结构,并遇到了我想尝试的PyTorch中的一个实现。我的问题是我没有使用PyTorch的经验,因此很难将以下代码片段转换为Keras。

class Recurrent_block(nn.Module):
    def __init__(self,ch_out,t=2):
        super(Recurrent_block,self).__init__()
        self.t = t
        self.ch_out = ch_out
        self.conv = nn.Sequential(
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        for i in range(self.t):

            if i==0:
                x1 = self.conv(x)

            x1 = self.conv(x+x1)
        return x1

1 个答案:

答案 0 :(得分:0)

以下代码段是否等效?

from keras.layers import *

def single_conv(filters):
    def layer(input):
        x = Conv2D(filters=filters, kernel_size=3, strides=1, padding='same')(input)
        x = BatchNormalization()(x)
        x = ReLU()(x)
        return x
    return layer

def recurrent_block(filters, t=2):
    def layer(input):
        for i in range(t):
            if i == 0:
                x1 = single_conv(filters)(input)
            add = Add()([input, x1])
            x1 = single_conv()(add)
        return x1
    return layer