如何在tensorflow 2中使用未知参数的ctc损失?

时间:2020-08-15 11:57:34

标签: tensorflow tensorflow2.0 tensorflow-datasets loss-function ctc

我在TFRecord中有一个数据集,该数据集分为多个文件。

我的记录如下:

sequence_features = {
    'frames': tf.io.FixedLenSequenceFeature([], dtype=tf.string),
    "label": tf.io.FixedLenSequenceFeature([], dtype=tf.int64),
}

context_features = {
    "frames_count": tf.io.FixedLenFeature([], dtype=tf.int64),
    "num_tokens":  tf.io.FixedLenFeature([], dtype=tf.int64),
}

frameslabel序列的长度不同。 帧的形状为BxTxHxWxC张量,其中B是批处理大小,T是批处理中最大的帧数,H是图像高度(112),{ {1}}是图片宽度(112),W是频道(1)

C的形状为label,其中BxS是批次中最长的标签。

我正在S上建立我的模型,并像这样编译它:

tf.keras.Sequential()

问题model.compile( loss=tf.nn.ctc_loss(), optimizer=tf.keras.Adam(), ) 需要4个参数-tf.nn.ctc_loss()labelslogitslabel_length。但是在模型编译时我还不知道,因为它在批次之间以及每个时期都不同。

在整个数据集中将帧计数或标签填充或裁剪为相同的长度并不是真正的选择,因为最长的序列可能比最短的序列长20倍。

我正在这样读取数据集:

logit_length

在这种情况下我该怎么办?或者如何更改代码以利用ctc损失?

0 个答案:

没有答案