Triplet网络(受“Siamese网络”启发)由同一前馈网络的3个实例(带有共享参数)组成。当馈送3个样本时,网络输出2个中间值 - 来自其两个输入的嵌入表示之间的L2(欧几里德)距离 第三个代表。
我正在使用三对图像来馈送网络( x =锚图像,标准图像, x + =正图像,包含相同对象的图像作为x - 实际上, x +与x 相同, x- =负图像,图像的类别不同于x 。
我正在使用here所述的三联体损失成本函数。
如何确定网络的准确性?
答案 0 :(得分:1)
我假设您正在为图像检索或类似任务工作。
您应该首先生成一些三元组,无论是随机还是使用一些硬(半硬)负挖掘方法。然后将三元组分成训练和验证集。
如果你这样做,那么你可以将验证准确度定义为三元组数量的比例,其中锚点和正面之间的特征距离小于验证三元组中锚点和负面之间的特征距离。你可以看到用PyTorch编写的an example here。
另一种方法是,您可以直接根据最终测试指标进行衡量。例如,对于图像检索,我们通常使用mean average precision来测量模型在测试集上的性能。如果您使用此指标,则应首先在验证集上定义一些查询及其相应的地面实况图像。
以上两个指标中的任何一个都没问题。选择您认为适合您的情况。
答案 1 :(得分:0)
所以我正在执行使用 Triplet loss 进行分类的类似任务。这是我如何将新颖的损失方法与分类器一起使用。 首先,使用标准的三元组损失函数训练你的模型 N 个时期。一旦您确定模型(我们将其称为嵌入生成器)经过训练,请保存权重,因为我们将在前面使用这些权重。 假设您的嵌入生成器定义为:
class EmbeddingNetwork(nn.Module):
def __init__(self):
super(EmbeddingNetwork, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 64, (7,7), stride=(2,2), padding=(3,3)),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.001),
nn.MaxPool2d((3, 3), 2, padding=(1,1))
)
self.conv2 = nn.Sequential(
nn.Conv2d(64,64,(1,1), stride=(1,1)),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.001),
nn.Conv2d(64,192, (3,3), stride=(1,1), padding=(1,1)),
nn.BatchNorm2d(192),
nn.LeakyReLU(0.001),
nn.MaxPool2d((3,3),2, padding=(1,1))
)
self.fullyConnected = nn.Sequential(
nn.Linear(7*7*256,32*128),
nn.BatchNorm1d(32*128),
nn.LeakyReLU(0.001),
nn.Linear(32*128,128)
)
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = self.fullyConnected(x)
return torch.nn.functional.normalize(x, p=2, dim=-1)
现在我们将使用这个嵌入生成器来创建另一个分类器,将我们之前保存的权重拟合到网络的这一部分,然后冻结这部分,这样我们的分类器训练器就不会干扰三元组模型。可以这样做:
class classifierNet(nn.Module):
def __init__(self, EmbeddingNet):
super(classifierNet, self).__init__()
self.embeddingLayer = EmbeddingNet
self.classifierLayer = nn.Linear(128,62)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.dropout(self.embeddingLayer(x))
x = self.classifierLayer(x)
return F.log_softmax(x, dim=1)
现在我们将加载我们之前保存的权重并使用以下方法冻结它们:
embeddingNetwork = EmbeddingNetwork().to(device)
embeddingNetwork.load_state_dict(torch.load('embeddingNetwork.pt'))
classifierNetwork = classifierNet(embeddingNetwork)
现在使用标准分类损失(如 BinaryCrossEntropy 或 CrossEntropy)训练这个分类器网络。