我正在学习使用TFrecord。我的代码尝试批量加载TFrecord数据,并训练一个验证间隔为0.1的简单Unet。我得到这个错误 文件“ /DATA/usr/local/anaconda3/lib/python3.6/site-packages/keras/engine/training.py”,第984行,适合 如果hasattr(x [0],'shape'): IndexError:列表索引超出范围
# DEFINE SOME CONSTANTS HERE
SHUFFLE_BUFFER=100
SUM_OF_ALL_DATASAMPLES=126
BATCH_SIZE = 126
HEIGHT = 512
WIDTH = 512
CHANNEL = 2
ALPHA = 0.1
# my code for loading data
def _parse_function(proto):
keys_to_features = {'img_raw': tf.FixedLenFeature([],tf.string),
'mask_raw': tf.FixedLenFeature([],tf.string)}
parsed_features = tf.parse_single_example(proto,keys_to_features)
parsed_features['img_raw'] = tf.decode_raw(parsed_features['img_raw'],tf.float32)
parsed_features['mask_raw'] = tf.decode_raw(parsed_features['mask_raw'],tf.float32)
return parsed_features['img_raw'], parsed_features['mask_raw']
def create_dataset(filepath):
dataset = tf.data.TFRecordDataset(filepath)
dataset = dataset.map(_parse_function, num_parallel_calls=8)
dataset = dataset.repeat()
dataset = dataset.shuffle(SHUFFLE_BUFFER)
dataset = dataset.batch(BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
img_raw,mask_raw = iterator.get_next()
img_raw = tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL])
mask_raw = tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])
return img_raw,mask_raw
STEPS_PER_EPOCH = SUM_OF_ALL_DATASAMPLES/BATCH_SIZE
# load dataset
filenames_train = tf.data.Dataset.list_files(directory)
img_raw,mask_raw = create_dataset(filenames_train)
print('The shape is ' + str(img_raw.shape) + '.')
imageSize = img_raw.shape[1:]
# take input
model_input = keras.layers.Input(tensor=tf.reshape(img_raw,[-1,HEIGHT,WIDTH,CHANNEL]))
# define network structure
# abbreviated here
model_output = Conv2D(1, (1, 1))(model_input)
# create model
with tf.device('/cpu:0'):
model2 = Model(inputs=model_input, outputs=model_output)
# compile model
model = multi_gpu_model(model2, gpus=2)
model.compile(optimizer='adam', loss='mse', metrics=['mae'], target_tensors=[tf.reshape(mask_raw,[-1,HEIGHT,WIDTH,1])])
earlystopper = EarlyStopping(patience=2, verbose=1)
checkpointer = ModelCheckpoint(modelFile, verbose=1, save_best_only=True)
results = model.fit(validation_split=ALPHA, epochs=3, callbacks=[earlystopper, checkpointer], steps_per_epoch=STEPS_PER_EPOCH)
model.summary()