我目前正在使用数据生成器来生成三胞胎,以馈入我的Triplet Network。但是,我面临数组超限的问题。
数据生成器代码为:
def data_generator(df, batch_size=32):
first_val = df['label'].iloc[0]
last_val = df['label'].iloc[-1]
n_classes = (last_val - first_val) + 1
print(n_classes)
VECTOR_SIZE = 128
height = 218
width = 178
x = df['image'].values # x_train (name of the image)
y = df['label'].values # y_train (label of the image)
class_ids = [None] * n_classes
for n in range(n_classes):
class_ids[n] = np.where(y == n)[0]
while True:
X_batch_0 = np.empty((batch_size, height, width, 3), dtype=np.float32)
X_batch_1 = np.empty((batch_size, height, width, 3), dtype=np.float32)
X_batch_2 = np.empty((batch_size, height, width, 3), dtype=np.float32)
for i in range(batch_size):
tgt_class = random.randint(0, n_classes - 1)
neg_class = random.randint(0, n_classes - 1)
# #DEBUG
# print("tgt: ", tgt_class)
# print("neg: ", neg_class)
if neg_class == tgt_class:
neg_class = n_classes - 1
#a, p and n is between 0 and 10 --- make sure is bigger than 0 and less than 10
end_range = class_ids[tgt_class].size - 1
if end_range == 0:
a = random.randint(0, end_range + 1)
p = random.randint(0, end_range + 1)
while a == p:
a = random.randint(0, end_range + 1)
p = random.randint(0, end_range + 1)
else:
a = random.randint(0, end_range)
p = random.randint(0, end_range)
while a == p:
a = random.randint(0, end_range)
p = random.randint(0, end_range)
# re-calculate a and p if same value (meaning same image)
# print("class_ids[tgt_class].size: ", class_ids[tgt_class].size)
end_range1 = class_ids[neg_class].size - 1
if end_range1 == 0:
n = random.randint(0, end_range1+1)
else:
n = random.randint(0, end_range1)
print("a", a)
print("p", p)
print("n", n)
idx_a = class_ids[tgt_class][a]
idx_p = class_ids[tgt_class][p]
idx_n = class_ids[neg_class][n]
print("Anchor ID: ", idx_a)
print("Positive ID: ", idx_p)
print("Negative ID: ", idx_n)
# print("x[idx_a]: ", x[idx_a])
load_image(x[idx_a])
# print("x[idx_p]: ", x[idx_p])
load_image(x[idx_p])
# print("x[idx_n]: ", x[idx_n])
load_image(x[idx_n])
X_batch_0[i] = load_image(x[idx_a]) / 255
X_batch_1[i] = load_image(x[idx_p]) / 255
X_batch_2[i] = load_image(x[idx_n]) / 255
yield [X_batch_0, X_batch_1, X_batch_2], np.empty((batch_size, VECTOR_SIZE * 3))
错误消息如下:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-194-c02cc779a136> in <module>()
3 steps_per_epoch=1000, validation_steps=200,
4 epochs=30,
----> 5 callbacks=[checkpoint, csv_logger])
6
7 # Save model
9 frames
<ipython-input-191-de6ea9976cc9> in data_generator(df, batch_size)
84 idx_a = class_ids[tgt_class][a]
85 idx_p = class_ids[tgt_class][p]
---> 86 idx_n = class_ids[neg_class][n]
87
88 print("Anchor ID: ", idx_a)
IndexError: index 1 is out of bounds for axis 0 with size 1
data_generator函数'df'的参数是一个df,包含以下结构(示例)中的图像名称和标签:
图片,标签 00567.jpg,27 00567.jpg,27 00568.jpg,28 00568.jpg,28
任何帮助将不胜感激。谢谢。