从caffe最后一层提取特征

时间:2018-03-09 18:47:35

标签: python deep-learning caffe pycaffe

我正在尝试从汽车分类微调的GoogleNet caffe模型的最后一层中提取功能。这是deploy.prototxt。我尝试了几件事:

  1. 我从' loss3_classifier_model' 图层中获取了不正确的图片。

  2. 现在,我正在从原型文件中提供的模型中的' pool5' 图层中提取要素。

  3. 我不确定它是否正确,因为我为不同的汽车提取的功能似乎没什么区别。换句话说,我无法使用最后一层功能来区分汽车,我在功能上使用欧几里德距离(它是否正确?)。我没有使用softmax,因为我不想对它们进行分类,我只想要功能,然后使用欧氏距离重新检查它们。

    这是我遵循的步骤:

    ## load the model 
    net = caffe.Net('deploy.prototxt',
                    caffe.TEST,
                    weights ='googlenet_finetune_web_car_iter_10000.caffemodel') 
    
    # resize the input size as I have only one image in my batch.
    net.blobs["data"].reshape(1, 3, 224, 224)
    
    # I read my image of size (x,y,3)
    frame = cv2.imread(frame_path) 
    
    bbox = frame[int(x1):int(x2), int(y1):int(y2)] # getting the car, # I have stored x1,x2,x3,x4 seperatly.
    # resized my image to 224,224,3, network input size.
    bbox = cv2.resize(bbox, (224, 224)) 
    
    # to align my input to the input of the model 
    bbox_input = bbox.swapaxes(1,2).reshape(3,224,224) 
    
    # fed input image to the model.
    net.blobs['data'].data[0] = bbox_input 
    net.forward()
    
    # features from pool5 layer or the last layer.
    temp = net.blobs["pool5"].data[0] 
    

    现在,我想确认这些步骤是否正确?我是caffe的新手,我不确定我上面写的步骤。

1 个答案:

答案 0 :(得分:0)

这两个选项都有效。 离网络末端越远,功能对问题/培训集的专业性就越低,同时仍会捕获可能应用于类似任务的相关信息。当您移动到网络末尾时,功能将更适合您的任务。

请注意,您正在处理两个类似的问题/任务。网络经过微调,适用于汽车分类("这款车是哪款车型?")现在您要验证两辆车是否属于同一车型。

考虑到网络使用大型且具有代表性的训练集进行微调,从中获得的特征功能强大且具有很强的表示能力(即,它们捕获了许多复杂的基础模式,这些模式是他们训练的任务)对您的验证任务有用。

考虑到这一点,您可以尝试多种方法来比较两个特征向量:

  • 欧几里德距离太简单了。我会尝试它只是因为它易于/快速实施;
  • 余弦相似性 [1]也可能是一个简单但良好的起点;
  • 分类即可。我在类似问题中做过的另一种可能性是在两个特征的组合之上训练分类器(SVM,Logistic回归)。分类器的输入可以是它们并排连接。
  • 将验证任务合并到您的网络。您可以更改GoogleNet架构,以接收两张汽车照片和输出(如果它们属于或不属于同一型号)。您可以将网络从分类问题转换/微调到验证任务。检查 siamese networks [2]

修改:调整框架大小时出错可能是导致问题的原因!

# I read my image of size (x,y,3)
frame = cv2.imread(frame_path) 

# resized my image to 224,224,3, network input size.
bbox = cv2.resize(frame, (224, 224))

您应该已在frame方法中输入cv2.resize()作为输入。您可能正在向网络提供垃圾输入,这就是为什么输出最终总是看起来相似。