我正在训练CNN,我相信我对sess.run()
的使用导致我的训练非常缓慢。
实质上,我使用的是mnist
数据集......
from tensorflow.examples.tutorials.mnist import input_data
...
...
features = input_data.read_data_sets("/tmp/data/", one_hot=True)
问题是,CNN的第一层必须接受[batch_size, 28, 28, 1]
形式的图像,这意味着我必须先将每张图像转换为CNN。
我用我的脚本执行以下操作......
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, 10])
...
...
with tf.Session() as sess:
for epoch in range(25):
total_batch = int(features.train.num_examples/500)
avg_cost = 0
for i in range(total_batch):
batch_xs, batch_ys = features.train.next_batch(10)
# Notice this line.
_, c = sess.run([train_op, loss], feed_dict={x:sess.run(tf.reshape(batch_xs, [10, 28, 28, 1])), y:batch_ys})
avg_cost += c / total_batch
if (epoch + 1) % 1 == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))
注意注释行。我正在从训练集中获取第一批,我正在重塑为正确的格式[batch_size, 28, 28, 1]
。我每次都要拨打sess.run()
,我相信这是训练速度如此之慢的原因。
我该如何防止这种情况发生。我尝试使用numpy
在另一个脚本中重新格式化数据,但它仍然给我带来了问题,因为我无法在不运行numpy
的情况下提供sess.run()
数组。有人可以告诉我如何在训练课程之外格式化数据吗?也许我可以在另一个脚本中格式化数据并将其加载到包含我的CNN的那个?
答案 0 :(得分:2)
在每次迭代中你绝对不应该在新的操作上有内部sess.run()
(虽然我不确定它真的减慢了多少)。你应该做以下其中一个:
[None, 28*28*1]
,后跟tf.reshape([None, 28, 28, 1])
,位于您网络的开头(而不是tf.placeholder([None, 28, 28, 1])
)OR
_, c = sess.run([train_op, loss], feed_dict={x:batch_xs.reshape( [-1, 28, 28, 1]), y:batch_ys})
如果你只是写_, c = sess.run([train_op, loss], feed_dict={x:tf.reshape(batch_xs, [10, 28, 28, 1]), y:batch_ys})
,它可能也有效,但是不那样做,因为它会在每次迭代时在你的图形中创建一个新的op。
答案 1 :(得分:1)
您可以做的另一件事是重新设置开头本身的所有输入,然后将其提供给占位符。
import math
import numpy as np
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
y = tf.placeholder(tf.float32, [None, 10])
...
...
with tf.Session() as sess:
X_train=mnist.train.images.reshape(-1,28,28,1)
y_train=mnist.train.labels
train_indicies = np.arange(X_train.shape[0])
num_epochs = 25 // number of epochs
batch_size = 50
total_batch = int(math.ceil(X_train.shape[0]/batch_size))
for epoch in range(25):
for i in np.arange(total_batch):
start_idx = (i*batch_size)%X_train.shape[0]
idx = train_indicies[start_idx:start_idx+batch_size]
_, c = sess.run([train_op, loss], feed_dict={x:X_train[idx,:], y:y_train[idx]})
avg_cost += c / total_batch
if (epoch + 1) % 1 == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))
因为我们无法使用mnist.train.next_batch,所以我们需要手动计算和增加索引。
希望这有效:)