读入tf.dataset时的一种热编码

时间:2019-07-30 13:48:17

标签: python tensorflow google-cloud-platform tensorflow-datasets

我正在gcp-ai平台上运行一个tensorflow模型。数据集很大,并且并非所有内容都可以同时保存在内存中,因此我使用以下代码将数据读入tf.dataset

def read_dataset(filepattern):
    def decode_csv(value_column):
        cols = tf.io.decode_csv(value_column, record_defaults=[[0.0],[0],[0.0])
        features=[cols[1],cols[2]]
        label = cols[0]
        return features, label
    # Create list of files that match pattern
    file_list = tf.io.gfile.glob(filepattern)
    # Create dataset from file list
    dataset = tf.data.TextLineDataset(file_list).map(decode_csv)
    return dataset

training_data=read_dataset(<filepattern>)

问题在于我数据中的第二列是分类的,我需要使用一种热编码。在函数decode_csv中或以后再操作tf.dataset怎么做。

1 个答案:

答案 0 :(得分:0)

您可以使用tf.one_hot。假设第二列为cols[1],并且类别值已转换为整数,则可以执行以下操作:

def decode_csv(value_column):
    cols = tf.io.decode_csv(value_column, record_defaults=[[0.0],[0],[0.0]])
    features=[cols[1], tf.one_hot(cols[2], nb_classes)]
    label = cols[0]
    return features, label

注意:未测试。