Chainer:无法分类,训练有素的模型(x)抛出错误

时间:2018-04-21 16:55:29

标签: python machine-learning chainer

我正在使用Chainer进行书面数字识别。我在MNIST数据库上训练了一个模型。但是,出于某种原因,我无法对单个示例进行分类。我的猜测是我没有为示例选择正确的格式,但我已经尝试了很多方法,但我仍然无法解决这个问题。抱歉,如果这很明显,我对Python没有经验。

This is what the error looks like

代码:

import chainer
import chainer.functions as F
import chainer.links as L
import os
import sys
from chainer.training import extensions
import numpy as np
from tkinter.filedialog import askopenfilename
from PIL import Image
from chainer import serializers
from chainer.dataset import concat_examples


from chainer.backends import cuda
from chainer import Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList


CONST_RESUME = ''


class Network(chainer.Chain):
    def __init__(self, n_units, n_out):
        super(Network, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None, n_units)
            self.l2 = L.Linear(None, n_units)
            self.l3 = L.Linear(None, n_out)

    def __call__(self, x):
        h1 = F.sigmoid(self.l1(x))
        h2 = F.sigmoid(self.l2(h1))
        return self.l3(h2)



def query_yes_no(question, default="no"):
    """Ask a yes/no question via raw_input() and return their answer.

    "question" is a string that is presented to the user.
    "default" is the presumed answer if the user just hits <Enter>.
        It must be "yes" (the default), "no" or None (meaning
        an answer is required of the user).

    The "answer" return value is True for "yes" or False for "no".
    """
    valid = {"yes": True, "y": True, "ye": True,
             "no": False, "n": False}
    if default is None:
        prompt = " [y/n] "
    elif default == "yes":
        prompt = " [Y/n] "
    elif default == "no":
        prompt = " [y/N] "
    else:
        raise ValueError("invalid default answer: '%s'" % default)

    while True:
        sys.stdout.write(question + prompt)
        choice = input().lower()
        if default is not None and choice == '':
            return valid[default]
        elif choice in valid:
            return valid[choice]
        else:
            sys.stdout.write("Please respond with 'yes' or 'no' "
                             "(or 'y' or 'n').\n")




    return np.argmax(y.data)

def main():
    #file_list = [None]*10
    #for i in range(10):
    #    file_list[i] = open('data{}.txt'.format(i), 'rb')

    print('MNIST digit recognition.')
    usr_in : str



    model = L.Classifier(Network(100, 10))
    chainer.backends.cuda.get_device_from_id(0).use()
    model.to_gpu()

    usr_in = input('Input (t) to train, (l) to load.')
    while len(usr_in) != 1 and (usr_in[0] != 'l' or usr_in[0] != 't'):
        print('Invalid input.')
        usr_in = input()

    if usr_in[0] == 't':



        optimizer = chainer.optimizers.Adam()
        optimizer.setup(model)

        train, test = chainer.datasets.get_mnist()


        train_iter = chainer.iterators.SerialIterator(train, batch_size=100, shuffle=True)
        test_iter = chainer.iterators.SerialIterator(test, 100, repeat=False, shuffle=False)

        updater = training.updaters.StandardUpdater(train_iter, optimizer, device=0)
        trainer = training.Trainer(updater, (10, 'epoch'), out='out.txt')

        trainer.extend(extensions.Evaluator(test_iter, model, device=0))

        trainer.extend(extensions.dump_graph('main/loss'))

        frequency = 1
        trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

        trainer.extend(extensions.LogReport())

        if extensions.PlotReport.available():
            trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'))
            trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'))

        trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
        trainer.extend(extensions.ProgressBar())

        if CONST_RESUME:
            chainer.serializers.load_npz(CONST_RESUME, trainer)

        trainer.run()

        ans = query_yes_no('Would you like to save this network?')
        if ans:
            usr_in = input('Input filename: ')
            serializers.save_npz('{}.{}'.format(usr_in, 'npz'), model)


    elif usr_in[0] == 'l':
        filename = askopenfilename(initialdir=os.getcwd(), title='Choose a file')
        serializers.load_npz(filename, model)

    else:
        return

    while True:
        ans = query_yes_no('Would you like to evaluate an image with the current network?')
        if ans:
            filename = askopenfilename(initialdir=os.getcwd(), title='Choose a file')
            file = Image.open(filename)
            bw_file = file.convert('L')
            size = 28, 28
            bw_file.thumbnail(size, Image.ANTIALIAS)
            pix = bw_file.load()
            x = np.empty([28 * 28])
            for i in range(28):
                for j in range(28):
                    x[i * 28 + j] = pix[i, j]

            #gpu_id = 0
            #batch = (, gpu_id)
            x = (x.astype(np.float32))[None, ...]

            y = model(x)
            print('predicted_label:', y.argmax(axis=1)[0])

        else:
            return

main()

1 个答案:

答案 0 :(得分:0)

您能告诉我们错误信息吗?否则很难猜出导致错误的原因。

但我猜输入形状可能不同。 当您使用chainer内置函数获取数据集时, train, test = chainer.datasets.get_mnist() 这些数据集图像形状为(minibatch, channel, height, width)。但似乎你正在构造输入x作为形状(minibatch, height * width =28*28=784),这是不同的形状??

你也可以参考chainer的一些教程,