我正在gcp-ai平台上运行一个tensorflow模型。数据集很大,并且并非所有内容都可以同时保存在内存中,因此我使用以下代码将数据读入tf.dataset
:
def read_dataset(filepattern):
def decode_csv(value_column):
cols = tf.io.decode_csv(value_column, record_defaults=[[0.0],[0],[0.0])
features=[cols[1],cols[2]]
label = cols[0]
return features, label
# Create list of files that match pattern
file_list = tf.io.gfile.glob(filepattern)
# Create dataset from file list
dataset = tf.data.TextLineDataset(file_list).map(decode_csv)
return dataset
training_data=read_dataset(<filepattern>)
问题在于我数据中的第二列是分类的,我需要使用一种热编码。在函数decode_csv
中或以后再操作tf.dataset
怎么做。
答案 0 :(得分:0)
您可以使用tf.one_hot。假设第二列为cols[1]
,并且类别值已转换为整数,则可以执行以下操作:
def decode_csv(value_column):
cols = tf.io.decode_csv(value_column, record_defaults=[[0.0],[0],[0.0]])
features=[cols[1], tf.one_hot(cols[2], nb_classes)]
label = cols[0]
return features, label
注意:未测试。