pytorch自动编码器模型评估失败

时间:2019-01-13 07:20:48

标签: python deep-learning pytorch

我实际上是PyTorch的初学者。 我训练了一个自动编码器网络,以便可以绘制潜在矢量的分布(编码器的结果)。

这是我用于网络培训的代码。

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset
from PIL import Image
import os
import glob

dir_img_decoded = '/media/dohyeong/HDD/mouth_autoencoder/dc_img_2'
if not os.path.exists(dir_img_decoded):
    os.mkdir(dir_img_decoded)

dir_check_point = '/media/dohyeong/HDD/mouth_autoencoder/ckpt_2'
if not os.path.exists(dir_check_point):
    os.mkdir(dir_check_point)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

num_epochs = 200
batch_size = 150  # up -> GPU memory increase
learning_rate = 1e-3

dir_dataset = '/media/dohyeong/HDD/mouth_autoencoder/mouth_crop/dir_normalized_mouth_cropped_images'
images = glob.glob(os.path.join(dir_dataset, '*.png'))
train_images = images[:-113]
test_images = images[-113:]

train_images.sort()
test_images.sort()





class TrumpMouthDataset(Dataset):
    def __init__(self, images):
        super(TrumpMouthDataset, self).__init__()
        self.images = images

        self.transform = transforms.Compose([
            # transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __getitem__(self, index):
        image = Image.open(self.images[index])

        return self.transform(image)

    def __len__(self):
        return len(self.images)


train_dataset = TrumpMouthDataset(train_images)
test_dataset = TrumpMouthDataset(test_images)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(60000, 60),
            nn.ReLU(True),
            nn.Linear(60, 3),
            nn.ReLU(True),
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 60),
            nn.ReLU(True),
            nn.Linear(60, 60000),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return encoded, decoded


model = Autoencoder().cuda()
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    model.to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
                             lr=learning_rate,
                             weight_decay=1e-5)

for epoch in range(num_epochs):

    total_loss = 0

    for index, imgs in enumerate(train_dataloader):
        imgs = imgs.to(device)

        # ===================forward=====================
        outputs = model(imgs)

        imgs_flatten = imgs.view(imgs.size(0), -1)
        loss = criterion(outputs, imgs_flatten)

        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        print('{} Epoch, [{}/{}] batch, loss: {:.4f}'.format(epoch, index + 1, len(train_dataloader), loss.item()))

    avg_loss = total_loss / len(train_dataset)
    print('{} Epoch, avg_loss: {:.4f}'.format(epoch, avg_loss))


    if epoch % 10 == 0:
        check_point_file = os.path.join(dir_check_point, str(epoch) + ".pth")
        torch.save(model.state_dict(), check_point_file)

训练后,我尝试使用此代码获取编码值。

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

check_point = '/media/dohyeong/HDD/mouth_autoencoder/290.pth'
model = torch.load(check_point)

for index, imgs in enumerate(train_dataloader):

    imgs = imgs.to(device)

    # ===================evaluate=====================
    encoded, _ = model(imgs)

完成此错误消息。 “ TypeError:'collections.OrderedDict'对象不可调用” 我可以帮忙吗?

1 个答案:

答案 0 :(得分:1)

您好,欢迎来到PyTorch社区:D

TL; DR

model = torch.load(check_point)更改为model.load_state_dict(torch.load(check_point))


唯一的问题是一行:

model = torch.load(check_point)

保存检查点的方式是:

torch.save(model.state_dict(), check_point_file)

也就是说,您将模型的state_dict(只是各种参数的字典,一起描述了模型的当前实例)保存在check_point_file中。

现在,为了将其加载回去,只需逆转该过程即可。 check_point_file仅包含state_dict

它对模型的内部一无所知-它的架构是什么,应该如何工作等等。

因此,将其重新加载:

state_dict = torch.load(check_point)

state_dict现在可以如下复制到您的Model实例:

model.load_state_dict(state_dict)

或更简洁地说,

model.load_state_dict(torch.load(check_point))

您收到错误消息是因为torch.load(check_point)返回了您分配给state_dict的{​​{1}}

随后您调用model时,model(imgs)model对象(不可调用)。

出现错误。

有关更多详细信息,请参见Serialization Semantics Notes

除此之外,您的代码对于初学者来说肯定是完整的。太好了!


P.S。您的设备不可知性非常出色!也许您想看看:

  1. OrderedDict
  2. torch.load()model = Autoencoder().cuda()参数