我遵循了本教程:https://medium.com/@vijayabhaskar96/multi-label-image-classification-tutorial-with-keras-imagedatagenerator-cd541f8eaf24 并编写了一些用于多标签分类的代码。我使用小规模的一键编码进行了工作,但是我不得不转到本文中提到的选项2,因为我有6000个类,因此一键编码是不可行的。我设法训练了网络,它说99%的准确性和83%的f1分数。但是,当我尝试测试网络时,对于每个图像,当有6000个可能的标签时,它仅输出3个标签的某种组合。我想知道测试模型的代码是否不正确。我尝试使用帖子中提到的代码,但它不起作用:
test_generator.reset()
pred = model.predict_generator(test_generator, steps=STEP_SIZE_TEST, verbose=1);
pred_bool = (pred > 0.5)
无序
types: list() > float()
我已经尽力解决了这个问题,但并未弄清楚,而且我找不到任何人在网上做任何类似事情的例子。有谁知道如何使用此代码块来使这个预测部分起作用(我有另外2个选项,并且在打印一个或多个标签时遇到了问题)还是为什么模型可能无法通过这种行为进行训练? / p>
编辑:有关培训问题的更多背景信息,请参阅以下所有培训代码:
import json
input_file = open ('class_names_6000.json')
json_array = json.load(input_file)
#print(str(json_array))
args = parser.parse_args()
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
print('Loading Data...')
df = pd.read_csv('dataset_train.csv')
df["labels"]=df["labels"].apply(lambda x:x.split(","))
datagen=ImageDataGenerator(rescale=1./255.)
test_datagen=ImageDataGenerator(rescale=1./255.)
train_generator=datagen.flow_from_dataframe(
dataframe=df,
directory="",
x_col="Filepaths",
y_col="labels",
batch_size=128,
seed=42,
shuffle=True,
class_mode="categorical",
classes=json_array,
target_size=(100,100))
df = pd.read_csv('dataset_test.csv')
df["labels"]=df["labels"].apply(lambda x:x.split(","))
test_generator=test_datagen.flow_from_dataframe(
dataframe=df,
directory="",
x_col="Filepaths",
y_col="labels",
batch_size=128,
seed=42,
shuffle=True,
class_mode="categorical",
classes=json_array,
target_size=(100,100))
df = pd.read_csv('dataset_validation.csv')
df["labels"]=df["labels"].apply(lambda x:x.split(","))
valid_generator=test_datagen.flow_from_dataframe(
dataframe=df,
directory="",
x_col="Filepaths",
y_col="labels",
batch_size=128,
seed=42,
shuffle=True,
class_mode="categorical",
classes=json_array,
target_size=(100,100))
print('Data Loaded.')
f1_score_callback = ComputeF1()
model = build_model('train', numclasses=len(json_array), model_name = args.model)
ImageFile.LOAD_TRUNCATED_IMAGES = True
还有一个重要的细节,在训练时,它说准确率是99%,f1分数是84%,而有效的f1分数也是84%。