如何在TensorFlow中将字符串标签转换为单热矢量?

时间:2017-04-25 06:01:49

标签: python machine-learning tensorflow

我是TensorFlow的新手,想要读取逗号分隔值(csv)文件,其中包含2列,第1列是索引,第2列是标签字符串。我有以下代码逐行读取csv文件中的行,我能够使用print语句正确获取csv文件中的数据。但是,我想从字符串标签进行单热编码转换,而不是如何在TensorFlow中进行。最终目标是使用tf.train.batch()函数,这样我就可以获得批量的单热标签向量来训练神经网络。

正如您在下面的代码中所看到的,我可以在TensorFlow会话中手动为每个标签条目创建一个热矢量。但是我如何使用tf.train.batch()函数?如果我移动线

label_batch = tf.train.batch([col2], batch_size=5)

进入TensorFlow会话块(用label_one_hot替换col2),程序阻塞什么都不做。我试图在TensorFlow会话之外移动单热矢量转换,但我没能使它正常工作。这样做的正确方法是什么?请帮忙。

label_files = []
label_files.append(LABEL_FILE)
print "label_files: ", label_files

filename_queue = tf.train.string_input_producer(label_files)

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
print "key:", key, ", value:", value

record_defaults = [['default_id'], ['default_label']]
col1, col2 = tf.decode_csv(value, record_defaults=record_defaults)

num_lines = sum(1 for line in open(LABEL_FILE))

label_batch = tf.train.batch([col2], batch_size=5)

with tf.Session() as sess:
    coordinator = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coordinator)

    for i in range(100):
        column1, column2 = sess.run([col1, col2])

        index = 0
        if column2 == 'airplane':
            index = 0
        elif column2 == 'automobile':
            index = 1
        elif column2 == 'bird':
            index = 2
        elif column2 == 'cat':
            index = 3
        elif column2 == 'deer':
            index = 4
        elif column2 == 'dog':
            index = 5
        elif column2 == 'frog':
            index = 6
        elif column2 == 'horse':
            index = 7
        elif column2 == 'ship':
            index = 8
        elif column2 == 'truck':
            index = 9

        label_one_hot = tf.one_hot([index], 10)  # depth=10 for 10 categories
        print "column1:", column1, ", column2:", column2
        # print "onehot label:", sess.run([label_one_hot])

    print sess.run(label_batch)

    coordinator.request_stop()
    coordinator.join(threads)

2 个答案:

答案 0 :(得分:2)

您可能想尝试将index变量反馈到占位符,然后通过tf.one_hot将其转换为单热矢量?这些方面的东西:

lbl = tf.placeholder(tf.uint8, [YOUR_BATCH_SIZE])
lbl_one_hot = tf.one_hot(lbl, YOUR_VOCAB_SIZE, 1.0, 0.0)
lb_h = sess.run([lbl_one_hot], feed_dict={lbl: index})

不确定您是否批量做事,所以如果不是你的情况,YOUR_BATCH_SIZE可能无关紧要。您也可以使用numpy.zeros来完成它,但我发现上面更清洁,更容易,特别是批处理。

答案 1 :(得分:0)

问这个问题已有2年多了,但这个答案可能对某些人仍然有用。这是在TF中将字符串标签转换为一键矢量的一种简单方法:

import tensorflow as tf

vocab = ['a', 'b', 'c']

input = tf.placeholder(dtype=tf.string, shape=(None,))
matches = tf.stack([tf.equal(input, s) for s in vocab], axis=-1)
onehot = tf.cast(matches, tf.float32)

with tf.Session() as sess:
    out = sess.run(onehot, feed_dict={input: ['c', 'a']})
    print(out) # prints [[0. 0. 1.]
               #         [1. 0. 0.]]