我如何在tensorflow中构建一个svm?

时间:2019-08-18 09:01:30

标签: tensorflow svm

我尝试在tensorflow中使用svm,并且在计算损失函数中的交叉熵时发现了svm的代码。此代码对svm有用吗?

def calculate_loss(logits_list, labels_list, regularizer_rate):
    weight_coefficients = [1]

    # xent_mean_list = []
    # for i,(logits,labels,weight) in enumerate(zip(logits_list,labels_list,weight_coefficients)):
    #     if i == 0:
    #         xent = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)
    #         xent_mean = tf.reduce_mean(xent)
    #         xent_mean_list.append(weight*xent_mean)
    # if regularizer_rate != 0:
    #     loss = reduce(tf.add,xent_mean_list) + tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    # else:
    #     loss = reduce(tf.add, xent_mean_list)
    # return loss

    xent_mean_list = []
    for i, (logits, labels, weight) in enumerate(zip(logits_list, labels_list, weight_coefficients)):
        if i == 0:
            cross_entropy = tf.square(tf.maximum(tf.zeros([train_batch_size, 2]), 1 - logits * labels))
            xent_mean = tf.reduce_mean(tf.square(logits_list)) + tf.reduce_mean(cross_entropy)
            xent_mean_list.append(weight * xent_mean)
    if regularizer_rate != 0:
        loss = reduce(tf.add, xent_mean_list) + tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    else:
        loss = reduce(tf.add, xent_mean_list)
    return loss

损失函数确实起作用了,但我不知道这是否只是不包括svm的常见交叉熵损失。对不起,我的数学不好。

0 个答案:

没有答案