我正在构建一个使用SGD优化的Deeplearning模型。我有一个python线程队列,我在其中持有10个小批量。当我调用梯度下降步骤函数时,我会弹出一个小批量并将其发送给它。但这需要很长时间,我不知道为什么。我正在尝试将队列的引用传递给函数。以下是代码的小片段,展示了我是如何做到的。开销几乎比其他操作慢100倍。
class DataLoaderThread(threading.Thread):
def __init__():
threading.Thread.__init__(self)
self.queue = queue
***code block not mentioned, not important***
if self.ae:
self.queue.put(train_features1)
else:
self.queue.put(train_features1)
#####################################################################
the main function:
def main():
start1=time.time()
**q = Queue.Queue(maxsize=10)**
data_loader = DataLoaderThread(q, data_file_list, batch_size, is_training=(not test), ae=(not is_full_model))
data_loader.daemon = True
data_loader.start()
print 'time for dataloader', (time.time() - start1)
**do_training(q, numbertotal_samples, batch_size, weights_share)**
data_loader.terminate()
data_loader.join()
return 0
###############################################
the training function:
**def do_training**(**data_q**,batch_size, weights_share):
model = gatedAE.FactoredGatedAutoencoder(**necessary args**)
learningrate = 0.00005
trainer = GraddescentMinibatch(model, batch_size, learningrate)
print 'time for first portion of do training', (time.time() - start_train)
try:
while True:
# TRAIN MODEL
**cost=trainer.step(data_q)**
loss+= cost
print 'cost_at_evry_batch',cost
epoch += 1
if epoch % progress_report == 0:
loss /= progress_report
print '%d\t%g' % (epoch, loss)
sys.stdout.flush()
loss = 0
##############################################
Finally just the step function from gradient descent:
**def step**(self,data_q):
time_step=time.time()
**val=data_q.get()
print 'time outside set_sharedval in step :', (time.time() - time_step)**
time_sharedval=time.time()
cost=self.set_shareddata(val)
print 'time inside sharedval in step :', (time.time() - time_sharedval)
#cost = (1.0-1.0/stepcount)*cost + (1.0/stepcount)*self._updateincs(np.asarray(val))
#ipdb.set_trace()
self._trainmodel(0)
self.model.normalizefilters()
return cost
########################################################################
步骤中set_sharedval外的打印时间:',(time.time() - time_step)来自step函数的语句输出执行这部分代码所需的200秒,除了加载队列和弹出之外什么都没有它的一个值让我觉得我构造和传递数据加载器的方式有问题。函数之间所有必要的代码流以粗体显示,便于阅读。