三重损失数据生成器

时间:2020-04-21 03:28:10

标签: python pandas dataframe keras triplet

我目前正在使用数据生成器来生成三胞胎,以馈入我的Triplet Network。但是,我面临数组超限的问题。

数据生成器代码为:

def data_generator(df, batch_size=32):
  first_val = df['label'].iloc[0]
  last_val = df['label'].iloc[-1]

  n_classes = (last_val - first_val) + 1
  print(n_classes)           
  VECTOR_SIZE = 128
  height = 218
  width = 178
  x = df['image'].values   # x_train (name of the image)
  y = df['label'].values   # y_train (label of the image)

  class_ids = [None] * n_classes
  for n in range(n_classes):
    class_ids[n] = np.where(y == n)[0]

  while True:
    X_batch_0 = np.empty((batch_size, height, width, 3), dtype=np.float32)
    X_batch_1 = np.empty((batch_size, height, width, 3), dtype=np.float32)
    X_batch_2 = np.empty((batch_size, height, width, 3), dtype=np.float32)

    for i in range(batch_size):
      tgt_class = random.randint(0, n_classes - 1)
      neg_class = random.randint(0, n_classes - 1)

      # #DEBUG
      # print("tgt: ", tgt_class)
      # print("neg: ", neg_class)

      if neg_class == tgt_class:
        neg_class = n_classes - 1

      #a, p and n is between 0 and 10 --- make sure is bigger than 0 and less than 10
      end_range = class_ids[tgt_class].size - 1
      if end_range == 0:
        a = random.randint(0, end_range + 1)
        p = random.randint(0, end_range + 1)
        while a == p:
          a = random.randint(0, end_range + 1)
          p = random.randint(0, end_range + 1)

      else:
        a = random.randint(0, end_range)
        p = random.randint(0, end_range)
        while a == p:
          a = random.randint(0, end_range)
          p = random.randint(0, end_range)
      # re-calculate a and p if same value (meaning same image)
      # print("class_ids[tgt_class].size: ", class_ids[tgt_class].size)
      end_range1 = class_ids[neg_class].size - 1
      if end_range1 == 0:
        n = random.randint(0, end_range1+1)
      else:
        n = random.randint(0, end_range1)

      print("a", a)
      print("p", p)
      print("n", n)

      idx_a = class_ids[tgt_class][a]
      idx_p = class_ids[tgt_class][p]
      idx_n = class_ids[neg_class][n]

      print("Anchor ID: ", idx_a)
      print("Positive ID: ", idx_p)
      print("Negative ID: ", idx_n)

      # print("x[idx_a]: ", x[idx_a])
      load_image(x[idx_a])
      # print("x[idx_p]: ", x[idx_p])
      load_image(x[idx_p])
      # print("x[idx_n]: ", x[idx_n])
      load_image(x[idx_n])

      X_batch_0[i] =       load_image(x[idx_a]) / 255
      X_batch_1[i] =       load_image(x[idx_p]) / 255
      X_batch_2[i] =       load_image(x[idx_n]) / 255

    yield [X_batch_0, X_batch_1, X_batch_2], np.empty((batch_size, VECTOR_SIZE * 3))

错误消息如下:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-194-c02cc779a136> in <module>()
      3     steps_per_epoch=1000, validation_steps=200,
      4     epochs=30,
----> 5     callbacks=[checkpoint, csv_logger])
      6 
      7 # Save model

9 frames
<ipython-input-191-de6ea9976cc9> in data_generator(df, batch_size)
     84       idx_a = class_ids[tgt_class][a]
     85       idx_p = class_ids[tgt_class][p]
---> 86       idx_n = class_ids[neg_class][n]
     87 
     88       print("Anchor ID: ", idx_a)

IndexError: index 1 is out of bounds for axis 0 with size 1

data_generator函数'df'的参数是一个df,包含以下结构(示例)中的图像名称和标签:

图片,标签 00567.jpg,27 00567.jpg,27 00568.jpg,28 00568.jpg,28

任何帮助将不胜感激。谢谢。

0 个答案:

没有答案