我不明白为什么它不起作用以及为什么有错误,请帮助我。
def get_data_set_with_csv(file_name, header_lines, batch_size, repeat_size):
dataset = tf.data.TextLineDataset(file_name).skip(header_lines)
def parse_csv(line):
col_types = [tf.ones(shape=(1,),dtype=tf.float32)]+[tf.zeros(shape=(1,),dtype=tf.float32)]*784
data= tf.decode_csv(tf.expand_dims(line,axis=0),col_types)
label = data[0]
img = data[1:]
return label, img
dataset = dataset.map(parse_csv)
return dataset
我和