我正在尝试从头开始在自己的数据集上实现this keras implementation of yolov3。在将代码更改为我的要求并进行了试用培训之后,我试图通过keras中的predict_generator函数来验证预测。但是它会引发错误。
model.compile(optimizer=Adam(lr=1e-4), loss={'yolo_loss': lambda y_true, y_pred: y_pred},metrics=['accuracy']) # recompile to apply the change
print('Unfreeze all of the layers.')
batch_size = 2
print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
model.fit_generator(data_generator_wrapper(lines[:num_train], batch_size, input_shape, anchors, num_classes),
steps_per_epoch=max(1, num_train//batch_size),
validation_data=data_generator_wrapper(lines[num_train:], batch_size, input_shape, anchors, num_classes),
validation_steps=max(1, num_val//batch_size),
epochs=1,
initial_epoch=0,
callbacks=[logging, checkpoint, reduce_lr, early_stopping])
predict = model.predict_generator(data_generator_wrapper(lines[:num_train], batch_size, input_shape, anchors, num_classes),verbose=1,steps=10)
model.save_weights(log_dir + 'trained_weights_final.h5')
这是错误
_main中的文件“ yolo_train.py”,第83行 model.predict_generator(data_generator_wrapper(lines [:num_train],batch_size,input_shape,anchors,num_classes),详细= 1,步骤= 10) 包装中的文件“ /usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py”,第91行 return func(* args,** kwargs) 1522行中的“ /usr/local/lib/python3.6/dist-packages/keras/engine/training.py”文件,位于预报生成器中 详细=详细) 预测文件生成器中的文件“ /usr/local/lib/python3.6/dist-packages/keras/engine/training_generator.py”,行474 返回np.concatenate(all_outs [0]) ValueError:零维数组无法串联
如果形状不匹配,我将无法理解训练效果如何,而且我也没有修改任何输出张量。仅更改了输入图像的大小并设置了类数1。