我在我的GPU实例中一直使用Tensorflow版本1.12,我大约有130个TfRecords文件,其中包含120万张ImageNet数据。首先,我应用了地图函数,然后应用flat_map
来扩展数据集,最终将得到120万x 2048张图像。
self.filenames = tf.placeholder(tf.string, shape=[None])
self.eval_filenames = tf.placeholder(tf.string, shape = [None])
dataset = tf.data.TFRecordDataset(self.filenames)
eval_dataset = tf.data.TFRecordDataset(self.eval_filenames)
print("inside dataset ", dataset.output_shapes)
dataset = dataset.map(self.decode, num_parallel_calls=10)
dataset = dataset.flat_map(self.apply_flip_crop)
dataset = dataset.batch(self.config["batch_size"])
dataset = dataset.prefetch(2)
iterator = dataset.make_initializable_iterator()
此处解码函数返回图像的展平数组和一键编码的标签。但是,在flat_map
中传递的函数做了很重的工作,就像:两个循环来创建切片并将它们反转以产生1024个张量。单个图像的最终输出将是[2048, 224, 224, 3]
张量。该函数如下所示:
def apply_flip_crop(self, tf_example, lable):
"""
Calls a helper function random_crop flips which randomly crops and flips
the images, and returns the agumented tensors.
Parameters
----------
:param tf_example: A tensor of shape [batchsize, flattedimageshape]
:type tf_example: Tensors [batchsize, flattedimageshape]
:param lable: A Constant integer representing the class_id of the image.
:type lable: tf.int32
:return: Tensors of shape [flattedimageshape], label of image tf.int32
:rtype: Tensors
"""
data = tf.reshape(tf_example, [256, 256, 3])
data = self.random_crop_flip(data)
lables = [lable for i in range(2048)]
return tf.data.Dataset.from_tensor_slices((data, lables))
def random_crop_flip(self, image):
"""
Apply random crop and random flip to the image tensor.
Parameters
----------
:param image: A tensor representing a flattened image array.
:type image: Tensor of shape [imageflattenedarray]
:return: List of 2048 tensors of shape [imageflattenedarray]
:rtype: List
"""
crops = []
for i in range(256 - 224):
for j in range(256 - 224):
crop = tf.slice(image, [i, j, 0], [224, 224, 3])
crop2 = tf.reverse(crop, axis=[1])
crops.append(crop)
crops.append(crop2)
return crops
现在的问题是训练过程非常缓慢。我已经读过dataset.from_tensor_slices
对于这种需要非常不好。但是我认为有很多地方可以改进。为此,我需要可视化每个操作的性能。主要是flat_map
函数。
我正在使用这样的张量流的运行时统计:
sess.run(iterator.initializer, feed_dict={data_gen.filenames:
training_filenames},
options=run_options, run_metadata=run_metadata)
next_element = iterator.get_next()
for i in range(1):
datapoint = sess.run(next_element, options=run_options,
run_metadata=run_metadata)
summary_writer.add_run_metadata(run_metadata, 'step%d' % i)
哪个记录了准备数据集所花费的时间,但是没有记录执行flat_map
操作所花费的时间,我怀疑这是我关心的,那是性能所在的地方滞后的。
感谢您在性能建议以及 flat_map函数所用时间的度量方面的帮助。
先谢谢了。