tf.data.Dataset:映射无法拆分字符串

时间:2018-10-23 19:07:46

标签: python tensorflow tensorflow-datasets

我有一个这样创建的tf.data.Dataset

dataset = tf.data.Dataset.from_tensor_slices(({"reviews": x_train}, y_train))

我只想在空白处分割评论(字符串)。当我这样做时:

dataset = dataset.map(lambda string: tf.string_split([string]))

Python抱怨,告诉我:

TypeError: <lambda>() takes exactly 1 argument (2 given)

我看了看文档,为什么Python认为我给了两个参数……有什么主意呢?

谢谢!

1 个答案:

答案 0 :(得分:1)

似乎与tensorflow定义map()的方式有关。在此处查看文档:{​​{3}}

map()的签名:

  

地图(       map_func,       num_parallel_calls =无   )

重要的是:

  

map_func的输入签名由该数据集中每个元素的结构决定。

因此,您的dataset必须以某种方式排列为大小为2的元组元素,这使得map将2个参数传递到map_func中。但是,您可以定义lambda函数,例如:

lambda string: tf.string_split([string])

这意味着它需要1个输入,即string