pytorch RNN参数,损失函数

时间:2020-08-11 05:30:21

标签: pytorch recurrent-neural-network

我是Pytorch的初学者。我想将地球物理学中的优化问题应用于RNN。 我没有使用nn.RNN或nn.LSTM,因为输出和可训练参数的关系是由物理公式定义的。

这是带有nn.Module的代码。在前向函数中,u3(输出)是在前一个波场u1,u2处获得的,并且涉及可训练参数vp_inv。

class WaveCell(nn.Module):
def __init__(self, nnz, dz, dt):
    super(WaveCell, self).__init__()

    self.nnz = nnz
    self.dz = dz
    self.dt = dt

    nz=self.nnz
    npz=self.nnz+4

    vp_ini = np.zeros((npz), dtype=np.float32)
    # vp_ini[0:30]=2.
    min_vp = 2.
    max_vp = 5.

    vp_ini[0:30] = 2.
    for iz in range(30, 170):
        vp_ini[iz] = min_vp + (max_vp - min_vp) * float(iz) / (nz + 1)
    vp_ini[170:npz] = 5.

    vp_ini = gaussian_filter1d(vp_ini, 5, mode='nearest')

    self.vp_inv = Parameter(torch.tensor(data=vp_ini, dtype=torch.float32))
    # self.vp_inv = Variable(vp_ini,requires_grad=True)     

def forward(self, u1, u2, src):
    nnz = self.nnz
    dz = self.dz
    dt = self.dt

    nz = nnz
    vp_in=self.vp_inv.clamp(2.,5.)
    vp_in[nz+1:]=5.

    npz = nnz + 4

    srcf = torch.zeros((npz))
    srcf[1] = 1.

    u3 = torch.zeros((npz))
    u3[nnz] = -vp_in[nz] * dt * (u2[nnz] - u2[nnz - 1]) / dz + u2[nnz]

    u3[1:nnz] = 2. * u2[1:nnz] - u1[1:nnz] + (dt) ** 2 * \
                (vp_in[1:nnz]**2*(u2[2:nnz + 1] - 2. * u2[1:nnz] + u2[0:nnz - 1]) / dz ** 2 + srcf[1:nnz] * src)

    return u2, u3, u3[1]

,此代码是优化过程。我每个时间步都计算损失值,并在时间循环后向后计算。我的问题是 Q1:它可以作为RNN正常工作吗? Q2:我想通过定义参数直接更新vp_inv。但是,我不确定它的更新是否正确。

rnn = WaveCell(nnz, dz, dt)
epochs=1000
lr=100.
loss_func = nn.MSELoss()

optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)
  
p1= torch.zeros(npz)
p2= torch.zeros(npz)
p3= torch.zeros(npz)

srcf = torch.zeros((npz))

p0= torch.zeros(1)

for i in range(epochs):
    rnn.zero_grad()
    total_loss = 0

    p1 = torch.zeros(npz)
    p2 = torch.zeros(npz)
    p3 = torch.zeros(npz)

    outs=[]

    for it in range(1,nt):
        target = true_rec[it]

        p3 = torch.zeros((npz))
        srcf = torch.zeros((npz))
        p0 = torch.zeros(1)

        p1, p2, p0 = rnn(p1, p2, source[it])
        loss = loss_func(p0, target)

        total_loss += loss
        outs.append(p0)

    total_loss.backward(retain_graph=True)


    optimizer.step()

0 个答案:

没有答案