我正在使用Chainer进行书面数字识别。我在MNIST数据库上训练了一个模型。但是,出于某种原因,我无法对单个示例进行分类。我的猜测是我没有为示例选择正确的格式,但我已经尝试了很多方法,但我仍然无法解决这个问题。抱歉,如果这很明显,我对Python没有经验。
This is what the error looks like
代码:
import chainer
import chainer.functions as F
import chainer.links as L
import os
import sys
from chainer.training import extensions
import numpy as np
from tkinter.filedialog import askopenfilename
from PIL import Image
from chainer import serializers
from chainer.dataset import concat_examples
from chainer.backends import cuda
from chainer import Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
CONST_RESUME = ''
class Network(chainer.Chain):
def __init__(self, n_units, n_out):
super(Network, self).__init__()
with self.init_scope():
self.l1 = L.Linear(None, n_units)
self.l2 = L.Linear(None, n_units)
self.l3 = L.Linear(None, n_out)
def __call__(self, x):
h1 = F.sigmoid(self.l1(x))
h2 = F.sigmoid(self.l2(h1))
return self.l3(h2)
def query_yes_no(question, default="no"):
"""Ask a yes/no question via raw_input() and return their answer.
"question" is a string that is presented to the user.
"default" is the presumed answer if the user just hits <Enter>.
It must be "yes" (the default), "no" or None (meaning
an answer is required of the user).
The "answer" return value is True for "yes" or False for "no".
"""
valid = {"yes": True, "y": True, "ye": True,
"no": False, "n": False}
if default is None:
prompt = " [y/n] "
elif default == "yes":
prompt = " [Y/n] "
elif default == "no":
prompt = " [y/N] "
else:
raise ValueError("invalid default answer: '%s'" % default)
while True:
sys.stdout.write(question + prompt)
choice = input().lower()
if default is not None and choice == '':
return valid[default]
elif choice in valid:
return valid[choice]
else:
sys.stdout.write("Please respond with 'yes' or 'no' "
"(or 'y' or 'n').\n")
return np.argmax(y.data)
def main():
#file_list = [None]*10
#for i in range(10):
# file_list[i] = open('data{}.txt'.format(i), 'rb')
print('MNIST digit recognition.')
usr_in : str
model = L.Classifier(Network(100, 10))
chainer.backends.cuda.get_device_from_id(0).use()
model.to_gpu()
usr_in = input('Input (t) to train, (l) to load.')
while len(usr_in) != 1 and (usr_in[0] != 'l' or usr_in[0] != 't'):
print('Invalid input.')
usr_in = input()
if usr_in[0] == 't':
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, batch_size=100, shuffle=True)
test_iter = chainer.iterators.SerialIterator(test, 100, repeat=False, shuffle=False)
updater = training.updaters.StandardUpdater(train_iter, optimizer, device=0)
trainer = training.Trainer(updater, (10, 'epoch'), out='out.txt')
trainer.extend(extensions.Evaluator(test_iter, model, device=0))
trainer.extend(extensions.dump_graph('main/loss'))
frequency = 1
trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))
trainer.extend(extensions.LogReport())
if extensions.PlotReport.available():
trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
trainer.extend(extensions.ProgressBar())
if CONST_RESUME:
chainer.serializers.load_npz(CONST_RESUME, trainer)
trainer.run()
ans = query_yes_no('Would you like to save this network?')
if ans:
usr_in = input('Input filename: ')
serializers.save_npz('{}.{}'.format(usr_in, 'npz'), model)
elif usr_in[0] == 'l':
filename = askopenfilename(initialdir=os.getcwd(), title='Choose a file')
serializers.load_npz(filename, model)
else:
return
while True:
ans = query_yes_no('Would you like to evaluate an image with the current network?')
if ans:
filename = askopenfilename(initialdir=os.getcwd(), title='Choose a file')
file = Image.open(filename)
bw_file = file.convert('L')
size = 28, 28
bw_file.thumbnail(size, Image.ANTIALIAS)
pix = bw_file.load()
x = np.empty([28 * 28])
for i in range(28):
for j in range(28):
x[i * 28 + j] = pix[i, j]
#gpu_id = 0
#batch = (, gpu_id)
x = (x.astype(np.float32))[None, ...]
y = model(x)
print('predicted_label:', y.argmax(axis=1)[0])
else:
return
main()
答案 0 :(得分:0)
您能告诉我们错误信息吗?否则很难猜出导致错误的原因。
但我猜输入形状可能不同。
当您使用chainer内置函数获取数据集时,
train, test = chainer.datasets.get_mnist()
这些数据集图像形状为(minibatch, channel, height, width)
。但似乎你正在构造输入x
作为形状(minibatch, height * width =28*28=784)
,这是不同的形状??
你也可以参考chainer的一些教程,