我最近在使用时有TypeError
def lie_filter(line):
return tf.equal(line['lie_id'], 2)
在
dataset = (
tf.data
.TextLineDataset('shots.csv')
.skip(1)
.map(decode_line)
.filter(lie_filter)
.cache())
确切的错误是TypeError: lie_filter() takes 1 positional argument but 2 were given
。
仅将函数签名更改为lie_filter(line, x)
即可使错误消失,并且过滤工作按预期进行。但是,这让我想知道这个神秘的第二个争论是什么。
TensorFlow manual for tf.data.filter()仅指定一个参数。 TensorFlow也有许多示例,其中过滤是按照我上面的尝试完成的。例如,imports85.py。
在x
内打印lie_filter
会产生Tensor("arg12:0", shape=(), dtype=float32)
。
第二个参数是什么?在哪里可以找到有关它的文档?
谢谢!
答案 0 :(得分:0)
好的,当然,提交问题后,我终于明白了。我怀疑这是我自己做的。 map()
返回(features, label)
的元组。第二个参数当然是label
作为张量。
希望这对将来的人有所帮助:)