我有一个包含4644幅彩色图像的图像数据集,我将其重塑为大小为50 x 50的色块并传递到我的深度神经网络。
在洗牌操作中缓冲区大小= 10000是否足以传递到网络之前,还是在369765补丁中还有其他有效的洗牌方法?
我遵循的步骤: 1.创建一个单一的tf记录,存储所有4644张图像。 2.使用tf.data管道解码每个图像并从中创建补丁。 3.每10000个补丁洗牌并传递到网络。
这是我正在使用的代码: 我正在使用buffer_size = 10000,parallel_calls = 4
dataset = (tf.data.TFRecordDataset( tfrecords_filename_image )
.repeat( no_epochs )
.map( read_and_decode, num_parallel_calls=num_parallel_calls )
.map( get_patches_fn, num_parallel_calls=num_parallel_calls )
.apply( tf.data.experimental.unbatch()) # unbatch the patches we just produced
.shuffle( buffer_size=buffer_size, seed=random_number_1 )
.batch( batch_size )
.prefetch( 1 )
get_patches_function definition:
get_patches_fn = lambda image: get_patches( image, patch_size=patch_size )
def get_patches( image, patch_size=16 ):
# Function to compute patches for given image
# Input- image - Image which has to be converted to patches
# patch_size- size of each patch
# Output-patches of image(4d Tensor)
# with tf.device('/cpu:0'):
pad = [ [ 0, 0 ], [ 0, 0 ] ]
patches_image = tf.space_to_batch_nd( [ image ], [ patch_size, patch_size ], pad )
patches_image = tf.split( patches_image, patch_size * patch_size, 0 )
patches_image = tf.stack( patches_image, 3 )
patches_image = tf.reshape( patches_image, [ -1, patch_size, patch_size, 3 ] )
return patches_image
read and decode function definition:
def read_and_decode( tf_record_file ):
# Function to read the tensorflow record and return image suitable for patching
# Input: tf_record_file - tf record file in which image can be extracted
# Output: Image
features = {
'height': tf.FixedLenFeature( [ ], tf.int64 ),
'width': tf.FixedLenFeature( [ ], tf.int64 ),
'image_raw': tf.FixedLenFeature( [ ], tf.string )
parsed = tf.parse_single_example( tf_record_file, features )
image = tf.decode_raw( parsed[ 'image_raw' ], tf.uint8 )
height = tf.cast( parsed[ 'height' ], tf.int32 )
width = tf.cast( parsed[ 'width' ], tf.int32 )
image_shape = tf.stack( [ height, width, -1 ] )
image = tf.reshape( image, image_shape )
image = image[ :, :, :3 ]
image = tf.cast( image, tf.float32 )
return image
答案 0 :(得分:1)