如何使用特定索引将数组张量转换为one_hot?

时间:2017-10-19 06:29:30

标签: tensorflow

a = tf.constant([20, 1, 5, 3, 123, 4])

我希望将其转换为tensor([0,0,0,1,0,0,0])(index = 3)

我怎样才能轻松做到这一点?

我真正想做的是这样的:有一个深度神经网络,有5个输出节点(用于分类)。假设一个前馈传播的输出是[5,22,3,4,11](类型tensor)。在这个前馈中,标签是1.所以,我需要打开这个索引的值并关闭其他像这样:[5,0,0,0,0]。最后,需要将值更改为1:[1,0,0,0,0]并在网络中反向传播(渐变)此张量。

2 个答案:

答案 0 :(得分:0)

您正在寻找的不是单热编码。也许这就是你想要实现的目标:

a = tf.constant([20, 1, 5, 3, 123, 4])
c = tf.cast(tf.equal(a, 3), tf.int32)    # 3 is your matching element
with tf.Session() as sess:
    print(c.eval())

# [0 0 0 1 0 0]

修改

如果您已经了解索引,可以通过多种方式完成此操作。 如果张量中的值有可能重复,则可以执行以下操作:

a = tf.constant([20, 1, 5, 3, 123, 4, 3])
c = tf.cast(tf.equal(a, a[3]), tf.int32)
with tf.Session() as sess:
    print(c.eval())
# [0 0 0 1 0 0 1]

但是如果你确定不重复这些值,你可以在这样的numpy数组的帮助下构造这个张量:

import numpy as np

c = np.zeros((7), np.int32)
c[3] = 1
c_tensor = tf.constant(c)
with tf.Session() as sess:
    print(c_tensor.eval())
# [0 0 0 1 0 0 0]

编辑2

基于新编辑的问题,为了进行分类任务,并且因为在我看来你没有进行自定义反向传播,让我给你一个你正在寻找的部分的骨架代码。

tf.reset_default_graph()

X = tf.placeholder(tf.float32, (None, 224, 224, 3))
y = tf.placeholder(tf.int32, (None))
one_hot_y = tf.one_hot(y, n_outputs)   # Generate one-hot vector

logits = My_Network(X)   # This function returns your network.
cross_entropy = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits, one_hot_y)) 
  # This function will compute softmax and get the loss function which you
  # would like to minimize.

optimizer = tf.train.AdamOptimizer(learning_rate = 0.01).minimize(cross_entropy)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for **each epoch**:
       for **generate batches of your data**:
            sess.run(optimizer, feed_dict = {X: batch_x, y: batch_y})

请花点时间了解代码。我还建议您按照分类任务的一些教程进行操作,因为它们非常容易获得。我建议你CNN by TensorFlow

答案 1 :(得分:0)

这段代码应该这样做。它使用Numpy:

import numpy as np
def one_hot(y):
  y = y.reshape(len(y))
  n_values = int(np.max(y)) + 1
  return tf.convert_to_tensor(np.eye(n_values)[np.array(y, dtype=np.int32)])

我不确定这是否是你需要的,但我希望它有所帮助。 例如:

>>> print(one_hot(np.array([2,3,4])))
>>> [[ 0.  0.  1.  0.  0.]
     [ 0.  0.  0.  1.  0.]
     [ 0.  0.  0.  0.  1.]]