tf.SparseTensor和tf.nn.ctc_loss

时间:2017-10-09 20:08:20

标签: tensorflow

我想重新构建tf.nn.ctc_loss中第一个参数的标签。

标签存储在labels的张量shape=[batch_size x max_time]中,由于第二个维度已用零填充,因此标签的真实长度存储在另一个张量labels_lengthshape=[batch_size]

我不清楚labels的{​​{1}}参数应该是什么样子。我正在读它应该是tf.nn.ctc_lossSparseTensorindicesvalues的正确形状和内容是什么?

1 个答案:

答案 0 :(得分:2)

如果我理解正确,您当前的输入labels如下所示:

[[4, 3, 1, 2, 5],
 [2, 3, 4, 1, 0],
 [1, 2, 3, 0, 0],
 [5, 4, 0, 0, 0]]

我们有batch_size=4max_time=5labels_length=[5,4,3,2]

如果是这种情况,您可以使用here所述的相同方法将其转换为SparseTensor。因此,只要您始终用零填充,就不需要使用labeles_length

import tensorflow as tf
labels = tf.Variable([[4, 3, 1, 2, 5],
                      [2, 3, 4, 1, 0],
                      [1, 2, 3, 0, 0],
                      [5, 4, 0, 0, 0]], tf.int32)
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  idx = tf.where(tf.not_equal(labels, 0))
  sparse = tf.SparseTensor(idx, tf.gather_nd(labels, idx), labels.get_shape())
  s = sess.run(sparse)
  print s.indices
  print s.values
  print s.dense_shape

> [[0 0]
   [0 1]
   [0 2]
   [0 3]
   [0 4]
   [1 0]
   [1 1]
   [1 2]
   [1 3]
   [2 0]
   [2 1]
   [2 2]
   [3 0]
   [3 1]]
> [4 3 1 2 5 2 3 4 1 1 2 3 5 4]
> [4 5]

为了更好地理解稀疏张量(实际上它不是张量,而是围绕三个张量的包装),请参阅documentation