如何获得单热矢量的密集表示

时间:2016-10-10 18:20:59

标签: tensorflow

假设Tensor包含:

[[0 0 1]
 [0 1 0]
 [1 0 0]]

如何以本机方式获取密集表示(不使用numpy或迭代)?

[2,1,0]

tf.one_hot()做反向,还有tf.sparse_to_dense()似乎这样做但我无法弄清楚如何使用它。

3 个答案:

答案 0 :(得分:12)

tf.argmax(x, axis=1)应该完成这项工作。

答案 1 :(得分:9)

vec = tf.constant([[0, 0, 1], [0, 1, 0], [1, 0, 0]])
locations = tf.where(tf.equal(vec, 1))
# This gives array of locations of "1" indices below
# => [[0, 2], [1, 1], [2, 0]])

# strip first column
indices = locations[:,1]
sess = tf.Session()
print(sess.run(indices))
# => [2 1 0]

答案 2 :(得分:2)

TensorFlow没有本地密集到稀疏的转换函数/帮助器。鉴于输入数组是一个密集张量,例如你提供的那个,你可以定义一个函数来将一个密集张量转换为一个稀疏张量。

def dense_to_sparse(dense_tensor):
    where_dense_non_zero = tf.where(tf.not_equal(dense_tensor, 0))
    indices = where_dense_non_zero
    values = tf.gather_nd(dense_tensor, where_dense_non_zero)
    shape = dense_tensor.get_shape()

    return tf.SparseTensor(
        indices=indices,
        values=values,
        shape=shape
    )

此辅助函数查找Tensor非零的索引和值,并输出具有这些索引和值的稀疏张量。此外,形状被有效复制。

您不想使用tf.sparse_to_dense,因为这会给您相反的表示。如果您希望输出为[2, 1, 0],则需要索引索引。首先,您需要数组不是0的索引

indices = tf.where(tf.not_equal(dense_tensor, 0))

然后,您需要使用切片/指示来访问张量:

output = indices[:, 1]

您可能会注意到上面切片中的1等于张量的维度 - 1.因此,要使这些值具有通用性,您可以执行以下操作:

output = indices[:, len(dense_tensor.get_shape()) - 1]

虽然我不确定您对这些值(值所在的列的值)的处理方式。希望这有帮助!

编辑:如果您正在寻找输入张量(如果为1)的索引/位置,Yaroslav的答案会更好;如果需要,它不会对具有非1/0值的张量进行扩展。