tensorlfow中的分发策略

时间:2019-11-20 22:59:16

标签: tensorflow machine-learning

我的伪代码如下:

 entire_data = fetch_and_preprocess_data()
 with tf.Session(config=config) as sess:
     for epoch in range(CONFIG.MAX_EPOCHS):
         for batch_of_data in entire_data:
         preprocessed_batch = preprocess_data(batch_of_data)
         loss, output = sess.run([loss_tr, outputs_tr, train_op_tr], feed_dict= 
         get_feed_dict(preprocessed_data))

这可以在单个GPU上完美运行。但是,我想将其扩展到多个GPU。因此,我决定使用tensorlflow提供的一些分发策略。我尝试使用“ MirroredStrategy”对其进行并行化,并在多个GPU上进行训练。我对上面的代码做了一些修改:

entire_data = fetch_and_preprocess_data()
strategy = tf.contrib.distribute.MirroredStrategy()
with strategy.scope():
#The entire code from above except the entire_data = fetch_and_preprocess_data() part

我在4个GPU上运行了这段代码,并希望这可以在所有GPU上进行培训。但是,我看到只有第一个GPU内存被使用,其余GPU的内存却微不足道。如果有人可以帮助我,那就太好了!

0 个答案:

没有答案