我正在使用tensorflow + keras。使用数据扩充时,我不确定model.fit中的steps_per_epoch参数。我的数据扩充是使用tfrecord中的map函数而不是keras中的图像生成器完成的。这样,如果我将数据增加4倍,我的step_per_epoch也会增加4倍吗?
#load data and augment
def _parse_function(proto):
keys_to_features = {'img_raw': tf.FixedLenFeature([],tf.string),
'mask_raw': tf.FixedLenFeature([],tf.string)}
parsed_features = tf.parse_single_example(proto,keys_to_features)
parsed_features['img_raw'] = tf.decode_raw(parsed_features['img_raw'],tf.float32)
parsed_features['mask_raw'] = tf.decode_raw(parsed_features['mask_raw'],tf.float32)
return parsed_features['img_raw'], parsed_features['mask_raw']
def preprocess_flip(img_raw,mask_raw):
img_raw = tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL])
mask_raw = tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])
img_flip = tf.image.flip_left_right(img_raw)
mask_flip = tf.image.flip_left_right(mask_raw)
return img_flip, mask_flip
def preprocess_rotate(img_raw,mask_raw):
img_raw = tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL])
mask_raw = tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])
angles = 30/180*math.pi
img_rotate = tf.contrib.image.rotate(img_raw,angles)
mask_rotate = tf.contrib.image.rotate(mask_raw,angles)
return img_rotate, mask_rotate
def preprocess_translate(img_raw,mask_raw):
img_raw = tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL])
mask_raw = tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])
dx = 5
dy = 5
img_ts = tf.contrib.image.translate(img_raw,[dx,dy])
mask_ts = tf.contrib.image.translate(mask_raw,[dx,dy])
return img_ts, mask_ts
def create_dataset(filepath,trainFLAG):
dataset = tf.data.TFRecordDataset(filepath)
dataset = dataset.map(_parse_function, num_parallel_calls=8)
if(trainFLAG>0):
dataset = dataset.map(preprocess_flip,num_parallel_calls=8)
dataset = dataset.map(preprocess_rotate,num_parallel_calls=8)
dataset = dataset.map(preprocess_translate,num_parallel_calls=8)
dataset = dataset.repeat()
dataset = dataset.shuffle(SHUFFLE_BUFFER)
dataset = dataset.batch(BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
img_raw,mask_raw = iterator.get_next()
img_raw = tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL])
mask_raw = tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])
return img_raw,mask_raw
# load dataset
filenames_train = tf.data.Dataset.list_files(directory)
img_raw,mask_raw = create_dataset(filenames_train,1)
model_input = keras.layers.Input(tensor=tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL]))
model_output = cnn_layers(model_input)
# create model
with tf.device('/cpu:0'):
model2 = Model(inputs=model_input, outputs=model_output)
# compile model
model = multi_gpu_model(model2, gpus=2)
model.compile(optimizer='adam', loss='mse', metrics=['mae'], target_tensors=[tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])])
model.summary()
results = model.fit(epochs=1, steps_per_epoch=STEPS_PER_EPOCH)