当我尝试计算predict_generator()
之后的精度时,最终得到的精度与evaluate_generator()
计算出的精度不同。
不确定是否相关,但是不确定shuffle = True
类中flow_from_generator()
中的DataGenerator
。
idg_train
和idg_test
是ImageDataGenerator
对象。
# TensorFlow, Keras and NumPy
from tensorflow import keras
from keras.optimizers import Adam
from keras.losses import categorical_crossentropy
import numpy as np
# Own libraries
from DataManipulation import create_dataset, DataGenerator
from ModelZoo import variable_conv_layers
# Data Generation
train_gen = DataGenerator(generator = idg_train, subset = 'training', **params)
val_gen = DataGenerator(generator = idg_train, subset = 'validation', **params)
val_gen = DataGenerator(generator = idg_test, **params)
y_true = test_gen.generator.classes
# Model preparation
model = variable_conv_layer(**model_params) # Creates model
model.compile(optimizer = Adam(lr = 1e-4),
loss = categorical_crossentropy,
metrics = ['accuracy'])
# Training
model.fit_generator(train_gen,
epochs = 1,
validation_data = val_gen,
workers = 8,
use_multiprocessing = True,
shuffle = True)
# Prediction
scores = model.predict_generator(test_gen,
workers = 8,
use_multiprocessing = True)
pred = np.argmax(scores, axis = -1)[:len(test_gen.generator.classes)]
acc = np.mean(pred == y_true)
print("%s: %1.3e" % ("Manual accuracy", acc))
print("Evaluated [loss, accuracy]:", model.evaluate_generator(test_gen,
workers = 8,
use_multiprocessing = True)
这将打印以下内容:
Manual accuracy: 1.497e-01
Evaluated [loss, accuracy]: [0.308414297710572, 0.9838169642857143]
很明显,手动计算的精度与evaluate_generator()
中的精度不同。我已经连续看了好几个小时,不知道问题可能在哪里。
谢谢!
编辑:另外,我尝试使用sklearn.metrics.confusion_matrix(y_true, pred)
创建一个混淆矩阵,它产生以下数组:
[[407 0 70 1 8 1 0 57 0]
[413 0 74 15 0 16 1 32 0]
[230 0 40 0 0 4 4 32 0]
[239 0 40 0 0 2 2 36 0]
[282 0 34 0 0 7 1 39 0]
[296 0 37 0 3 4 0 40 0]
[377 0 39 2 8 8 0 42 0]
[183 0 28 4 6 4 0 19 0]
[283 0 46 6 5 6 0 33 0]]
由于某些原因,仅使用np.argmax(scores, axis = -1)
时,似乎可以预测很大的多数为'0'。
答案 0 :(得分:0)
只需在第二次使用前重置test_gen:
test_gen.reset()
print("Evaluated [loss, accuracy]:", model.evaluate_generator(
test_gen,
workers = 8,
use_multiprocessing = True
)