如何“一个热编码” Tensorflow数据集?

时间:2018-11-30 15:36:42

标签: python tensorflow one-hot-encoding

Newby在这里...我按如下方式加载了TF数据集:

dataset = tf.data.TFRecordDataset(files)
dataset.map(extract_fn)

数据集包含带有某些值的“字符串列”,我想对其进行“一次热”编码。如果我有索引和深度(到目前为止,我只有一个String值),则可以按记录在extract_fn记录中执行此操作。但是,有没有TF函数可以帮我做到这一点?即

  • 计算不同值的数量
  • 将每个值映射到索引
  • 为此创建一个热编码列

1 个答案:

答案 0 :(得分:0)

我认为这可以满足您的要求

import tensorflow as tf

def one_hot_any(a):
    # Save original shape
    s = tf.shape(a)
    # Find unique values
    values, idx = tf.unique(tf.reshape(a, [-1]))
    # One-hot encoding
    n = tf.size(values)
    a_1h_flat = tf.one_hot(idx, n)
    # Reshape to original shape
    a_1h = tf.reshape(a_1h_flat, tf.concat([s, [n]], axis=0))
    return a_1h, values

# Test
x = tf.constant([['a', 'b'], ['a', 'd'], ['c', 'd'], ['b', 'd']])
x_1h, x_vals = one_hot_any(x)
with tf.Session() as sess:
    print(*sess.run([x_1h, x_vals]), sep='\n')

输出:

[[[1. 0. 0. 0.]
  [0. 1. 0. 0.]]

 [[1. 0. 0. 0.]
  [0. 0. 1. 0.]]

 [[0. 0. 0. 1.]
  [0. 0. 1. 0.]]

 [[0. 1. 0. 0.]
  [0. 0. 1. 0.]]]
[b'a' b'b' b'd' b'c']

但是,问题在于,不同的输入将产生不一致的输出,具有不同的价值顺序甚至是不同的“一火”深度,因此我不确定它是否真的有用。