keras模型的steps_per_epoch.fit与数据增强

时间:2019-03-27 21:42:06

标签: tensorflow keras data-augmentation

我正在使用tensorflow + keras。使用数据扩充时,我不确定model.fit中的steps_per_epoch参数。我的数据扩充是使用tfrecord中的map函数而不是keras中的图像生成器完成的。这样,如果我将数据增加4倍,我的step_per_epoch也会增加4倍吗?

#load data and augment
def _parse_function(proto):
    keys_to_features = {'img_raw': tf.FixedLenFeature([],tf.string),
                        'mask_raw': tf.FixedLenFeature([],tf.string)}

    parsed_features = tf.parse_single_example(proto,keys_to_features)

    parsed_features['img_raw'] = tf.decode_raw(parsed_features['img_raw'],tf.float32)
    parsed_features['mask_raw'] = tf.decode_raw(parsed_features['mask_raw'],tf.float32)

    return parsed_features['img_raw'], parsed_features['mask_raw']

def preprocess_flip(img_raw,mask_raw):
    img_raw = tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL])
    mask_raw = tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])
    img_flip = tf.image.flip_left_right(img_raw)
    mask_flip = tf.image.flip_left_right(mask_raw)

    return img_flip, mask_flip

def preprocess_rotate(img_raw,mask_raw):
    img_raw = tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL])
    mask_raw = tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])
    angles = 30/180*math.pi
    img_rotate = tf.contrib.image.rotate(img_raw,angles)
    mask_rotate = tf.contrib.image.rotate(mask_raw,angles)

    return img_rotate, mask_rotate

def preprocess_translate(img_raw,mask_raw):
    img_raw = tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL])
    mask_raw = tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])
    dx = 5
    dy = 5
    img_ts = tf.contrib.image.translate(img_raw,[dx,dy])
    mask_ts = tf.contrib.image.translate(mask_raw,[dx,dy])

    return img_ts, mask_ts




def create_dataset(filepath,trainFLAG):
    dataset = tf.data.TFRecordDataset(filepath)

    dataset = dataset.map(_parse_function, num_parallel_calls=8)
    if(trainFLAG>0):
        dataset = dataset.map(preprocess_flip,num_parallel_calls=8)
        dataset = dataset.map(preprocess_rotate,num_parallel_calls=8)
        dataset = dataset.map(preprocess_translate,num_parallel_calls=8)

    dataset = dataset.repeat()
    dataset = dataset.shuffle(SHUFFLE_BUFFER)
    dataset = dataset.batch(BATCH_SIZE)

    iterator = dataset.make_one_shot_iterator()
    img_raw,mask_raw = iterator.get_next()

    img_raw = tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL])
    mask_raw = tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])


    return img_raw,mask_raw

# load dataset

filenames_train = tf.data.Dataset.list_files(directory)
img_raw,mask_raw = create_dataset(filenames_train,1)

model_input = keras.layers.Input(tensor=tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL]))
model_output = cnn_layers(model_input)

# create model
with tf.device('/cpu:0'):
    model2 = Model(inputs=model_input, outputs=model_output)

# compile model  
model = multi_gpu_model(model2, gpus=2)
model.compile(optimizer='adam', loss='mse', metrics=['mae'], target_tensors=[tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])]) 
model.summary()

results = model.fit(epochs=1, steps_per_epoch=STEPS_PER_EPOCH)

0 个答案:

没有答案