正如您在下面的代码中看到的那样,我正在尝试使用Tensorflow数据集在Tensorflow上训练一个简单的模型。数据集非常庞大,我需要进行填充,重复和批处理,以便为训练我的模型做一个随机的梯度下降。
但是我可以观察到优化步骤的周期开销(在我的代码中是sess.run(train))。
正如你在这里看到的,每5个步骤,它需要3s而不是0.5来进行优化。
步骤105持续时间:3.5233473777770996
步骤106持续时间:0.5653283596038818
步骤107持续时间:0.5391891002655029
步骤108持续时间:0.5480048656463623
步骤109持续时间:0.0415492057800293
步骤110持续时间:3.032115936279297
步骤111持续时间:0.5407207012176514
步骤112持续时间:0.5276811122894287
步骤113持续时间:0.5448746681213379
步骤114持续时间:0.04253268241882324
步骤115持续时间:3.1273345947265625
此外,我的GPU几乎一直处于0%的利用率,大约有90%的内存使用。
当Iterator完成查看所有数据集时,似乎这个开销就到了。
我在Ubuntu 16.04上使用带有Tensorflow 1.4的Python 3.6。
你知道如何加快训练速度吗?
最佳,
import tensorflow as tf
import numpy as np
import os, time, multiprocessing
import matplotlib.pyplot as plt
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value.reshape(-1)))
def parser(record):
num_features = 2000
size_group = 300
num_classes= 10
class_indice = 0
keys_to_features={
'X': tf.FixedLenFeature([size_group*num_features],tf.float32),
'label' : tf.FixedLenFeature([num_classes],tf.float32)}
parsed = tf.parse_single_example(record, keys_to_features)
label = parsed['label']
label = tf.slice(label,[class_indice],[1])
label = tf.squeeze(label) # To get a vector one dimension
X = parsed['X']
X= tf.reshape(X, [size_group,num_features])
return X, label
def test_train_w_dataset():
# Definition of the size
num_features = 2000
num_ex = 2000
size_group = 300
num_classes = 10
batch_size= 480
max_iters = 300
buffer_size = 10000
# Creation of the Dataset
filename_tfrecords = 'tmp.tfrecords'
if not(os.path.isfile(filename_tfrecords)): # If the file doesn't exist we will create it
print("Start creating the Dataset")
writer = tf.python_io.TFRecordWriter(filename_tfrecords)
for i in range(num_ex):
if i % 1000 == 0: print("Step :",i)
X = np.random.normal(size=(size_group,num_features))
vectors = 2*np.random.randint(0,2,(num_classes,1))-1
features=tf.train.Features(feature={
'X': _floats_feature(X),
'label' : _floats_feature(vectors)})
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
writer.close()
else:
print("The dataset tfrecords already exist")
train_dataset = tf.data.TFRecordDataset(filename_tfrecords)
num_proc = multiprocessing.cpu_count()
train_dataset = train_dataset.map(parser,
num_parallel_calls=num_proc)
dataset_shuffle = train_dataset.shuffle(buffer_size=buffer_size,
reshuffle_each_iteration=True)
dataset_shuffle = dataset_shuffle.batch(batch_size)
dataset_shuffle = dataset_shuffle.repeat()
dataset_shuffle = dataset_shuffle.prefetch(batch_size)
shuffle_iterator = dataset_shuffle.make_initializable_iterator()
X_, y_ = shuffle_iterator.get_next()
W=tf.Variable(tf.random_normal([num_features], stddev=1.),name="weights")
W=tf.reshape(W,(1,1,num_features))
Prod=tf.reduce_sum(tf.multiply(W,X_),axis=2)
Max=tf.reduce_max(Prod,axis=1)
Tan= tf.reduce_sum(tf.multiply(tf.tanh(Max),y_))
loss= tf.add(Tan,tf.reduce_sum(tf.multiply(W,W)))
LR = 0.01
restarts = 1
optimizer = tf.train.GradientDescentOptimizer(LR)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
train = optimizer.minimize(loss)
print("The graph is defined")
sess = tf.Session(config=config)
durationTab = []
for essai in range(restarts+1):
# To do need to reinitialiszed
t0 = time.time()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(shuffle_iterator.initializer)
t1 = time.time()
duration = t1 - t0
print('Duration of initialization : ',duration)
for step in range(max_iters):
t0 = time.time()
sess.run(train)
t1 = time.time()
duration = t1 - t0
print("Step ",str(step),' duration : ',duration)
durationTab += [duration]
plt.plot(durationTab)
plt.ylabel('Duration')
plt.xlabel('Iteration')
plt.show()
if __name__ == '__main__':
test_train_w_dataset()
答案 0 :(得分:0)
对于GPU利用率,请确保使用gpu优化二进制文件。检查操作位置(例如,在tensorboard中)。强制在gpu上放置操作(参见tf.device)。
对于周期性尖峰,可能有以下几个原因:
由于很多原因与RAM有关,你应该尝试一个较小的模型(较小的批次,较少的层,较少的节点/层),看看它是否消失。如果确实如此,那么你需要出去买更多的RAM。
答案 1 :(得分:0)
似乎在批处理和重复功能之间添加dataset_shuffle = dataset_shuffle.cache()会消除这些周期性开销。不过,我不确定使用此命令是否完全读取了数据集。