将多个文件输入Tensorflow数据集

时间:2018-03-08 19:26:12

标签: tensorflow tensorflow-serving tensorflow-datasets tensorflow-estimator

我有以下input_fn。

def input_fn(filenames, batch_size):
    # Create a dataset containing the text lines.
    dataset = tf.data.TextLineDataset(filenames).skip(1)

    # Parse each line.
    dataset = dataset.map(_parse_line)

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(10000).repeat().batch(batch_size)

    # Return the dataset.
    return dataset

如果filenames=['file1.csv']filenames=['file2.csv'],则效果很好。如果filenames=['file1.csv', 'file2.csv'],它会给我一个错误。在Tensorflow documentation中,它表示filenames是包含一个或多个文件名的tf.string张量。我该如何导入多个文件?

以下是错误。它似乎忽略了上面.skip(1)中的input_fn

InvalidArgumentError: Field 0 in record 0 is not a valid int32: row_id
 [[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_INT32, DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, ..., DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4, DecodeCSV/record_defaults_5, DecodeCSV/record_defaults_6, DecodeCSV/record_defaults_7, DecodeCSV/record_defaults_8, DecodeCSV/record_defaults_9, DecodeCSV/record_defaults_10, DecodeCSV/record_defaults_11, DecodeCSV/record_defaults_12, DecodeCSV/record_defaults_13, DecodeCSV/record_defaults_14, DecodeCSV/record_defaults_15, DecodeCSV/record_defaults_16, DecodeCSV/record_defaults_17, DecodeCSV/record_defaults_18)]]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?], [?], [?], [?], ..., [?], [?], [?], [?], [?]], output_types=[DT_FLOAT, DT_INT32, DT_INT32, DT_STRING, DT_STRING, ..., DT_INT32, DT_FLOAT, DT_INT32, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]

1 个答案:

答案 0 :(得分:7)

使用tf.data.TextLineDataset您有正确的想法。但是,当前实现的作用是在文件名的输入张量中生成每个文件的每一行,除了第一个文件的第一个文件。跳过第一行的方式现在只影响第一个文件中的第一行。在第二个文件中,不会跳过第一行。

根据Datasets guide上的示例,您应该调整代码以首先从文件名创建常规Dataset,然后对每个文件名运行flat_map以使用{{ 1}},同时跳过第一行:

TextLineDataset

此处,d = tf.data.Dataset.from_tensor_slices(filenames) # get dataset from each file, skipping first line of each file d = d.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1)) d = d.map(_parse_line) # And whatever else you need to do 通过读取文件的内容并跳过第一行,从原始数据集的每个元素创建一个新数据集。