如何使用分割模型输出张量?

时间:2019-07-22 19:23:11

标签: ios tensorflow image-segmentation tensorflow-lite semantic-segmentation

我正在尝试在iOS上运行细分模型,但我对如何正确使用输出张量存在一些疑问。

这是我使用的模型上的链接: https://www.tensorflow.org/lite/models/segmentation/overview

运行此模型时,我得到具有尺寸的输出张量: 1 x 257 x 257 x 21。 为什么我将21作为最后一个维度?看起来对于每个像素,我们都得到了班级成绩。我们是否需要在此处找到argmax以获得正确的类值?

但是为什么只有21个班级?我认为它应该包含更多。在哪里可以找到对应于某个类的值的信息。 在ImageClassification示例中,我们有一个包含1001个类的label.txt。

基于ImageClassification示例,我尝试解析张量:首先将其转换为大小为1387029(21 x 257 x 257)的Float数组,然后使用以下代码逐像素创建图像:

RestTemplate restTemplate = new RestTemplate();
   System.out.println(" jsonLinksObj : 4a "+jsonLinksObj);   
   org.json.JSONObject jsonSendMessageObj = jsonLinksObj.getJSONObject("sendMessage");
   String sendMsgFullUrl = poolUrl.concat(jsonSendMessageObj.getString("href")).concat("?OperationContext=4322131"); 

  headers.clear();
  headers.setContentType(MediaType.TEXT_PLAIN);
  headers.set("Authorization", "Bearer " + jwtToken);

  String body = "My Send Message Here";

                HttpEntity<Object> entityJsonSendMsg4 = new HttpEntity<Object>(body, headers);


                 ResponseEntity<Object> sbSendMsgObj  = null;
               try{

                    sbSendMsgObj = restTemplate.exchange(
                        new URI(sendMsgFullUrl), 
                        HttpMethod.POST, 
                        entityJsonSendMsg4, 
                        Object.class);
               } catch(Exception e){
                        e.printStackTrace();
               }

                            WARN : org.springframework.web.client.RestTemplate - POST request for "https://webpoolpnqin102.infra.lync.com/ucwa/oauth/v1/applications/102086376449/communication/conversations/46db8085-8ad0-4186-a4f0-9a521b256b9b/messaging/messages?OperationContext=4322131" resulted in 500 (Internal Server Error);         invoking error handler
                org.springframework.web.client.HttpServerErrorException: 500 Internal Server Error
                                at org.springframework.web.client.DefaultResponseErrorHandler.handleError(DefaultResponseErrorHandler.java:94)
                                                                                         at org.springframework.web.client.RestTemplate.handleResponseError(RestTemplate.java:589)
                                at org.springframework.web.client.RestTemplate.doExecute(RestTemplate.java:547)
                                at org.springframework.web.client.RestTemplate.execute(RestTemplate.java:518)
                                at org.springframework.web.client.RestTemplate.exchange(RestTemplate.java:463)
                                at com.test.example.employee.test.EmployeeTest.sendMessage(EmployeeTest.java:266)
                                  at com.test.example.employee.test.EmployeeTest.main(EmployeeTest.java:337)
                Exception in thread "main" java.lang.NullPointerException
                                  at com.test.example.employee.test.EmployeeTest.sendMessage(EmployeeTest.java:331)
                                  at com.test.example.employee.test.EmployeeTest.main(EmployeeTest.java:337)

这是我得到的结果:

enter image description here

您可以看到质量不是很好。我错过了什么?

CoreML(https://developer.apple.com/machine-learning/models/)的细分模型在同一示例上效果更好:

enter image description here

2 个答案:

答案 0 :(得分:1)

似乎您的模型是根据PASCAL VOC数据进行训练的,该数据具有21个细分类别。
您可以找到类here的列表:

  

背景
  飞机
  自行车
  鸟
  船
  瓶子
  巴士
  汽车
  猫
  椅子
  牛
  餐桌
  狗
  马
  摩托车
  人
  盆栽植物
  羊
  沙发
  火车
  电视监视器

答案 1 :(得分:0)

除了Shai的答案,您还可以使用Netron之类的工具来可视化您的网络,并获得对输入和输出的更多了解,例如,您输入的图像尺寸为257x257x3: enter image description here

您已经知道您的输出大小,对于细分模型,您将得到21,因为这是模型所支持的类数(如Shai所述),然后取所有类的每个像素的argmax,这应该为您提供更多体面的输出,无需调整任何大小,请尝试以下操作(使用伪代码):

output = [rows][cols]
for i in rows:
  for j in cols:
    argmax = -1
    for c in classes:
      if tensor_out[i][j][c] > argmax:
        argmax = tensor_out[i][j][c]
    output[i][j] = c

然后输出将是您的分割图像。