TL; DR:如何将每个实例2个标签的2D二进制张量拆分为每个实例仅1个标签的2个张量,如下图所示:
作为自定义损失函数的一部分,我试图将每个实例2个标签的多标签y张量拆分为每个实例1个标签的2个y张量。 当我在1D y张量上执行此代码时,此代码非常有用:
y_true = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 0.])
label_cls = tf.where(tf.equal(y_true, 1.))
idx1, idx2 = tf.split(label_cls,2)
raplace = tf.constant([1.])
y_true_1 = tf.scatter_nd(tf.cast(idx1, dtype=tf.int32), raplace, [tf.size(y_true)])
y_true_2 = tf.scatter_nd(tf.cast(idx2, dtype=tf.int32), raplace, [tf.size(y_true)])
with tf.Session() as sess:
print(sess.run([y_true_1,y_true_2]))
我得到:
[array([1., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), array([0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=float32)]
但是当我在训练中使用批处理时,会出现此错误:
Invalid argument: Outer dimensions of indices and update must match.
由于我的“ y张量”是2D而不是1D,因此在这种情况下idx1, idx2
(索引)不正确,replace
的形状(更新)也不正确。
据我了解,tf.scatter_nd
只能更新变量的第一个维度,那么如何解决呢?以及如何获取所需的索引?
答案 0 :(得分:0)
我觉得你走的是曲折的道路。这是我的解决方案。感觉比尝试的方法更简单(尝试tf 1.14)。
import tensorflow as tf
y_true = tf.constant([[1, 0, 1, 0],[0, 1, 1, 0]])
_, label_inds = tf.math.top_k(y_true, k=2)
idx1, idx2 = tf.split(label_inds,2, axis=1)
y_true_1 = tf.one_hot(idx1, depth=4)
y_true_2 = tf.one_hot(idx2, depth=4)
with tf.Session() as sess:
print(sess.run([y_true_1, y_true_2]))
因此,您的想法是获取每行的前2个标签的索引。然后使用tf.split
将其分为2列。然后使用one_hot
将这些索引转换回onehot向量。