如何在Tensorflow中使用带有迭代器的tf.datasets

时间:2018-05-14 13:01:09

标签: python tensorflow tensorflow-datasets

我正在尝试使用tf.data.TextLineDataset从csv文件中读取,在多个工作节点上对数据集进行分片,然后创建迭代器以迭代它们以批量提供数据。我在TensorFlow(https://www.tensorflow.org/programmers_guide/datasets)的tf.datasets上使用了程序员指南。 在tf会话中运行代码时出现的问题是我收到以下错误:

*** tensorflow.python.framework.errors_impl.NotFoundError: Date,Open,High,Low,Last,Close,Total Trade Quantity,Turnover,close_pct_change_1d,KAMA7-KAMA30,KAMA15-KAMA30,HT_QUAD,TURNOVER,BOP,MFI,MINUS_DI,ROCP,STOCH_SLOWK,NATR,EMA7-EMA30-1d,DX-1d,PPO-1d,NATR-1d,HT_INPHASOR-2d,day_0,day_1,day_2,day_3; No such file or directory
     [[Node: IteratorGetNext_5 = IteratorGetNext[output_shapes=[[], [], [], [], [], ..., [], [], [], [], []], output_types=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32, ..., DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_8)]]

现在,“日期”,“打开”,“高”等是我要加载的数据集中的列。因此,我知道错误与加载数据集无关。 加载数据集时,我使用tf.data.TextLineDataset(file).skip(1)但是根据错误,它似乎没有跳过我的数据集的第一行(它们是列头)。

有人知道这个错误来自哪里吗?有没有人解决这个问题?

请参阅以下代码以获得澄清:

def create_pipeline(bs, nr, ep):

    def _X_parse_csv(file):

        record_defaults=[[0]]*20
        splits = tf.decode_csv(file, record_defaults)
        input = splits 

        return input

    def _y_parse_csv(file):

        record_defaults=[[0]]*20
        splits = tf.decode_csv(file, record_defaults)
        label = splits[0] 

        return label


    # Dataset for input data
    file = tf.gfile.Glob("./NSEOIL.csv")

    num_workers = 1 # for testing; simulate 1 node for sharding below
    task_index = 0

    ds_file = tf.data.TextLineDataset(file)

    ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1))) #remove CSV headers
    ds = ds.shard(num_workers, task_index).repeat(ep)
    X_train = ds.map(_X_parse_csv)

    ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(2))) #remove CSV headers + shift forward 1 day
    ds = ds.shard(num_workers, task_index).repeat(ep)
    y_train = ds.map(_y_parse_csv)

    X_iterator = X_train.make_initializable_iterator()
    y_iterator = y_train.make_initializable_iterator()

    return X_iterator, y_iterator

1 个答案:

答案 0 :(得分:1)

这两行似乎是问题的根源:

ds_file = tf.data.TextLineDataset(file)

ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1))) #remove CSV headers

第一行根据file中命名的文件(或多个文件)的行创建数据集。然后,第二行为ds_file中的每个元素创建一个数据集,该数据集将每个元素(来自file的一行文本)视为另一个文件名。当NotFoundError的第一行(看起来像是CSV标题)被视为文件名时,会引发file

修复相对简单,幸运的是,您可以使用Dataset.list_files()创建与您的glob匹配的文件数据集,然后Dataset.flat_map()将对文件名进行操作:

# Create a dataset of filenames.
ds_file = tf.data.Dataset.list_files("./NSEOIL.csv")

# For each filename in `ds_file`, read the lines from that file (skipping the
# header).
ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1)))