自身损失函数,tf.edit_distance,modle.compile错误

时间:2019-06-14 12:51:49

标签: python tensorflow keras loss-function loss

我正在尝试为端到端文字处理网络(Keras +张量流)实现我自己的丢失功能。目的是通过WER(字错误率)来表示损失。 第一步是确定编辑距离并将其用作损失。为此,经过长时间的搜索,我找到了tf.edit_distance(我没有得到一个自我开发的代码来运行[python或keras或tf都没有正常的Python])。 经过一番尝试和错误,以及一些研究之后,我没有直接从损失函数中得到任何错误消息,但是现在有了一个我无法执行任何操作的新消息。

我在python 2 GPU运行时上,在Google Colab上将Keras 2.0.4与tf 2.0a一起使用。

这是我的损失函数的定义。

def create_sparse(dense):
  zero = K.tf.constant(0, dtype=K.tf.float32)
  where = K.tf.not_equal(dense, zero)
  indices = K.tf.where(where)
  values = K.tf.gather_nd(dense, indices)
  shape=K.tf.cast(K.tf.shape(dense), K.tf.int64)
  sparse = K.tf.SparseTensor(indices, values, shape)
  return sparse

def ed(y_true,y_pred):
  loss = K.tf.edit_distance(create_sparse(y_pred), create_sparse(y_true), normalize=True)
  # print(loss)
  return loss

def editDistanceLoss():
  def myHelperLoss(y_true, y_pred):
    return ed(y_true, y_pred)
  return myHelperLoss

我这样使用它:

loss_editDistance = editDistanceLoss()
...
model.compile(loss=loss_editDistance, optimizer=optimizer, metrics=['accuracy'])

当我开始编译时,出现以下错误消息:

TypeErrorTraceback (most recent call last)
<ipython-input-98-5787fc39817d> in <module>()
      2 # opt = Adam(lr=LR)  # keep calm and reduce learning rate
      3 # model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
----> 4 model.compile(loss=loss_editDistance, optimizer=optimizer, metrics=['accuracy'])
      5 # model.compile(loss=loss_WER, optimizer=optimizer, metrics=['accuracy'])

1 frames
/usr/local/lib/python2.7/dist-packages/keras/engine/training_utils.pyc in weighted(y_true, y_pred, weights, mask)
    418             weight_ndim = K.ndim(weights)
    419             score_array = K.mean(score_array,
--> 420                                  axis=list(range(weight_ndim, ndim)))
    421             score_array *= weights
    422             score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))

TypeError: range() integer end argument expected, got NoneType.

如果我在功能ed中添加print(loss),则会得到以下输出Tensor("loss_21/activation_1_loss/edit_distance:0", dtype=float32)

我不明白为什么会收到此错误消息。 也许有人可以给我提示我做错了什么。

0 个答案:

没有答案