读取csv时Tensorflow DecodeCSV字段错误

时间:2018-11-21 13:06:26

标签: python tensorflow

我目前正在尝试训练Tensorflow模型,并且在读取非常简单的CSV文件时遇到问题。目前,在运行训练脚本时出现以下错误:

InvalidArgumentError (see above for traceback): Expect 2 fields but have 1 in record 0

CSV文件如下所示:

github,Awesome per directory history for ZSH
github,PHP class which implements the Elo rating system
github,Comic Sans Everything

我已经检查了训练集和验证集的“额外”逗号,这些逗号可能会在读取过程中破坏定界,但是我没有发现这样的错误。有没有办法找出我的数据集中哪一行打破了读取功能?

def read_dataset(prefix):
    # use prefix to create filename
    filename = 'gs://{}/{}*csv*'.format(BUCKET, prefix)
    if prefix == 'train':
        mode = tf.contrib.learn.ModeKeys.TRAIN
    else:
        print('EvalSet')
        mode = tf.contrib.learn.ModeKeys.EVAL

    # the actual input function passed to TensorFlow
    def _input_fn():
        # could be a path to one file or a file pattern.
        input_file_names = tf.train.match_filenames_once(filename)
        filename_queue = tf.train.string_input_producer(input_file_names, shuffle=True)

        # read CSV
        reader = tf.TextLineReader(skip_header_lines=0)
        _, value = reader.read_up_to(filename_queue, num_records=BATCH_SIZE)
        print(value)
        #value = tf.train.shuffle_batch([value], BATCH_SIZE, capacity=10*BATCH_SIZE, min_after_dequeue=BATCH_SIZE, enqueue_many=True, allow_smaller_final_batch=False)
        value_column = tf.expand_dims(value, -1)

        columns = tf.decode_csv(value_column, record_defaults = DEFAULTS, field_delim=',', use_quote_delim=False, na_value="navalue")

        features = dict(zip(CSV_COLUMNS, columns))
        label = features.pop(LABEL_COLUMN)

        # make targets numeric
        table = tf.contrib.lookup.index_table_from_tensor(
                                     mapping=tf.constant(TARGETS), num_oov_buckets=0, default_value=-1)

        target = table.lookup(label)

        return features, target

    return _input_fn

0 个答案:

没有答案