我认为我的全部训练数据都存储在达到2GB极限的图形内。如何在估算器API中使用feed_dict?仅供参考,我正在使用tensorflow估算器API来训练模型。
输入功能:
def input_fn(X_train,epochs,batch_size):
''' input X_train is the scipy sparse matrix of large input dimensions(200000) and number of rows=600000'''
X_train_tf = tf.data.Dataset.from_tensor_slices((convert_sparse_matrix_to_sparse_tensor(X_train, tf.float32)))
X_train_tf = X_train_tf.apply(tf.data.experimental.shuffle_and_repeat(shuffle_to_batch*batch_size, epochs))
X_train_tf = X_train_tf.batch(batch_size).prefetch(2)
return X_train_tf
错误:
回溯(最近通话最近):文件 “ /tmp/apprunner/.working/runtime/app/ae_python_tf.py”,第259行,在 AE_Regressor.train(lambda:input_fn(X_train,epochs,batch_size),hooks = [time_hist,logging_hook])文件 “ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”, 火车上的354号线 损失= self._train_model(input_fn,钩子,save_listeners)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”, 第1205行,在_train_model中 返回self._train_model_distributed(input_fn,hook,saving_listeners)文件 “ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”, 第1352行,在_train_model_distributed中 Saving_listeners)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”, 第1468行,在_train_with_estimator_spec中 log_step_count_steps = log_step_count_steps)作为mon_sess:文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”, MonitoredTrainingSession中的第504行 stop_grace_period_secs = stop_grace_period_secs)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”, 第921行,在 init stop_grace_period_secs = stop_grace_period_secs)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”, 第631行, init h.begin()文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/basic_session_run_hooks.py”, 第543行,开始 self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)文件 “ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer_writer.py”, 第63行,进入 logdir,graph = ops.get_default_graph())文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”, 第367行,在 init 超级(FileWriter,自我)。初始化(事件写入器,图形,graph_def)文件 “ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”, 第83行,初始化 self.add_graph(graph = graph,graph_def = graph_def)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”, 第193行,在add_graph中 true_graph_def = graph.as_graph_def(add_shapes = True)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”, as_graph_def中的第3124行 结果,_ = self._as_graph_def(from_version,add_shapes)文件“ /tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”, _as_graph_def中的第3082行 c_api.TF_GraphToGraphDef(self._c_graph,buf)tensorflow.python.framework.errors_impl.InvalidArgumentError:无法 将类型为tensorflow.GraphDef的协议缓冲区序列化为 序列化的大小(2838040852bytes)将大于限制 (2147483647字节)
答案 0 :(得分:0)
我通常反对逐字引用文档,但这在TF documentation中是逐字解释的,我找不到一种比他们已经做得更好的方法:
请注意,[在
Dataset.from_tensor_slices()
和features
numpy数组上使用labels
将嵌入特征和标签 TensorFlow图中的数组作为tf.constant()操作。这个 对于较小的数据集,效果很好,但浪费了内存-因为 数组的内容将被复制多次-并可以运行到 tf.GraphDef协议缓冲区的2GB限制。作为替代方案,您可以根据以下内容定义数据集: tf.placeholder()张量,并在您输入NumPy数组时 在数据集上初始化Iterator。
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
(代码和文本均来自上面的链接,在与该问题无关的代码中删除了一个assert
)
如果您尝试将其与Estimator API结合使用,那么您就不走运了。在同一个链接页面上,前面引用的部分上方有几个部分:
注意:当前,单次迭代器是唯一可与Estimator一起使用的类型。
正如您在评论中指出的,这是因为Estimator API隐藏了sess.run()
调用,而您需要在其中将feed_dict
传递给迭代器。
答案 1 :(得分:0)
如果您使用的是估算器,则可以通过 SessionRunHook 执行此操作。