tf.one_hot()是否支持SparseTensor作为索引参数?

时间:2017-07-18 06:30:51

标签: tensorflow tflearn

我想问一下tf.one_hot()函数是否支持SparseTensor作为"索引"参数。我想进行多标签分类(每个示例都有多个标签),这需要计算交叉熵损失。

我尝试直接将SparseTensor放入" indices"参数,但它引发以下错误:

TypeError:无法将类型对象转换为Tensor。内容:SparseTensor(indices = Tensor(" read_batch_features / fifo_queue_Dequeue:106",shape =(?,2),dtype = int64,device = / job:worker),values = Tensor(" string_to_index_Lookup :0",shape =(?,),dtype = int64,device = / job:worker),dense_shape = Tensor(" read_batch_features / fifo_queue_Dequeue:108",shape =(2,),dtype = int64,device = / job:worker))。考虑将元素转换为支持的类型。

关于可能原因的任何建议?

感谢。

2 个答案:

答案 0 :(得分:1)

您可以从最初的SparseTensor构建另一个形状为(batch_size, num_classes)的SparseTensor。例如,如果将类保留在单个字符串功能列(以空格分隔)中,则可以使用以下命令:

import tensorflow as tf

all_classes = ["class1", "class2", "class3"]
classes_column = ["class1 class3", "class1 class2", "class2", "class3"]

table = tf.contrib.lookup.index_table_from_tensor(
    mapping=tf.constant(all_classes)
)
classes = tf.constant(classes_column)
classes = tf.string_split(classes)
idx = table.lookup(classes) # SparseTensor of shape (4, 2), because each of the 4 rows has at most 2 classes
num_items = tf.cast(tf.shape(idx)[0], tf.int64) # num items in batch
num_entries = tf.shape(idx.indices)[0] # num nonzero entries

y = tf.SparseTensor(
    indices=tf.stack([idx.indices[:, 0], idx.values], axis=1),
    values=tf.ones(shape=(num_entries,), dtype=tf.int32),
    dense_shape=(num_items, len(all_classes)),
)
y = tf.sparse_tensor_to_dense(y, validate_indices=False)

with tf.Session() as sess:
    tf.tables_initializer().run()
    print(sess.run(y))

    # Outputs: 
    # [[1 0 1]
    #  [1 1 0]
    #  [0 1 0]
    #  [0 0 1]]

此处idx是SparseTensor。其索引idx.indices[:, 0]的第一列包含批次的行号,其值idx.values包含相关类ID的索引。我们将这两者结合起来创建新的y.indices

要全面实施多标签分类,请参阅https://stackoverflow.com/a/47671503/507062

的“选项2”

答案 1 :(得分:0)

one_hot不支持SparseTensor作为indices参数。您可以将稀疏张量的索引/值张量作为索引参数传递,这可能会解决您的问题。