我从this起具有以下损失函数:
def weightedLoss(originalLossFunc, weightsList):
def lossFunc(true, pred):
axis = -1 #if channels last
#axis= 1 #if channels first
#argmax returns the index of the element with the greatest value
#done in the class axis, it returns the class index
classSelectors = K.argmax(true, axis=axis)
#considering weights are ordered by class, for each class
#true(1) if the class index is equal to the weight index
classSelectors = [K.equal(i, classSelectors) for i in range(len(weightsList))]
#casting boolean to float for calculations
#each tensor in the list contains 1 where ground true class is equal to its index
#if you sum all these, you will get a tensor full of ones.
classSelectors = [K.cast(x, K.floatx()) for x in classSelectors]
#for each of the selections above, multiply their respective weight
weights = [sel * w for sel,w in zip(classSelectors, weightsList)]
#sums all the selections
#result is a tensor with the respective weight for each element in predictions
weightMultiplier = weights[0]
for i in range(1, len(weights)):
weightMultiplier = weightMultiplier + weights[i]
#make sure your originalLossFunc only collapses the class axis
#you need the other axes intact to multiply the weights tensor
loss = originalLossFunc(true,pred)
weightMultiplier = tf.Print(weightMultiplier, [weightMultliplier], "loss weightage")
loss = loss * weightMultiplier
#weightMultiplier = tf.Print(weightMultiplier, [weightMultliplier], "loss weightage") ---location 2
return loss
return lossFunc
现在在该函数内部,我有一条打印语句来打印权重向量。在当前位置,网络不会打印任何内容,尽管我认为这会导致网络将其包含在其计算图中。然后,我将其向下移动了一行并进行了尝试,但这也不起作用。我究竟做错了什么?我在任何时候都不会出错。