我有一个具有1个输入和3个输出的keras模型:
model = Model(inputs=x, outputs=[out, aux_2, aux_1])
我想用model.fit和tf.dataset进行训练:
model.fit(training_set.make_one_shot_iterator(),steps_per_epoch=steps_per_epoch_train,epochs=1, verbose=1)
我这样创建数据集:
def load_dataset_from_tfrecordfile_ICNET(tfrecords_path, batch_size, num_classes):
def preprocess_fn(example_proto):
'''A transformation function to preprocess raw data
into trainable input. '''
features = {
"image_raw": tf.FixedLenFeature((), tf.string),
"label_raw": tf.FixedLenFeature((), tf.string),
"width": tf.FixedLenFeature((), tf.int64),
"height": tf.FixedLenFeature((), tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features)
image = tf.decode_raw(parsed_features['image_raw'], tf.uint8)
label = tf.decode_raw(parsed_features['label_raw'], tf.uint8)
height = tf.cast(parsed_features["height"], tf.int32)
width = tf.cast(parsed_features["width"], tf.int32)
image_shape = tf.stack([height,width,3])
label_shape = tf.stack([height,width,num_classes]) # opencv order: width, height, channel numpy order: height, width, channel
image = tf.reshape(image, image_shape)
label = tf.reshape(label, label_shape)
label_low_resolution = tf.image.resize_images(label,[12,20], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True)
label_medium_resolution = tf.image.resize_images(label,[24,40], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True)
return image, label, label_medium_resolution, label_low_resolution
tfrecords_filename = glob.glob(tfrecords_path + "*.tfrecord")
dataset = tf.data.TFRecordDataset(tfrecords_filename)
dataset = dataset.apply(tf.contrib.data.map_and_batch(
preprocess_fn, batch_size,
num_parallel_batches=8, # cpu cores
drop_remainder=True))
dataset = dataset.prefetch(2)
print("Dataset was loaded successfully.")
return dataset
我收到此错误:
753 if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
754 raise ValueError('Please provide data as a list or tuple of 2 elements '
--> 755 ' - input and target pair. Received %s' % next_element)
756 x, y = next_element
757 TypeError: not all arguments converted during string formatting
当我像这样返回数据集中的两个元素的列表时:
return [image, [label, label_medium_resolution, label_low_resolution]]
然后我收到此错误:
Dimension 0 in both shapes must be equal, but are 24 and 12. Shapes are [24,40,7] and [12,20,7].
From merging shape 1 with other shapes. for 'packed' (op: 'Pack') with input shapes: [?,?,7], [24,40,7], [12,20,7].
您知道如何解决此问题吗?