如何在Tensorflow中为CTC丢失生成/读取稀疏序列标签?

时间:2017-03-03 12:08:57

标签: python tensorflow recurrent-neural-network

从包含转录的单词图片列表中,我尝试使用@Id @GeneratedValue(strategy = GenerationType.IDENTITY) @Column(name="TXN_ID") private int txnId; public TxnCustomer() { } public int getTxnId() { return this.txnId; } public void setTxnId(int txnId) { this.txnId = txnId; } 创建和阅读稀疏序列标签(tf.nn.ctc_loss),避免

  1. 将预打包的培训数据序列化到磁盘中 tf.train.slice_input_producer格式

  2. TFRecord

  3. 的明显局限性
  4. 任何不必要或过早的填充,

  5. 将整个数据集读取到RAM。

  6. 主要问题似乎是将字符串转换为tf.py_func所需的标签序列(SparseTensor)。

    例如,对于(有序)范围tf.nn.ctc_loss中的字符集,我希望将文本标签字符串[A-Z]转换为序列标签类列表"BAD"

    我想要阅读的每个示例图像都包含文本作为文件名的一部分,因此可以直接提取并直接进行python转换。 (如果有办法在TensorFlow计算中做到这一点,我还没有找到它。)

    之前的几个问题都是关注这些问题,但我还没有能够成功地整合它们。例如,

    有没有办法整合这些方法?

    另一个例子(问题#38012743)显示了如何延迟从字符串到列表的转换,直到将文件名出列用于解码,但它依赖于tf.train.Example,这有一些警告。 (我应该担心他们吗?)

    我认识到" SparseTensors不能很好地排队等等#34; (根据tf文档),所以可能有必要在批处理之前对结果(序列化?)做一些伏都教,甚至在计算发生的地方进行返工;我对此持开放态度。

    遵循MarvMind的大纲,这是一个基本框架,包含我想要的计算(迭代包含示例文件名的行,提取每个标签字符串并转换为序列),但我还没有成功确定" Tensorflow"这样做的方法。

    感谢您提供正确的"调整",这是针对我的目标的更合适的策略,或者表明tf.py_func不会破坏培训效率或其他下游(例如,加载)训练有素的模型以供将来使用)。

    编辑(+7小时)我发现了缺少操作的补丁。虽然仍然需要验证这与下游的CTC_Loss有关,但我已检查下面编辑的版本是否正确批量并读入图像和稀疏张量。

    tf.py_func

1 个答案:

答案 0 :(得分:1)

关键想法似乎是从所需数据中创建SparseTensorValue,将其传递给tf.convert_to_tensor_or_sparse_tensor,然后(如果要批量处理数据)将其序列化为tf.serialize_sparse。批处理后,您可以使用tf.deserialize_many_sparse恢复值。

这是大纲。创建稀疏值,转换为张量和序列化:

indices = [[i] for i in range(0,len(text))]
values = [out_charset.index(c) for c in list(text)]
shape = [len(text)]
label = tf.SparseTensorValue(indices,values,shape)
label = tf.convert_to_tensor_or_sparse_tensor(label)
label = tf.serialize_sparse(label) # needed for batching

然后,您可以进行批处理和反序列化:

image,label = tf.train.batch([image,label],dynamic_pad=True)
label = tf.deserialize_many_sparse(label,tf.int32)