我是Tensorflow的新手,并且正在使用 IRMAS 数据集与CNN合作解决仪器识别问题。
IRMAS 数据集具有清晰的 training 集合,仅基于单个乐器,因此每个乐器(吉他,小提琴等)都有自己的目录,其中包含多个样本( 多分类)。 testing (测试)集中包含和弦音乐,因此同一首歌中可能会出现多种乐器( multi-classification , multi-label )。
这是预测和损失计算的最后一部分,我使用的是Estimator API
。
### previous part of network is omitted
# logits layer
logits = tf.layers.dense(inputs=dropout, units=labels.shape[1])
print('Shape Logits:', logits.shape)
predictions = {
# "classes": tf.equal(tf.argmax(tf.cast(logits, tf.float32), 1), tf.argmax(tf.cast(labels, tf.float32), 1)),
"classes": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(labels, tf.float32), logits=tf.cast(logits, tf.float32), name="sigmoid_tensor")
}
# correct_prediction = tf.equal(tf.argmax(tf.cast(logits, tf.float32), 1), tf.argmax(tf.cast(labels, tf.float32), 1))
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=labels, logits=logits)
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
#optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.96)
train_op = optimizer.minimize(
loss=loss,
global_step=tf.train.get_global_step())
logging_hook = tf.train.LoggingTensorHook({"loss" : loss}, every_n_iter=10)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks = [logging_hook])
eval_metric_ops = {
"accuracy": tf.metrics.accuracy(
labels=tf.argmax(input=labels, axis=1),
predictions=predictions["classes"]),
"recall": tf.metrics.recall(
labels=tf.argmax(input=labels, axis=1),
predictions=predictions["classes"]),
"precision": tf.metrics.precision(
labels=tf.argmax(input=labels, axis=1),
predictions=predictions["classes"])
}
经过几个时期之后,训练过程似乎开始了,因为损失的下降非常平稳。但是,评估模型所需的张量流指标似乎不一致。
这些是我在评估集中(训练集中的15%)的结果:
{'recall': 0.9989496, 'accuracy': 0.40656063, 'global_step': 1428, 'precision': 0.95004994, 'loss': 0.2432195}
这些是从测试集中获得的结果:
{'recall': 0.9981618, 'accuracy': 0.26530612, 'global_step': 1428, 'precision': 0.96533334, 'loss': 0.46097097}
在我看来,精度值似乎是唯一的对策,因为测试集的减少是由和弦音乐而非单一乐器组成的。我不明白为什么要得到这些结果以提高精确度和召回率,它们应该比我得到的准确度低。
我尝试显示FP,FN,TN和TP,它们似乎符合精度和召回率。实际上,通过计算它们的准确度,我得到了0.97左右(对于该分类问题而言过高)。
我不知道自己在做什么错,非常感谢您的帮助。
提前谢谢