CNN训练的模型似乎不起作用

时间:2020-08-18 07:10:58

标签: deep-learning pytorch

我已经训练了CNN模型,并且希望针对新数据运行训练后的模型。但是,似乎训练后的模型无法像训练中那样正确地预测计数。我感觉该模型未使用PTH文件。有人可以告诉我我在做什么错吗?

import argparse
import datetime
import glob
import os
import random
import shutil
import time
from os.path import join

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor
from tqdm import tqdm
import torch.optim as optim

from convnet3_eval import Convnet
from dataset2_eval import CellsDataset

parser = argparse.ArgumentParser('Predicting hits from pixels')
parser.add_argument('name',type=str,help='Name of experiment')
parser.add_argument('data_dir',type=str,help='Path to data directory containing images and gt.csv')
parser.add_argument('--weight_decay',type=float,default=0.0,help='Weight decay coefficient (something like 10^-5)')
parser.add_argument('--lr',type=float,default=0.0001,help='Learning rate')
args = parser.parse_args()

metadata = pd.read_csv(join(args.data_dir,'gt.csv'))
metadata.set_index('filename', inplace=True)


dataset = CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
dataset = DataLoader(dataset,num_workers=4,pin_memory=True)
model_path = '/base_model.pth'

model = Convnet()
optimizer = torch.optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)

for images, paths in tqdm(dataset):

    targets = torch.tensor([metadata['count'][os.path.split(path)[-1]] for path in paths]) # B
    targets = targets.float()

    # code to print training data to a csv file
    filename=CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
    output = model(images) # B x 1 x 9 x 9 (analogous to a heatmap)
    preds = output.sum(dim=[1,2,3]) # predicted cell counts (vector of length B)
    print(preds)
    paths_test = np.array([paths])
    names_preds = np.hstack(paths)
    print(names_preds)                
    df=pd.DataFrame({'Image_Name':names_preds, 'Target':targets.detach(), 'Prediction':preds.detach()})
    print(df) 
    # save image name, targets, and predictions
    df.to_csv(r'model.csv', index=False, mode='a')


model.load_state_dict(torch.load(model_path))
model.eval()

1 个答案:

答案 0 :(得分:1)

将最后两行移动到加载权重的位置

model.load_state_dict(torch.load(model_path))
model.eval()

在下面初始化模型的for循环上方。