我尝试在tensorflow中使用svm,并且在计算损失函数中的交叉熵时发现了svm的代码。此代码对svm有用吗?
def calculate_loss(logits_list, labels_list, regularizer_rate):
weight_coefficients = [1]
# xent_mean_list = []
# for i,(logits,labels,weight) in enumerate(zip(logits_list,labels_list,weight_coefficients)):
# if i == 0:
# xent = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)
# xent_mean = tf.reduce_mean(xent)
# xent_mean_list.append(weight*xent_mean)
# if regularizer_rate != 0:
# loss = reduce(tf.add,xent_mean_list) + tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
# else:
# loss = reduce(tf.add, xent_mean_list)
# return loss
xent_mean_list = []
for i, (logits, labels, weight) in enumerate(zip(logits_list, labels_list, weight_coefficients)):
if i == 0:
cross_entropy = tf.square(tf.maximum(tf.zeros([train_batch_size, 2]), 1 - logits * labels))
xent_mean = tf.reduce_mean(tf.square(logits_list)) + tf.reduce_mean(cross_entropy)
xent_mean_list.append(weight * xent_mean)
if regularizer_rate != 0:
loss = reduce(tf.add, xent_mean_list) + tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
else:
loss = reduce(tf.add, xent_mean_list)
return loss
损失函数确实起作用了,但我不知道这是否只是不包括svm的常见交叉熵损失。对不起,我的数学不好。