我正在为张量流估计器编写input_fn,就像下面的指南中所述:
def _parse_line(line):
fields = tf.decode_csv(line, FIELD_DEFAULTS)
features = dict(zip(COLUMNS, fields))
features.pop("DATE")
label = features.pop("LABEL")
return features, label
def csv_input_fn(csv_path):
filenames = [join(csv_path, f) for f in os.listdir(csv_path)]
dataset = tf.data.TextLineDataset(filenames).skip(0)
dataset = dataset.map(_parse_line)
return dataset
但是对于分类,我需要将我的标签设置为热门标签,而且它们还需要根据边界将其归为特殊类别。
例如,0到2之间的值是一个类别,而2到5之间的值则是另一类别。然后,如果我的标签为4.1,则输出应为[0,1],如果我的标签值为0.5,则输出应为[1,0]
我认为应该在_parse_line函数中添加代码,但是有什么想法怎么做?非常感谢!