大尺寸输入的Tensorflow估算器图形大小限制

时间:2019-01-04 09:30:58

标签: python tensorflow tensorflow-estimator

我认为我的全部训练数据都存储在达到2GB极限的图形内。如何在估算器API中使用feed_dict?仅供参考,我正在使用tensorflow估算器API来训练模型。

输入功能:

def input_fn(X_train,epochs,batch_size):
''' input X_train is the scipy sparse matrix of large input dimensions(200000) and number of rows=600000'''

X_train_tf = tf.data.Dataset.from_tensor_slices((convert_sparse_matrix_to_sparse_tensor(X_train, tf.float32)))
    X_train_tf = X_train_tf.apply(tf.data.experimental.shuffle_and_repeat(shuffle_to_batch*batch_size, epochs))
    X_train_tf = X_train_tf.batch(batch_size).prefetch(2)
    return X_train_tf

错误:

  

回溯(最近通话最近):文件   “ /tmp/apprunner/.working/runtime/app/ae_python_tf.py”,第259行,在           AE_Regressor.train(lambda:input_fn(X_train,epochs,batch_size),hooks = [time_hist,logging_hook])文件   “ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,   火车上的354号线       损失= self._train_model(input_fn,钩子,save_listeners)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,   第1205行,在_train_model中       返回self._train_model_distributed(input_fn,hook,saving_listeners)文件   “ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,   第1352行,在_train_model_distributed中       Saving_listeners)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,   第1468行,在_train_with_estimator_spec中       log_step_count_steps = log_step_count_steps)作为mon_sess:文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”,   MonitoredTrainingSession中的第504行       stop_grace_period_secs = stop_grace_period_secs)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”,   第921行,在 init       stop_grace_period_secs = stop_grace_period_secs)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”,   第631行, init       h.begin()文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/basic_session_run_hooks.py”,   第543行,开始       self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)文件   “ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer_writer.py”,   第63行,进入       logdir,graph = ops.get_default_graph())文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”,   第367行,在 init       超级(FileWriter,自我)。初始化(事件写入器,图形,graph_def)文件   “ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”,   第83行,初始化       self.add_graph(graph = graph,graph_def = graph_def)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”,   第193行,在add_graph中       true_graph_def = graph.as_graph_def(add_shapes = True)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”,   as_graph_def中的第3124行       结果,_ = self._as_graph_def(from_version,add_shapes)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”,   _as_graph_def中的第3082行       c_api.TF_GraphToGraphDef(self._c_graph,buf)tensorflow.python.framework.errors_impl.InvalidArgumentError:无法   将类型为tensorflow.GraphDef的协议缓冲区序列化为   序列化的大小(2838040852bytes)将大于限制   (2147483647字节)

2 个答案:

答案 0 :(得分:0)

我通常反对逐字引用文档,但这在TF documentation中是逐字解释的,我找不到一种比他们已经做得更好的方法:

  

请注意,[在Dataset.from_tensor_slices()features numpy数组上使用labels将嵌入特征和标签   TensorFlow图中的数组作为tf.constant()操作。这个   对于较小的数据集,效果很好,但浪费了内存-因为   数组的内容将被复制多次-并可以运行到   tf.GraphDef协议缓冲区的2GB限制。

     

作为替代方案,您可以根据以下内容定义数据集:   tf.placeholder()张量,并在您输入NumPy数组时   在数据集上初始化Iterator。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

(代码和文本均来自上面的链接,在与该问题无关的代码中删除了一个assert


更新

如果您尝试将其与Estimator API结合使用,那么您就不走运了。在同一个链接页面上,前面引用的部分上方有几个部分:

  

注意:当前,单次迭代器是唯一可与Estimator一起使用的类型。

正如您在评论中指出的,这是因为Estimator API隐藏了sess.run()调用,而您需要在其中将feed_dict传递给迭代器。

答案 1 :(得分:0)

如果您使用的是估算器,则可以通过 SessionRunHook 执行此操作。