我通过跟踪和调整tensorflow教程,为我的回归问题设计了一个tensorflow神经网络。但是,由于我的问题结构(约300,000个数据点和使用昂贵的FTRLOptimizer),我的问题花了太长时间才能执行,即使使用我的32 CPU机器(我也没有GPU)。
根据this comment和 htop 的快速确认,似乎我有一些单线程操作,它应该是feed_dict。
因此,正如here所述,我尝试使用队列来多线程化我的程序。
我写了一个带队列的简单代码文件来训练模型如下:
import numpy as np
import tensorflow as tf
import threading
#Function for enqueueing in parallel my data
def enqueue_thread():
sess.run(enqueue_op, feed_dict={x_batch_enqueue: x, y_batch_enqueue: y})
#Set the number of couples (x, y) I use for "training" my model
BATCH_SIZE = 5
#Generate my data where y=x+1+little_noise
x = np.random.randn(10, 1).astype('float32')
y = x+1+np.random.randn(10, 1)/100
#Create the variables for my model y = x*W+b, then W and b should both converge to 1.
W = tf.get_variable('W', shape=[1, 1], dtype='float32')
b = tf.get_variable('b', shape=[1, 1], dtype='float32')
#Prepare the placeholdeers for enqueueing
x_batch_enqueue = tf.placeholder(tf.float32, shape=[None, 1])
y_batch_enqueue = tf.placeholder(tf.float32, shape=[None, 1])
#Create the queue
q = tf.RandomShuffleQueue(capacity=2**20, min_after_dequeue=BATCH_SIZE, dtypes=[tf.float32, tf.float32], seed=12, shapes=[[1], [1]])
#Enqueue operation
enqueue_op = q.enqueue_many([x_batch_enqueue, y_batch_enqueue])
#Dequeue operation
x_batch, y_batch = q.dequeue_many(BATCH_SIZE)
#Prediction with linear model + bias
y_pred=tf.add(tf.mul(x_batch, W), b)
#MAE cost function
cost = tf.reduce_mean(tf.abs(y_batch-y_pred))
learning_rate = 1e-3
train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
available_threads = 1024
#Feed the queue
for i in range(available_threads):
threading.Thread(target=enqueue_thread).start()
#Train the model
for step in range(1000):
_, cost_step = sess.run([train_op, cost])
print(cost_step)
Wf=sess.run(W)
bf=sess.run(b)
这段代码不起作用,因为每次调用x_batch时,一个y_batch也会出列,反之亦然。然后,我不会将这些功能与相应的"结果"。
进行比较有没有一种简单的方法可以避免这个问题?
答案 0 :(得分:1)
我的错误,一切正常。 我被误导了,因为我在算法的每一步估计了我在不同批次上的表现,也因为我的模型对于虚拟模型来说太复杂了(我应该有y = W * x或y = x + b)。 然后,当我尝试在控制台中打印时,我在不同的变量上多次sess.run,并且得到了明显不一致的结果。
答案 1 :(得分:0)
尽管如此,您的问题已经解决,希望向您展示代码中的低效率。创建RandomShuffleQueue时,您指定了capacity=2**20
。在所有队列中capacity:
可存储在其中的元素数量的上限 队列中。
队列将尝试在队列中放置尽可能多的元素,直到它达到此限制。所有这些元素都在吃你的RAM。如果每个元素只包含1个字节,那么您的队列将占用1Mb的数据。如果队列中有10Kb图像,则会占用10Gb的RAM。
这非常浪费,特别是因为队列中从不需要这么多元素。所有你需要确保你的队列永远不会是空的。因此,找到一个合理的队列容量,不要使用大数字。