我在tensorflow中训练卷积模型。经过大约70个时期的训练模型,花了将近1.5小时,我无法保存模型。它给了我ValueError: GraphDef cannot be larger than 2GB
。我发现随着训练的进行,我的图表中的节点数量不断增加。
在时期0,3,6,9,图中的节点数分别为7214,7238,7262,7286。当我使用with tf.Session() as sess:
时,而不是将会话作为sess = tf.Session()
传递,节点数分别为3982,4006,4030,4054,分别位于0,3,6,9的时期。
在this回答中,据说当节点添加到图表时,它可能超过其最大大小。我需要帮助了解节点数量如何在我的图表中继续上升。
我使用以下代码训练我的模型:
def runModel(data):
'''
Defines cost, optimizer functions, and runs the graph
'''
X, y,keep_prob = modelInputs((755, 567, 1),4)
logits = cnnModel(X,keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y), name="cost")
optimizer = tf.train.AdamOptimizer(.0001).minimize(cost)
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1), name="correct_pred")
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
for e in range(12):
batch_x, batch_y = data.next_batch(30)
x = tf.reshape(batch_x, [30, 755, 567, 1]).eval(session=sess)
batch_y = tf.one_hot(batch_y,4).eval(session=sess)
sess.run(optimizer, feed_dict={X: x, y: batch_y,keep_prob:0.5})
if e%3==0:
n = len([n.name for n in tf.get_default_graph().as_graph_def().node])
print("No.of nodes: ",n,"\n")
current_cost = sess.run(cost, feed_dict={X: x, y: batch_y,keep_prob:1.0})
acc = sess.run(accuracy, feed_dict={X: x, y: batch_y,keep_prob:1.0})
print("At epoch {epoch:>3d}, cost is {a:>10.4f}, accuracy is {b:>8.5f}".format(epoch=e, a=current_cost, b=acc))
导致节点数量增加的原因是什么?
答案 0 :(得分:2)
您正在训练循环中创建新节点。特别是,您正在调用tf.reshape
和tf.one_hot
,每个节点都会创建一个(或多个)节点。你可以:
我建议使用第二个,因为使用TensorFlow进行数据准备似乎没有任何好处。你可以有类似的东西:
import numpy as np
# ...
x = np.reshape(batch_x, [30, 755, 567, 1])
# ...
# One way of doing one-hot encoding with NumPy
classes_arr = np.arange(4).reshape([1] * batch_y.ndims + [-1])
batch_y = (np.expand_dims(batch_y, -1) == classes_arr).astype(batch_y.dtype)
# ...
PD:我还建议在tf.Session()
context manager中使用with
,以确保在最后调用close()
方法(除非您想要继续使用相同的会话。)
答案 1 :(得分:0)
另一个为我解决了类似问题的方法是使用tf.reset_default_graph()