如何避免在Tensorflow 2中为CTC损失模型定义目标张量?

时间:2020-04-26 13:43:09

标签: python tensorflow keras tensorflow2.0 ctc

我正在尝试使用tf.distribute.MirroredStrategy()在Tensorflow 2中针对具有CTC损失的模型进行多GPU训练。

问题是模型需要定义target_tensor才能进行编译。 这可能是什么原因? 有没有定义target_tensors的解决方法和编译模型吗?

如果我不通过目标,则会得到以下信息:

TypeError: Value passed to parameter 'indices' has DataType float32 not in list of allowed values: uint8, int32, int64

使用类似Keras的功能API定义模型:

model = Model(name ='Joined_Model_2',inputs=self.inp, outputs=[self.network.outp, self.network.outp_stt])

模型必须编译为:

self.model_joined.compile(optimizer=optimizer_stt,
            loss=losses,
            loss_weights= lossWeights,
            target_tensors=[target1, target2]                      
            )

该模型有2个输出,但是第二个使用的CTC损失导致了问题。

1 个答案:

答案 0 :(得分:0)

这是通过使用tf-nightly版本解决的。

Tf-nightly不允许在急切的执行模式下使用target_tensors。 在夜间版本中,我的模型可以在没有目标张量的情况下成功编译(实现没有任何变化),因此问题得以解决。