我找不到文档的TensorFlow的tf.data.filter()的第二个参数是什么?

时间:2018-12-01 15:56:37

标签: python tensorflow

我最近在使用时有TypeError

def lie_filter(line):
    return tf.equal(line['lie_id'], 2)

dataset = (
    tf.data
    .TextLineDataset('shots.csv')
    .skip(1)
    .map(decode_line)
    .filter(lie_filter)
    .cache())

确切的错误是TypeError: lie_filter() takes 1 positional argument but 2 were given

仅将函数签名更改为lie_filter(line, x)即可使错误消失,并且过滤工作按预期进行。但是,这让我想知道这个神秘的第二个争论是什么。

TensorFlow manual for tf.data.filter()仅指定一个参数。 TensorFlow也有许多示例,其中过滤是按照我上面的尝试完成的。例如,imports85.py

x内打印lie_filter会产生Tensor("arg12:0", shape=(), dtype=float32)

第二个参数是什么?在哪里可以找到有关它的文档?

谢谢!

1 个答案:

答案 0 :(得分:0)

好的,当然,提交问题后,我终于明白了。我怀疑这是我自己做的。 map()返回(features, label)的元组。第二个参数当然是label作为张量。

希望这对将来的人有所帮助:)