我有一个这样创建的tf.data.Dataset
:
dataset = tf.data.Dataset.from_tensor_slices(({"reviews": x_train}, y_train))
我只想在空白处分割评论(字符串)。当我这样做时:
dataset = dataset.map(lambda string: tf.string_split([string]))
Python抱怨,告诉我:
TypeError: <lambda>() takes exactly 1 argument (2 given)
我看了看文档,为什么Python认为我给了两个参数……有什么主意呢?
谢谢!
答案 0 :(得分:1)
似乎与tensorflow定义map()
的方式有关。在此处查看文档:{{3}}
map()
的签名:
地图( map_func, num_parallel_calls =无 )
重要的是:
map_func的输入签名由该数据集中每个元素的结构决定。
因此,您的dataset
必须以某种方式排列为大小为2的元组元素,这使得map
将2个参数传递到map_func
中。但是,您可以定义lambda函数,例如:
lambda string: tf.string_split([string])
这意味着它需要1个输入,即string
。