TensorFlow:从NSynth数据集中提取具有给定功能的数据

时间:2018-11-14 16:58:25

标签: python-3.x tensorflow magenta

我有一个序列化的TensorFlow Example协议缓冲区的TFRecord文件数据集,每个注释有一个Example Proto,从https://magenta.tensorflow.org/datasets/nsynth下载。我正在使用大约1 Gb的测试仪,以防有人要下载它,请检查下面的代码。每个示例都包含许多功能:音高,乐器...

读取此数据的代码是:

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

# Reading input data
dataset = tf.data.TFRecordDataset('../data/nsynth-test.tfrecord')

# Convert features into tensors
features = {
"pitch": tf.FixedLenFeature([1], dtype=tf.int64),
"audio": tf.FixedLenFeature([64000], dtype=tf.float32),
"instrument_family": tf.FixedLenFeature([1], dtype=tf.int64)}

parse_function = lambda example_proto: tf.parse_single_example(example_proto,features)
dataset = dataset.map(parse_function)

# Consuming TFRecord data.
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size=3)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
sess.run(batch)

现在,音高范围从21到108。但是我只想考虑给定音高的数据,例如pitch =51。如何从整个数据集中提取“ pitch = 51”子集?或者,如何使迭代器仅遍历此子集?

1 个答案:

答案 0 :(得分:1)

您看起来不错,所缺少的只是一个过滤器功能。

例如,如果您只想提取pitch = 51,则应在地图函数之后添加

dataset = dataset.filter(lambda example: tf.equal(example["pitch"][0], 51))