1. 简介
我在tensorflow 渴望执行模式中编写了一个简单的mnist 3D CNN分类器。
我只用shape = (D,W,H,C)
(即(1, 28, 28, 10)
)将28x28像素的灰度图像重复10次到新的4D图像中,然后使用tf.dataset.TFRecord
将它们包装成{{1} }。
我将60000个mnist样本分为tarin_dataset和val_dataset,我的3D CNN Keras模型可以分别在 ONE GPU上对其进行训练和验证。
2. 问题:
我的服务器有10个GPU。
1)当我使用单个GPU训练Keras模型时,shape = (batch_size, 1, 28, 28, 10)
是〜80%,而ECC GPU-Util Compute M.
是〜3.4Gb )
2)当我Memory-Usage
让我的模型在多GPU(例如GPU 0和1)上训练时,nvidia-smi显示只有GPU 0有效({{ 1}}是〜60%,multi_gpu_model()
是〜3.4Gb ),但对于GPU 1,Volatile Uncorr. ECC GPU-Util Compute M.
是 0%< / strong>和〜 137Mb Memory-Usage
。就像图片1一样。
我已经搜索过很多次,但是没有找到解决方案,我的代码有问题吗?
期待任何建议。 ^ ^
3. 我的代码
我的main.py如下
Volatile Uncorr. ECC GPU-Util Compute M.
Memory-Usage
的输出是
import tensorflow as tf
import yaml
import os
from sample.models.CNN_3D import get_model
from yaml import CLoader as Loader
from sample.dataset import load_tf_records
# import matplotlib.pyplot as plt
import tensorflow.contrib.eager as tfe
# import numpy as np
from tensorflow.python.keras.utils import multi_gpu_model
# from sklearn.metrics import confusion_matrix
tf.enable_eager_execution() # start eager mode
tf.executing_eagerly()
cfg = yaml.load(
open(os.path.join(os.path.abspath(os.path.join(os.getcwd(), "../config/mnist_config.yml")))),
Loader=Loader
)["DATASET"][0]
print("cfg =\n", cfg)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # ignore warning
os.environ["CUDA_VISIBLE_DEVICES"] = cfg["DEVICES_IDS"]
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
tf.keras.backend.set_session(session)
def loss(net, inputs, gts):
"""
Get loss value
:param net: model
:param inputs: a batch of input tensor
:param gts: a batch of ground truth labels
:return: loss value
"""
return tf.reduce_sum(
tf.nn.softmax_cross_entropy_with_logits_v2(
logits=net(inputs),
labels=gts
)
)
def train_step(loss_f, net, opt, x, y):
"""
Perform a single step of optimization
:param loss_f: loss function
:param net: network model
:param opt: optimizer
:param x: a batch of input tensor
:param y: a batch of ground truth labels
:return:
"""
opt.minimize(
lambda: loss_f(net, x, y),
global_step=tf.train.get_or_create_global_step()
)
if __name__ == '__main__':
# TODO save/load model; add AUC; add early stop
train_dataset, val_dataset = load_tf_records()
if len(cfg["DEVICES_IDS"].split(",")) > 1: # use multi_GPUs
with tf.device('/cpu:0'):
train_model = get_model(summary=True, data_format=cfg["DATA_FORMAT"])
train_model = multi_gpu_model(train_model, gpus=len(cfg["DEVICES_IDS"].split(",")))
else: # use single GPU
train_model = get_model(summary=True, data_format=cfg["DATA_FORMAT"])
loss_fn = tf.nn.softmax_cross_entropy_with_logits_v2
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
# use TensorBoard
tb_path = os.path.join(cfg["ROOT"], cfg["TB_FOLDER"])
if not os.path.exists(tb_path):
os.makedirs(tb_path)
writer = tf.contrib.summary.create_file_writer(tb_path)
global_step = tf.train.get_or_create_global_step()
with writer.as_default(), tf.contrib.summary.always_record_summaries():
# Loop over the epochs
for epoch in range(cfg["EPOCHS"]):
# Initialize the metric
train_acc = tfe.metrics.Accuracy(name="train_acc")
val_acc = tfe.metrics.Accuracy(name="val_acc")
for xb, yb in tfe.Iterator(train_dataset.batch(cfg["BATCH_SIZE"])):
# Save the loss on disk
tf.contrib.summary.scalar("train_loss", loss(train_model, xb, yb))
# Make a training step
train_step(loss, train_model, optimizer, xb, yb)
train_acc(tf.argmax(train_model(tf.constant(xb)), axis=1), tf.argmax(tf.constant(yb), axis=1))
train_acc.result(write_summary=True)
if (global_step.numpy() + 1) % 10 == 0:
break # TODO need to remove
for xb, yb in tfe.Iterator(val_dataset.batch(cfg["BATCH_SIZE"])):
tf.contrib.summary.scalar("val_loss", loss(train_model, xb, yb))
# Save the validation accuracy on the batch
val_acc(tf.argmax(train_model(tf.constant(xb)), axis=1), tf.argmax(tf.constant(yb), axis=1))
val_acc.result(write_summary=True)
break # TODO need to remove
# Save the overall accuracy in our vector
# acc_history[epoch] = accuracy.result().numpy()
# # At the end, plot the evolution of the training accuracy
# plt.figure()
# plt.plot(acc_history)
# plt.xlabel('Epoch')
# plt.ylabel('Accuracy')
# plt.show()