多标签分类:如何学习阈值?

时间:2017-05-30 15:11:58

标签: python tensorflow deep-learning

我有一个很深的CNN,适用于多类分类。我想"升级"挑战并在多标签分类问题上进行训练。

为此,我用sigmoid替换了我的softmax并尝试训练我的网络以最小化:

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred)

但我最终得到了奇怪的预测:

Prediction for Im1 : [ 0.59275776  0.08751075  0.37567005  0.1636796   0.42361438  0.08701646

0.38991812 0.54468459 0.34593087 0.82790571]

Prediction for Im1 : [ 0.52609032  0.07885984  0.45780018  0.04995904  0.32828355  0.07349177

0.35400775 0.36479294 0.30002621 0.84438241]

Prediction for Im1 : [ 0.58714485  0.03258472  0.3349618   0.03199361  0.54665488  0.02271551

0.43719986 0.54638696 0.20344526 0.88144571]

所以我想我尝试让我的网络学习每个班级的门槛,以确定样本是否属于班级。

所以我把它添加到我的代码中:

initial = tf.truncated_normal([numberOfClasses], stddev=0.1)
W_thresh = tf.Variable(initial)

y_predict_thresh = int(y_predict > W_thresh)

但我有一个错误:

TypeError: int() argument must be a string or a number, not 'Tensor'.

任何人都有任何想法可以帮助我继续前进(如何避免这个错误?,我的数据集是否真的不平衡的事实导致了这种"常数"预测?多标签分类的其他建议? ,...)?

谢谢

编辑:我刚刚意识到做阈值处理对于反向传播可能不是很酷:/

1 个答案:

答案 0 :(得分:2)

不知道你是否还需要它,但你可以使用张量流转换函数tf.to_int32, tf.to_int64。在评估之前,你的表达式是python的一个对象,因此它不能简单地将它转换为int()

这可以满足您的需求:

with tf.Session() as sess:
    check = sess.run([tf.to_int64(W1 > W2)])