无法使dataset.filter()在model / official / resnet / resnet_run_loop.py文件中工作

时间:2018-11-22 05:18:39

标签: python tensorflow resnet

在正式的resnet模型中,当eval_only设置为True时,我想通过'label'的值过滤来自test.bin的数据集。我尝试使用tf.data.Dataset.filter()函数仅获取一类测试数据,但没有用。

dataset = dataset.filter(lambda inputs, label: tf.equal(label,15))

我将此代码放在resnet_run_loop.process_record_dataset函数中,但引发了错误

 raise ValueError("`predicate` must return a scalar boolean tensor.")

我发现张量'label'的形状是(?,):'Tensor(“ arg1:0”,shape =(?,),dtype = int32,device = / device:CPU:0)'< / p>

1 个答案:

答案 0 :(得分:0)

我在不同的情况下遇到了同样的问题,并且正如评论中所建议的那样,结果证明该问题是由过滤前的批处理引起的。

您可以使用以下示例重现此内容:

import pprint
import tensorflow as tf

dataset = tf.data.Dataset.zip((
    tf.data.Dataset.range(0, 5),
    tf.data.Dataset.from_tensor_slices([0, 10, 15, 20, 15])
))
pprint.pprint(list(dataset.as_numpy_iterator()))
# [(0, 0), (1, 10), (2, 15), (3, 20), (4, 15)]

filtered = dataset.filter(lambda x, y: y == 15)
pprint.pprint(list(filtered.as_numpy_iterator()))
# [(2, 15), (4, 15)]

BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = batched.filter(lambda x, y: y == 15)
# ValueError: `predicate` return type must be convertible to a scalar boolean tensor. Was [...]

一个简单的解决方案是unbatch您的数据集,然后过滤,最后再次批处理:

BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = batched.unbatch().filter(lambda x, y: y == 15).batch(BATCH_SIZE)
pprint.pprint(list(batched_filtered.as_numpy_iterator()))
# [(array([1, 2]), array([15, 15], dtype=int32)),
#  (array([4]), array([15], dtype=int32))]

如果您不知道或不想跟踪BATCH_SIZE的值,可以调整this solution来按需计算批次大小。

我最终将这两种解决方案合并在一起,

def calculate_batch_size(dataset):
    return next(iter(dataset))[0].shape[0]

def filter_batch(dataset, pred_fn):
    batch_size = calculate_batch_size(dataset)
    return dataset.unbatch().filter(pred_fn).batch(batch_size)

BATCH_SIZE = 2
batched = dataset.batch(BATCH_SIZE)
batched_filtered = filter_batch(batched, lambda x, y: y == 15)
pprint.pprint(list(batched_filtered.as_numpy_iterator()))
# [(array([1, 2]), array([15, 15], dtype=int32)),
#  (array([4]), array([15], dtype=int32))]