在Keras模型中使用tf.contrib.seq2seq.sequence_loss

时间:2018-06-18 18:28:24

标签: python tensorflow keras

我想在我的Keras模型中使用sequence_loss函数。但是我在创建重量函数时遇到了麻烦。

def tensorflow_loss(y_pred,y_true):
    weights = [tf.ones(BATCH_SIZE,tf.float32) for _ in y_pred]
    return tf.contrib.seq2seq.sequence_loss(y_pred,y_true,weights) 

这是我目前的尝试,我从另一个堆栈溢出答案,但我得到一个" Tensor对象不可迭代"错误。我该怎么做呢?

0 个答案:

没有答案