我目前正在尝试训练Tensorflow模型,并且在读取非常简单的CSV文件时遇到问题。目前,在运行训练脚本时出现以下错误:
InvalidArgumentError (see above for traceback): Expect 2 fields but have 1 in record 0
CSV文件如下所示:
github,Awesome per directory history for ZSH
github,PHP class which implements the Elo rating system
github,Comic Sans Everything
我已经检查了训练集和验证集的“额外”逗号,这些逗号可能会在读取过程中破坏定界,但是我没有发现这样的错误。有没有办法找出我的数据集中哪一行打破了读取功能?
def read_dataset(prefix):
# use prefix to create filename
filename = 'gs://{}/{}*csv*'.format(BUCKET, prefix)
if prefix == 'train':
mode = tf.contrib.learn.ModeKeys.TRAIN
else:
print('EvalSet')
mode = tf.contrib.learn.ModeKeys.EVAL
# the actual input function passed to TensorFlow
def _input_fn():
# could be a path to one file or a file pattern.
input_file_names = tf.train.match_filenames_once(filename)
filename_queue = tf.train.string_input_producer(input_file_names, shuffle=True)
# read CSV
reader = tf.TextLineReader(skip_header_lines=0)
_, value = reader.read_up_to(filename_queue, num_records=BATCH_SIZE)
print(value)
#value = tf.train.shuffle_batch([value], BATCH_SIZE, capacity=10*BATCH_SIZE, min_after_dequeue=BATCH_SIZE, enqueue_many=True, allow_smaller_final_batch=False)
value_column = tf.expand_dims(value, -1)
columns = tf.decode_csv(value_column, record_defaults = DEFAULTS, field_delim=',', use_quote_delim=False, na_value="navalue")
features = dict(zip(CSV_COLUMNS, columns))
label = features.pop(LABEL_COLUMN)
# make targets numeric
table = tf.contrib.lookup.index_table_from_tensor(
mapping=tf.constant(TARGETS), num_oov_buckets=0, default_value=-1)
target = table.lookup(label)
return features, target
return _input_fn