我对Keras / Tensorflow完全陌生。下面是我的fit_generator
train_dataset=train_fn_inputs(batch_size, None)
val_data=validation_fn_inputs(batch_size, None)
total_records = 44712
val_records = 11178
steps_per_epoch=int(total_records // batch_size)
hist=model.fit_generator(#aug.flow(X_def, y_def, batch_size=batch_size),
#get_batches(X_def, y_def, batch_size),
train_dataset,
steps_per_epoch=steps_per_epoch, #(training_df.shape[0])//batchsize,
epochs=5,
verbose = 1,
#callbacks=[early_stopping],
#validation_data=val_data,
#validation_steps=val_records//batch_size,
workers=0
)
函数定义为:
def train_fn_inputs(bs, aug=None):
train_files, total_records = get_training_data_old()
steps_per_epoch = int(total_records / batch_size)
raw_dataset = tf.data.TFRecordDataset(train_files) #.repeat()
parsed_image_dataset = raw_dataset.map(_parse_image_function).repeat().shuffle(buffer_size=buf_size).batch(batch_size).make_initializable_iterator()
image, label = parsed_image_dataset.get_next()
image = tf.reshape(image, [3, IMG_WIDTH, IMG_HEIGHT, bs])
#label = tf.reshape(label, [bs, 75, 25])
while True:
yield (np.array(image), np.array(label))
但是出现此错误:
文件“ ... \ Anaconda3 \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training_generator.py”,行184,在model_iteration中 batch_size = int(nest.flatten(batch_data)[0] .shape [0])
IndexError:元组索引超出范围