我试图用WebAssembly二进制文件中抽象的自定义函数覆盖keras / tf2.0损失函数。这是相关的代码。
@tf.function
def custom_loss(y_true, y_pred):
return tf.constant(instance.exports.loss(y_true, y_pred))
我正在用这种方式
model.compile(optimizer_obj, loss=custom_loss)
# and then model.fit(...)
我不确定tf2.0渴望执行的工作方式,因此对此有任何见解会很有用。
我认为instance.exports.loss函数与该错误无关,但是,如果您确定其他一切都没问题,请告诉我,我将分享其他详细信息。
这是一个堆栈跟踪和实际错误: https://pastebin.com/6YtH75ay
答案 0 :(得分:1)
首先,您不需要使用@tf.function
来定义自定义损失。
我们可以很高兴(如果毫无意义)做这样的事情:
def custom_loss(y_true, y_pred):
return tf.reduce_mean(y_pred)
model.compile(optimizer='adam',
loss=custom_loss,
metrics=['accuracy'])
只要我们在custom_loss
内部使用的所有操作都可以通过张量流区分
因此您可以删除@tf.function
装饰器,但是我怀疑您随后会遇到类似以下的错误消息:
[Some trace info] - Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
因为张量流无法在webassembly二进制文件中找到函数的梯度。损失函数中的所有内容都必须是张量流可以理解和计算的梯度,否则它将无法针对较低的损失值进行优化。
也许最好的方法是使用tensorflow可以计算梯度的操作而不是直接引用它来复制instance.exports.loss
内部的功能?