用于读取稀疏数据的TensorFlow输入函数(以libsvm格式)

时间:2017-11-07 13:02:56

标签: tensorflow

我是TensorFlow的新手,并尝试使用Estimator API进行一些简单的分类实验。我在libsvm format中有一个稀疏数据集。以下输入函数适用于小型数据集:

def libsvm_input_function(file):

    def input_function():

        indexes_raw = []
        indicators_raw = []
        values_raw = []
        labels_raw = []
        i=0

        for line in open(file, "r"):
            data = line.split(" ")
            label = int(data[0])
            for fea in data[1:]:
                id, value = fea.split(":")
                indexes_raw.append([i,int(id)])
                indicators_raw.append(int(1))
                values_raw.append(float(value))
            labels_raw.append(label)
            i=i+1

        indexes = tf.SparseTensor(indices=indexes_raw,
                              values=indicators_raw,
                              dense_shape=[i, num_features])

        values = tf.SparseTensor(indices=indexes_raw,
                             values=values_raw,
                             dense_shape=[i, num_features])

        labels = tf.constant(labels_raw, dtype=tf.int32)

        return {"indexes": indexes, "values": values}, labels

    return input_function

但是,对于几GB大小的数据集,我收到以下错误:

  

ValueError:无法创建内容大于2GB的张量原型。

如何避免此错误?我应该如何编写输入函数来读取中等大小的稀疏数据集(采用libsvm格式)?

2 个答案:

答案 0 :(得分:0)

使用估计器时,对于libsvm数据输入,可以创建密集的If you want to read file from any arbitrary location, then you have to specify the full path for that.(assuming that you are trying to access the file outside)列表,密集的index列表,然后使用valuefeature_column.categorical_column_with_identity创建功能列,最后,将要素列放入估算器。也许您输入要素的长度是可变的,您可以使用padded_batch来处理它。 这里是一些代码:

feature_column.weighted_categorical_column

另一种方法是,您可以创建自定义功能列,例如:_SparseArrayCategoricalColumn

答案 1 :(得分:0)

我一直在使用tensorflow.contrib.libsvm。这是一个示例(我在发电机上使用急切执行)

import os
import tensorflow as tf
import tensorflow.contrib.libsvm as libsvm


def all_libsvm_files(folder_path):
    for file in os.listdir(folder_path):
        if file.endswith(".libsvm"):
            yield os.path.join(folder_path, file)

def load_libsvm_dataset(path_to_folder):
    return tf.data.TextLineDataset(list(all_libsvm_files(path_to_folder)))


def libsvm_iterator(path_to_folder):
    dataset = load_libsvm_dataset(path_to_folder)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    yield libsvm.decode_libsvm(tf.reshape(next_element, (1,)),
                               num_features=666,
                               dtype=tf.float32,
                               label_dtype=tf.float32)

libsvm_iterator在您指定的文件夹中的多个文件中,在每次迭代中都为您提供了功能标签对。