我正在使用Tensorflow JS实现手写识别模型。 我的模型是RNN模型(开始时带有卷积层),并通过自定义损失进行了优化,如果在tensorflow.js库中可用的话,这将是ctc损失。
我的自定义损失相当大,它只是在真实标签和预测标签中都找到分隔符(这是ctc空白类),使用它们使预测和真实对齐,并返回相关部分的交叉熵。损失值似乎也下降得太快。
我的训练数据很少,这确实使模型最终过拟合。
我正在用句子“快速的棕色狐狸跳过懒狗”测试我的模型,从损失函数中我可以打印的内容来看,这有点“有效”:
例如, -q-u-i-c-k
被识别为qu-uu-uk
。
-t-h-e
将被视为thehehehehehehe
。
该模型似乎与我的分隔符不符,结果差异很大。
这是我的模特摘要:
_________________________________________________________________
Layer (type) Output shape Param #
=================================================================
conv1d_Conv1D1 (Conv1D) [null,48,100] 1900
_________________________________________________________________
max_pooling1d_MaxPooling1D1 [null,24,100] 0
_________________________________________________________________
bidirectional_Bidirectional1 [null,64] 84480
_________________________________________________________________
dropout_Dropout1 (Dropout) [null,64] 0
_________________________________________________________________
repeat_vector_RepeatVector1 [null,50,64] 0
_________________________________________________________________
bidirectional_Bidirectional2 [null,64] 66048
_________________________________________________________________
dropout_Dropout2 (Dropout) [null,64] 0
_________________________________________________________________
repeat_vector_RepeatVector2 [null,50,64] 0
_________________________________________________________________
bidirectional_Bidirectional3 [null,50,64] 66048
_________________________________________________________________
dropout_Dropout3 (Dropout) [null,50,64] 0
_________________________________________________________________
time_distributed_TimeDistrib [null,50,27] 1755
_________________________________________________________________
activation_Activation1 (Acti [null,50,27] 0
=================================================================
Total params: 220231
Trainable params: 220231
Non-trainable params: 0
_________________________________________________________________
conv1d层具有ReLU激活,最后一层是Softmax。 双向层是具有glorotNormal作为循环初始化程序的LSTM。 分布的时间是一个密集层。 优化器是亚当。
我正在努力改善结果。
损失函数显然需要一些工作,但是我不知道该去哪里。您对此有何建议?
我的模型还有其他需要改进的地方吗?