Numba可以用来编译与Tensorflow接口的Python代码吗?即Tensorflow宇宙之外的计算将与Numba一起运行以获得速度。我还没有找到任何有关如何执行此操作的资源。
答案 0 :(得分:1)
我知道这并不能直接回答您的问题,但这可能是一个不错的选择。 Numba正在使用即时(JIT)编译。因此,您可以按照TensorFlow官方文档here上有关如何在TensorFlow中使用JIT(但在Numba生态系统中不使用)的说明进行操作。
答案 1 :(得分:1)
您可以使用tf.numpy_function或tf.py_func包装python函数并将其用作TensorFlow op。这是我使用的示例:
@jit
def dice_coeff_nb(y_true, y_pred):
"Calculates dice coefficient"
smooth = np.float32(1)
y_true_f = np.reshape(y_true, [-1])
y_pred_f = np.reshape(y_pred, [-1])
intersection = np.sum(y_true_f * y_pred_f)
score = (2. * intersection + smooth) / (np.sum(y_true_f) +
np.sum(y_pred_f) + smooth)
return score
@jit
def dice_loss_nb(y_true, y_pred):
"Calculates dice loss"
loss = 1 - dice_coeff_nb(y_true, y_pred)
return loss
def bce_dice_loss_nb(y_true, y_pred):
"Adds dice_loss to categorical_crossentropy"
loss = tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \
tf.keras.losses.categorical_crossentropy(y_true, y_pred)
return loss
然后我在训练tf.keras模型时使用了此损失函数:
...
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss=bce_dice_loss_nb)