如何将tf.keras中的model.fit与tf.dataset和多个输出配合使用

时间:2018-08-20 08:51:39

标签: tensorflow keras

我有一个具有1个输入和3个输出的keras模型:

model = Model(inputs=x, outputs=[out, aux_2, aux_1])

我想用model.fit和tf.dataset进行训练:

model.fit(training_set.make_one_shot_iterator(),steps_per_epoch=steps_per_epoch_train,epochs=1, verbose=1)

我这样创建数据集:

def load_dataset_from_tfrecordfile_ICNET(tfrecords_path, batch_size, num_classes):
    def preprocess_fn(example_proto):
        '''A transformation function to preprocess raw data
        into trainable input. '''    
        features = {
                "image_raw": tf.FixedLenFeature((), tf.string),
                "label_raw": tf.FixedLenFeature((), tf.string),
                "width": tf.FixedLenFeature((), tf.int64),
                "height": tf.FixedLenFeature((), tf.int64)
                }

        parsed_features = tf.parse_single_example(example_proto, features) 

        image = tf.decode_raw(parsed_features['image_raw'], tf.uint8)
        label = tf.decode_raw(parsed_features['label_raw'], tf.uint8)

        height = tf.cast(parsed_features["height"], tf.int32)
        width = tf.cast(parsed_features["width"], tf.int32)

        image_shape = tf.stack([height,width,3])
        label_shape = tf.stack([height,width,num_classes]) # opencv order:  width, height, channel numpy order: height, width, channel 

        image = tf.reshape(image, image_shape)
        label = tf.reshape(label, label_shape)

        label_low_resolution = tf.image.resize_images(label,[12,20], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True)
        label_medium_resolution = tf.image.resize_images(label,[24,40], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True)

        return image, label, label_medium_resolution, label_low_resolution

    tfrecords_filename = glob.glob(tfrecords_path + "*.tfrecord")

    dataset = tf.data.TFRecordDataset(tfrecords_filename)
    dataset = dataset.apply(tf.contrib.data.map_and_batch(
        preprocess_fn, batch_size,
    num_parallel_batches=8,  # cpu cores
    drop_remainder=True))
    dataset = dataset.prefetch(2)

    print("Dataset was loaded successfully.")
    return dataset

我收到此错误:

753       if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
754         raise ValueError('Please provide data as a list or tuple of 2 elements '
--> 755                          ' - input and target pair. Received %s' % next_element)
756       x, y = next_element
757 TypeError: not all arguments converted during string formatting

我尝试过的事情:

当我像这样返回数据集中的两个元素的列表时:
return [image, [label, label_medium_resolution, label_low_resolution]]

然后我收到此错误:

Dimension 0 in both shapes must be equal, but are 24 and 12. Shapes are [24,40,7] and [12,20,7].
From merging shape 1 with other shapes. for 'packed' (op: 'Pack') with input shapes: [?,?,7], [24,40,7], [12,20,7].

您知道如何解决此问题吗?

0 个答案:

没有答案