Tensorflow对象检测,掩码类numpy数组?

时间:2018-01-09 05:10:28

标签: python numpy tensorflow object-detection

如果我不想在图片中找到所有类别,我应该如何处理,只有一到三个。例如,狗,猫,人。

我使用https://github.com/tensorflow/models/tree/master/research/object_detection示例。

我有两个数组:

分数:

[[9.98601377e-01 8.95673811e-01 8.53869259e-01 2.66915649e-01   2.12714598e-01 1.59017399e-01 1.13635637e-01 4.40990664e-02   3.05164494e-02 2.98027769e-02 2.18284614e-02 1.26428921e-02   7.69424951e-03 6.79711485e-03 3.39347101e-03 3.07240430e-03   2.98071955e-03 2.93320580e-03 2.82452232e-03 2.74329516e-03   2.70699803e-03 2.39588786e-03 2.26139510e-03 1.87807775e-03   1.84638728e-03 1.76362693e-03 1.69230008e-03 1.60750828e-03   1.38457527e-03 1.06237642e-03 8.92742886e-04 8.06386990e-04   6.60347985e-04 5.93963894e-04 5.72122575e-04 5.68453805e-04   5.54322207e-04 5.31597179e-04 5.15502586e-04 4.24901489e-04   4.00159304e-04 3.46195826e-04 3.33204021e-04 3.16907885e-04   2.75790022e-04 2.73264130e-04 2.66362855e-04 2.65591720e-04   2.62703601e-04 1.95777262e-04 1.95584420e-04 1.94998822e-04   1.93145475e-04 1.81952943e-04 1.78345916e-04 1.73626235e-04   1.65691730e-04 1.48035586e-04 1.46503138e-04 1.43825935e-04   1.41083947e-04 1.34577596e-04 1.28188753e-04 1.23581864e-04   1.21554323e-04 1.13173104e-04 1.12181173e-04 1.11818241e-04   1.04750507e-04 1.02079212e-04 1.00522630e-04 9.83492428e-05   9.67224623e-05 9.42678016e-05 9.03011023e-05 8.86701237e-05   8.70161384e-05 8.66368209e-05 8.65162874e-05 8.31855432e-05   8.28216725e-05 8.13762017e-05 7.97617613e-05 7.90129197e-05   7.67382662e-05 7.49801547e-05 7.47950835e-05 7.29791718e-05   7.24335769e-05 7.03693950e-05 6.93228139e-05 6.86998756e-05   6.83857288e-05 6.74587282e-05 6.73529139e-05 6.72009119e-05   6.60547448e-05 6.49067297e-05 6.21892177e-05 6.07847251e-05]]

[[17。 64. 64. 62. 33. 64. 70. 64. 62. 18. 62. 64. 70. 63. 70. 2. 88. 31。   27. 1. 72. 82. 51. 16. 17. 86. 15. 1. 33. 86. 62. 18. 33. 62. 15. 65。   15. 62. 47. 64. 65. 64. 11. 86. 15. 47. 44. 1. 82. 82. 81. 86. 17. 64。   31. 51. 44. 67. 64. 3. 82. 82. 86. 72. 15. 62. 44. 19. 89. 16. 2. 62。   61. 79. 79. 31. 23. 40. 67. 21. 64. 67. 47. 65. 51. 88. 62. 27. 2. 62。   47. 84. 63. 17. 15. 88. 70. 14. 70. 20。]]

如果我尝试这个,输出数组只有零:

for i in range(0,len(classes)):
      if classes[i] != "1.":
            scores[i] = 0

2 个答案:

答案 0 :(得分:1)

您的类数据是列表列表,但您只遍历外部列表。 您还将浮点值与字符串进行比较,还是您的类数组字符串数据?由于精度有限,您不应该直接比较浮点值......

简单解压缩第一个列表:

scs = scores[0]
cls = classes[0]
idxs = []
for i in range(0,len(cls)):
    if int(float(cls[i]))  != 1:
        scs[i] = 0
    else:
        idxs.append(i)
print("indices:", idxs)
print("scores:", [ scs[idx] for idx in idxs])

或使用正确的索引:

idxs = []
for i in range(0,len(classes[0])):
    if int(float(classes[0][i]))  != 1:
        scores[0][i] = 0
    else:
        idxs.append(i)
print("indices:", idxs)
print("scores:", [ scores[0][idx] for idx in idxs])

或使用例如numpy找到你需要的东西(假设一系列花车)......:

import numpy as np
cls_arr = np.array(classes[0],dtype = np.uint8)
scs_arr = np.array(scores[0],dtype = np.uint8)
idxs = np.where(cls_arr == 1) [0]
print("indices:", idxs)
print("scores:",scs_arr[idxs])

一般情况下,如果您想在输出上使用不同的类,那么您可以将转移学习转移到您真正需要的类。 一个例子显示在: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_pets.md

无论如何,这个答案很快而且很脏......

答案 1 :(得分:0)

我还做了一个答案,我的过滤人和猫的解决方案:

  aa = 0
  bb = scores.flatten()
  for i in classes.flatten():
    if (i != 17.0) and (i != 1.0):
       bb[aa] = 0.0
    aa += 1
  np.reshape(bb, (20,5))