列出索引超出范围,同时通过拆分验证在keras中拟合模型

时间:2019-03-27 14:30:27

标签: tensorflow keras tfrecord

我正在学习使用TFrecord。我的代码尝试批量加载TFrecord数据,并训练一个验证间隔为0.1的简单Unet。我得到这个错误 文件“ /DATA/usr/local/anaconda3/lib/python3.6/site-packages/keras/engine/training.py”,第984行,适合     如果hasattr(x [0],'shape'): IndexError:列表索引超出范围

# DEFINE SOME CONSTANTS HERE
SHUFFLE_BUFFER=100
SUM_OF_ALL_DATASAMPLES=126
BATCH_SIZE = 126
HEIGHT = 512
WIDTH = 512
CHANNEL = 2
ALPHA = 0.1


# my code for loading data

def _parse_function(proto):
    keys_to_features = {'img_raw': tf.FixedLenFeature([],tf.string),
                        'mask_raw': tf.FixedLenFeature([],tf.string)}

    parsed_features = tf.parse_single_example(proto,keys_to_features)

    parsed_features['img_raw'] = tf.decode_raw(parsed_features['img_raw'],tf.float32)
    parsed_features['mask_raw'] = tf.decode_raw(parsed_features['mask_raw'],tf.float32)

    return parsed_features['img_raw'], parsed_features['mask_raw']

def create_dataset(filepath):
    dataset = tf.data.TFRecordDataset(filepath)

    dataset = dataset.map(_parse_function, num_parallel_calls=8)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(SHUFFLE_BUFFER)
    dataset = dataset.batch(BATCH_SIZE)

    iterator = dataset.make_one_shot_iterator()
    img_raw,mask_raw = iterator.get_next()

    img_raw = tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL])
    mask_raw = tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])


    return img_raw,mask_raw


STEPS_PER_EPOCH = SUM_OF_ALL_DATASAMPLES/BATCH_SIZE

# load dataset
filenames_train = tf.data.Dataset.list_files(directory)
img_raw,mask_raw = create_dataset(filenames_train)
print('The shape is ' + str(img_raw.shape) + '.')
imageSize = img_raw.shape[1:]


# take input
model_input = keras.layers.Input(tensor=tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL]))

# define network structure
# abbreviated here

model_output = Conv2D(1, (1, 1))(model_input)

# create model

with tf.device('/cpu:0'):
    model2 = Model(inputs=model_input, outputs=model_output)

# compile model  
model = multi_gpu_model(model2, gpus=2)


model.compile(optimizer='adam', loss='mse', metrics=['mae'], target_tensors=[tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])])

earlystopper = EarlyStopping(patience=2, verbose=1)
checkpointer = ModelCheckpoint(modelFile, verbose=1, save_best_only=True)


results = model.fit(validation_split=ALPHA, epochs=3, callbacks=[earlystopper, checkpointer], steps_per_epoch=STEPS_PER_EPOCH)

model.summary()

0 个答案:

没有答案