如何按特定值过滤tf.data.Dataset?

时间:2018-02-16 11:27:08

标签: python tensorflow tensorflow-datasets

我通过读取TFRecords来创建数据集,我映射了值,我想过滤特定值的数据集,但由于结果是带有张量的dict,我无法得到张量的实际值或使用tf.cond() / tf.equal进行检查。我怎么能这样做?

def mapping_func(serialized_example):
    feature = { 'label': tf.FixedLenFeature([1], tf.string) }
    features = tf.parse_single_example(serialized_example, features=feature)
    return features

def filter_func(features):
    # this doesn't work
    #result = features['label'] == 'some_label_value'
    # neither this
    result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
    return result

def main():
    file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
    dataset = tf.contrib.data.TFRecordDataset(file_names)
    dataset = dataset.map(mapping_func)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.filter(filter_func)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    sample = iterator.get_next()

4 个答案:

答案 0 :(得分:3)

我正在回答我自己的问题。我发现了这个问题!

我需要做的是tf.unstack()这样的标签:

label = tf.unstack(features['label'])
label = label[0]

在我将其提交给tf.equal()

之前
result = tf.reshape(tf.equal(label, 'some_label_value'), [])

我认为问题是标签被定义为一个类型为字符串tf.FixedLenFeature([1], tf.string)的元素的数组,所以为了得到第一个和单个元素我必须解压缩它(创建一个列表)和然后得到索引为0的元素,如果我错了,请纠正我。

答案 1 :(得分:0)

我认为你不需要首先将标签制作成一维阵列。

使用:

feature = {'label': tf.FixedLenFeature((), tf.string)}

您不需要在filter_func

中取消堆叠标签

答案 2 :(得分:0)

读取、过滤数据集非常容易,无需拆开任何东西。

读取数据集:

print(my_dataset, '\n\n')
##let us print the first 3 records
for record in my_dataset.take(3):
    ##below could be large in case of image
    print(record)
    ##let us print a specific key
    print(record['key2'])

过滤同样简单:

my_filtereddataset = my_dataset.filter(_filtcond1)

您可以根据需要在何处定义 _filtcond1。假设您的数据集中有一个“真”“假”布尔标志,然后:

@tf.function
def _filtcond1(x):
    return x['key_bool'] == 1

甚至是一个 lambda 函数:

my_filtereddataset = my_dataset.filter(lambda x: x['key_int']>13)

如果您正在阅读尚未创建的数据集或您不知道密钥(似乎是 OP 案例),您可以使用它首先了解密钥和结构:

import json
from google.protobuf.json_format import MessageToJson

for raw_record in noidea_dataset.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    ##print(example) ##if image it will be toooolong
    m = json.loads(MessageToJson(example))
    print(m['features']['feature'].keys())

现在您可以继续过滤

答案 3 :(得分:-3)

您应该尝试使用apply函数 tf.data.TFRecordDataset tensorflow documentation

否则......阅读这篇关于TFRecords的文章,以便更好地了解TFRecords TFRecords for humans

但最可能的情况是你无法访问既不修改TFRecord ... github上有关于此主题的请求TFRecords request

我的建议是让事情尽可能简单......你必须知道你在使用图表和会话......

在任何情况下......如果一切都失败了,尝试在tensorflow会话中不起作用的代码部分就像你可以做到的那样简单......可能所有这些操作都应该在tf.session运行时完成。 ..