在训练诸如street2shop之类的问题模型时,用于三重态损失的train_data的结构将是什么?

时间:2019-06-26 12:40:07

标签: python tensorflow machine-learning keras deep-learning

我想为street2shop或使用三重损失的任何问题训练模型。因此,为了进行训练,我列出了三胞胎的列表,每个三胞胎包含3个数组(图像的numpy数组)的集合,但是当我通过train_data传递适合模型时,该模型将三胞胎视为一个数组。

所以我的问题是如何将三胞胎列表作为train_data传递给模型?

我共享了创建train_data的代码。

train_data=[]
with open('/home/easofy/Downloads/structured_images/triplets.csv', 'r') as fp:
    next(fp)
    reader = csv.reader(fp)
    for line in reader:
       a=np.array(cv2.imread(line[0],cv2.IMREAD_UNCHANGED))

       b=np.array(cv2.resize(cv2.imread(line[1],cv2.IMREAD_UNCHANGED),(256,256)))

       c=np.array(cv2.resize(cv2.imread(line[2],cv2.IMREAD_UNCHANGED),(256,256)))

       train_data.append([a, b, c])

       if len(train_data)==6:
           break

ValueError:检查模型输入时出错:传递给模型的Numpy数组列表不是模型预期的大小。预计会看到3个数组,但获得了以下6个数组的列表:

[array([[[[162, 196, 225],
          [161, 195, 224],
          [158, 193, 219],
          ...,
          [171, 204, 224],
          [152, 185, 205],
          [143, 176, 196]],

          [[155, 189, 218],
 ...

0 个答案:

没有答案