如何通过Tensorflow的管道保留有关示例的其他数据(功能和标签除外)?

时间:2019-06-12 14:41:39

标签: tensorflow tensorflow-estimator tf.keras

对于音频分类任务,我将一些音频文件切成固定大小的块,并将数据,标签,开始和结束时间以及音频文件名序列化(请参见下面的 data2serialized ),以 TFRecords < / strong>制作培训示例。

从生成的 TFRecords 中创建 Datasets 来馈送 tf.keras.models.Model.fit

进行预测时,我需要获取序列化数据的 filename 值,以合并给定音频文件中所有示例的结果,但 tf.keras.models.Model .predict 仅将要素作为输入,而我看不到如何获取预测和文件名作为输出。

我已经开始阅读the doc for tf.estimator.Estimator,但我仍然看不到如何通过预测管道传递不是输入也不是目标的额外数据...

有什么建议吗?

def data2serialized(filename, start_time, end_time, data, labels):
    feature = {
        'filename': _bytes_feature([filename.encode()]),
        'times': _float_feature([start_time, end_time]),
        'data': _float_feature(data.flatten()),
        'labels': _bytes_feature(["#".join(str(l) for l in labels).encode()]),
    }
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    return example.SerializeToString()


def serialized2data(serialized_data, feature_shape, class_list, nolabel_warning=True):
    """Generate features and labels.
    Labels are indices of original label in class_list.
    """

    features = {
        'filename': tf.FixedLenFeature([], tf.string),
        'times': tf.FixedLenFeature([2], tf.float32),
        'data': tf.FixedLenFeature(feature_shape, tf.float32),
        'labels': tf.FixedLenFeature([], tf.string),
    }
    example = tf.parse_single_example(serialized_data, features)

    # reshape data to channels_first format
    data = tf.reshape(example['data'], (1, feature_shape[0], feature_shape[1]))

    # one-hot encode labels
    labels = tf.strings.to_number(
        tf.string_split([example['labels']], '#').values,
        out_type=tf.int32
    )

    # get intersection of class_list and labels
    labels = tf.squeeze(
        tf.sparse.to_dense(
            tf.sets.intersection(
                tf.expand_dims(labels, axis=0),
                tf.expand_dims(class_list, axis=0)
            )
        ),
        axis=0
    )

    # sort class_list and get indices of labels in class_list
    class_list = tf.sort(class_list)
    labels = tf.where(
        tf.equal(
            tf.expand_dims(labels, axis=1),
            class_list)
    )[:,1]

    tf.cond(
        tf.math.logical_and(nolabel_warning, tf.equal(tf.size(labels), 0)),
        true_fn=lambda:myprint(tf.strings.format('File {} has no label', example['filename'])),
        false_fn=lambda:1
    )

    one_hot = tf.cond(
        tf.equal(tf.size(labels), 0),
        true_fn=lambda: tf.zeros(tf.size(class_list)),
        false_fn=lambda: tf.reduce_max(tf.one_hot(labels, tf.size(class_list)), 0)
    )

    return (data, one_hot)

def filelist2dataset(files, example_shape, class_list, training=True, batch_size=32, nolabel_warning=True):
    files = tf.convert_to_tensor(files, dtype=dtypes.string)
    files = tf.data.Dataset.from_tensor_slices(files)
    # dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(100), cycle_length=8)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=8)
    dataset = dataset.map(lambda x: serialized2data(x, example_shape, class_list, nolabel_warning))
    if training:
        dataset = dataset.shuffle(10000)
        dataset = dataset.repeat()  # Repeat the input indefinitely.
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset

0 个答案:

没有答案