我正在尝试创建一个管道来使用TensorFlow Dataset API和Pandas读取多个CSV文件。但是,使用flat_map
方法会产生错误。但是,如果我使用map
方法,我可以构建代码并在会话中运行它。这是我正在使用的代码。我已经在TensorFlow Github存储库中打开了#17415问题。但显然,这不是一个错误,他们让我在这里发帖。
folder_name = './data/power_data/'
file_names = os.listdir(folder_name)
def _get_data_for_dataset(file_name,rows=100):#
print(file_name.decode())
df_input=pd.read_csv(os.path.join(folder_name, file_name.decode()),
usecols =['Wind_MWh','Actual_Load_MWh'],nrows = rows)
X_data = df_input.as_matrix()
X_data.astype('float32', copy=False)
return X_data
dataset = tf.data.Dataset.from_tensor_slices(file_names)
dataset = dataset.flat_map(lambda file_name: tf.py_func(_get_data_for_dataset,
[file_name], tf.float64))
dataset= dataset.batch(2)
fiter = dataset.make_one_shot_iterator()
get_batch = iter.get_next()
我收到以下错误:map_func must return a Dataset object
。当我使用map
时,管道正常运行,但它没有提供我想要的输出。例如,如果Pandas从我的每个CSV文件中读取N行,我希望管道连接B文件中的数据并给我一个形状为数组的数组(N * B,2)。相反,它给了我(B,N,2),其中B是批量大小。 map
正在添加另一个轴而不是在现有轴上连接。根据我在文档中所理解的flat_map
应该给出一个平坦的输出。在文档中,map
和flat_map
都返回类型数据集。那么我的代码如何使用map而不是flat_map?
如果您能指出数据集API与Pandas模块一起使用的代码,那也很棒。
答案 0 :(得分:6)
作为mikkola points out in the comments,Dataset.map()
和Dataset.flat_map()
期望具有不同签名的函数:Dataset.map()
采用将输入数据集的单个元素映射到单个新元素的函数,而Dataset.flat_map()
采用的函数将输入数据集的单个元素映射到Dataset
个元素。
如果你想要_get_data_for_dataset()
返回的数组的每一行
要成为一个单独的元素,您应该使用Dataset.flat_map()
并使用Dataset.from_tensor_slices()
将tf.py_func()
的输出转换为Dataset
:
folder_name = './data/power_data/'
file_names = os.listdir(folder_name)
def _get_data_for_dataset(file_name, rows=100):
df_input=pd.read_csv(os.path.join(folder_name, file_name.decode()),
usecols=['Wind_MWh', 'Actual_Load_MWh'], nrows=rows)
X_data = df_input.as_matrix()
return X_data.astype('float32', copy=False)
dataset = tf.data.Dataset.from_tensor_slices(file_names)
# Use `Dataset.from_tensor_slices()` to make a `Dataset` from the output of
# the `tf.py_func()` op.
dataset = dataset.flat_map(lambda file_name: tf.data.Dataset.from_tensor_slices(
tf.py_func(_get_data_for_dataset, [file_name], tf.float32)))
dataset = dataset.batch(2)
iter = dataset.make_one_shot_iterator()
get_batch = iter.get_next()