Tensorflow:内存在尝试将numpy稀疏矩阵加载到input_fn时出错

时间:2017-10-31 05:44:23

标签: tensorflow

我正在构建一个文本分类模型,并构建了一个形状为(81062,100000)的大型稀疏矩阵。

input_fn函数定义为:

# Define the input function for training
input_fn = tf.estimator.inputs.numpy_input_fn(
    x={'tfidf': X_train_tfidf.todense()}, y=y_train.values,
    batch_size=batch_size, num_epochs=None, shuffle=True)

当我尝试执行它时,它会给我以下错误:

MemoryError                               Traceback (most recent call last)

然后我尝试使用data.Dataset模块构建一个input_fn:

def input_fn():
    dataset = tf.contrib.data.Dataset.from_sparse_tensor_slices((X_train_tfidf, y_train.values))
    dataset = dataset.repeat().shuffle(buff).batch(batch_size)
    x, y = dataset.make_one_shot_iterator().get_next()
    return x, y

然而,它给了我以下信息:

TypeError: `sparse_tensor` must be a `tf.SparseTensor` object.

基本上我想要做的是使用来自numpy稀疏矩阵的SGD将较小批量的训练数据提供给神经网络。但我找不到正确的方法。

有人可以帮忙吗?

1 个答案:

答案 0 :(得分:1)

TypeError表示from_sparse_tensor_slices要求其输入为tf.SparseTensor的实例。请参阅:https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_sparse_tensor_slices

将训练矩阵与标签一起打包到一个SparseTensor中可以解决问题。