当数据集具有不同的数据格式时,如何将两个输入管道与tf.data.Dataset合并

时间:2018-11-30 22:04:53

标签: python tensorflow tensorflow-datasets

我有两个不同的数据集可以馈入张量流中的模型。因此,我发现可馈送的迭代器是一个不错的选择。这是第一个数据集的形状:

fer +:

  • 图像:类型:浮动/暗淡:64 x 64(这些是灰度图像)

  • 标签:类型:浮动/暗淡:(?,8)

sewa:

  • 图像:uint8 /暗(?,64、64)

  • 名称:字符串(?,)

  • 化合价:浮动/暗淡:(?,1)

  • arousal:float / dim:(?,1)

  • 喜欢:浮动/暗淡:(?,1)

  • 检测到:int(这将告诉我是否在一帧中检测到了脸部)。

代码如下:

def _parse_fer_example(example_proto):
    features = {
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64),
        'image_raw': tf.FixedLenFeature([], tf.string),
        'labels': tf.FixedLenFeature([8], tf.float32)
    }

    parsed_features = tf.parse_single_example(example_proto, features)

    # This is how we create one example, that is, extract one example from the database.
    image = tf.decode_raw(parsed_features['image_raw'], tf.float32)
    # The height and the weights are used to
    height = tf.cast(parsed_features['height'], tf.int32)
    width = tf.cast(parsed_features['width'], tf.int32)

    image = tf.reshape(image, [64, 64])

    labels = parsed_features['labels']

    return image, labels

def _parse_sewa_example(example_proto):
    # The annotation contains the following features: timestamp; arousal; valence; liking
    features = {
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64),
        'image_raw': tf.FixedLenFeature([], tf.string),
        'name': tf.FixedLenFeature([], tf.string),
        'frame_number': tf.FixedLenFeature([1], tf.int64),
        'time': tf.FixedLenFeature([1], tf.float32),
        'detected': tf.FixedLenFeature([1], tf.int64),
        'arousal': tf.FixedLenFeature([1], tf.float32),
        'valence': tf.FixedLenFeature([1], tf.float32),
        'liking': tf.FixedLenFeature([1], tf.float32),
        'istalking': tf.FixedLenFeature([1], tf.int64)
    }

    parsed_features = tf.parse_single_example(example_proto, features)

    # This is how we create one example, that is, extract one example from the database.
    image = tf.decode_raw(parsed_features['image_raw'], tf.uint8)
    # The height and the weights are used to
    height = tf.cast(parsed_features['height'], tf.int32)
    width = tf.cast(parsed_features['width'], tf.int32)

    # The image is reshaped since when stored as a binary format, it is flattened. Therefore, we need the
    # height and the weight to restore the original image back.
    # Tensor("Reshape:0", shape=(112, 112, 3), dtype=uint8)
    image = tf.reshape(image, [112, 112, 3])

    name = parsed_features['name']
    istalking = parsed_features['istalking']
    detected = parsed_features['detected']
    arousal = parsed_features['arousal']
    valence = parsed_features['valence']
    liking = parsed_features['liking']

    return name, detected, arousal, valence, liking, istalking, image


# This will load fer dataset.
def load_fer_data(file_name, batch_size):
    dataset = tf.data.TFRecordDataset(file_name).map(_parse_fer_example).shuffle(15000).batch(batch_size)
    iterator = dataset.make_initializable_iterator(shared_name='fer_iterator')

    images, labels = iterator.get_next()
    print(images, labels)

    return images, labels, iterator

def load_sewa_data(file_name, batch_size):

    dataset = tf.data.TFRecordDataset(file_name).map(_parse_sewa_example).batch(batch_size)
    iterator = dataset.make_initializable_iterator(shared_name='sewa_iterator')

    with tf.name_scope('sewa_next_batch'):
        next_batch = iterator.get_next()

        names, detected, arousal, valence, liking, istalkings, images = next_batch

        print(names, detected, arousal, valence, liking, istalkings, images)

        return names, detected, arousal, valence, liking, istalkings, images, iterator


if __name__ == '__main__':

    tf_records = os.listdir(sewa_tfrecords_path)

    sewa_train_files = [sewa_tfrecords_path + "/" + f for f in tf_records if 'Train' in f]
    sewa_devel_files = [sewa_tfrecords_path + "/" + f for f in tf_records if 'Devel' in f]
    print(sewa_train_files)

    filenames = tf.placeholder(tf.string, shape=[None])

    names, detected, arousal, valence, liking, istalkings, images_sewa, iterator_sewa = load_sewa_data(filenames, 34)
    images_fer, labels, iterator_fer = load_fer_data(filenames, 64)

    with tf.Session() as sess:

        sess.run(iterator_sewa.initializer, feed_dict={filenames: sewa_train_files})
        total = 0
        # ....

打印后,这就是我得到的:

Tensor("sewa_next_batch/IteratorGetNext:0", shape=(?,), dtype=string) Tensor("sewa_next_batch/IteratorGetNext:1", shape=(?, 1), dtype=int64) Tensor("sewa_next_batch/IteratorGetNext:2", shape=(?, 1), dtype=float32) Tensor("sewa_next_batch/IteratorGetNext:3", shape=(?, 1), dtype=float32) Tensor("sewa_next_batch/IteratorGetNext:4", shape=(?, 1), dtype=float32) Tensor("sewa_next_batch/IteratorGetNext:5", shape=(?, 1), dtype=int64) Tensor("sewa_next_batch/IteratorGetNext:6", shape=(?, 112, 112, 3), dtype=uint8)

AND:

Tensor("IteratorGetNext:0", shape=(?, 64, 64), dtype=float32) Tensor("IteratorGetNext:1", shape=(?, 8), dtype=float32)

这样做的原因是:

我想在fer数据集上训练模型;然后删除分类器,为sewa数据库创建一个不同的分类器;然后再次训练整个模型。

因此,当在fer数据集上训练模型时,我将不得不在最后进行分类,那里我们有8个输出,而在sewa数据库上训练时,我将有2个输出。我需要一个不同的分类器。

因此,现在,我想使用类似于feedable iterator的方式来合并两个数据库。但是我不确定如何。

或者,如果我们看一下https://www.tensorflow.org/guide/datasets上的文档,则会提到:

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

因此,如果数据集具有不同的结构,该如何处理?

0 个答案:

没有答案