由于自定义损失函数,无法从h5加载keras模型

时间:2019-11-13 03:00:07

标签: python tensorflow keras

我一直在尝试加载使用自定义损失函数训练的模型,但是似乎无法正确加载它。当前它给了我这个错误:

ValueError: Shape must be rank 2 but is rank 1 for 'MatMul' (op: 'MatMul') with input shapes: [2048], [2048].

所以我知道我必须这样做:Loading model with custom loss + keras

这是我现在正在做的事情:

def batch_hard_triplet_loss(embeddings, labels, margin=0.3, squared=False):
    # Get the pairwise distance matrix
    print("Lables" , labels.shape)
    print("Embeddings", embeddings.shape)
    pairwise_dist = pairwise_distances(embeddings, squared=squared)
    mask_anchor_positive = _get_anchor_positive_triplet_mask(labels)
    mask_anchor_positive = tf.to_float(mask_anchor_positive)
    anchor_positive_dist = tf.multiply(mask_anchor_positive, pairwise_dist)
    hardest_positive_dist = tf.reduce_max(anchor_positive_dist, axis=1, keepdims=True)
    mask_anchor_negative = _get_anchor_negative_triplet_mask(labels)
    mask_anchor_negative = tf.to_float(mask_anchor_negative)
    max_anchor_negative_dist = tf.reduce_max(pairwise_dist, axis=1, keepdims=True)
    anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
    hardest_negative_dist = tf.reduce_min(anchor_negative_dist, axis=1, keepdims=True)

    def loss(y_true, y_pred):
        # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
        # triplet_loss = tf.maximum(hardest_positive_dist - hardest_negative_dist + margin, 0.0)
        # triplet_loss = tf.reduce_mean(triplet_loss)
        triplet_loss = k.maximum(hardest_positive_dist - hardest_negative_dist + margin, 0.0)
        triplet_loss = k.mean(triplet_loss)  # use keras mean
        return triplet_loss

    return loss

#Lables (?, 1) # This is the shape from the trianing of the model
#Embeddings (?, 2048)
embeddings = np.zeros(2048,)
labels = np.zeros(1,)
model = load_model('SavedModels/TripleLossGTBOX_ep10_0.400572.h5', custom_objects={'loss': batch_hard_triplet_loss(embeddings, labels, margin=0.3, squared=False)})

在这里找到了我正在使用的三重损失的完整代码,并对batch_hard_triplet_loss进行了一些修改,这就是我在上面添加的内容。 https://github.com/omoindrot/tensorflow-triplet-loss/blob/fc698369bb6c9acdc9f0e9e1ea00de0ddf782f12/model/triplet_loss.py

0 个答案:

没有答案