张量随时间推移缓慢的图像分类推理速度

时间:2018-07-12 12:23:09

标签: python tensorflow image-processing machine-learning computer-vision

我正在为自己的用例重新训练tensorflow image classifier。我必须通过分类器传递大量图像(超过500k),我已经相应地修改了推理脚本,以便图形文件仅加载一次。

最初的推理速度非常高(10张图像/秒),但随着时间的推移,推理速度会逐渐降低并达到大约1张图像/秒。时间在图像加载部分和前进通道处也增加。我一直都有足够的CPU和GPU

用于调用图像读取功能并进行推断的代码段

with tf.Session(graph=graph) as sess:
    while True:
        a = []
        try:
            files = os.listdir(folder_name)

            for f in files:
                try:
                    #CALLING IMAGE READ FUNCTION 
                    t = sess.run(read_tensor_from_image_file(
                        (folder_name+"/"+f),
                        input_height=input_height,
                        input_width=input_width,
                        input_mean=input_mean,
                        input_std=input_std))

                    #GETTING INFERENCE
                    results = sess.run(output_operation.outputs[0], {
                        input_operation.outputs[0]: t
                    })



                    results = np.squeeze(results)

                    top_k = results.argsort()[-5:][::-1]

读取图像的代码

def read_tensor_from_image_file(file_name, input_height=299, input_width=299, input_mean=0, input_std=255):
    input_name = "file_reader"
    output_name = "normalized"
    file_reader = tf.read_file(file_name, input_name)
    if file_name.endswith(".png"):
        image_reader = tf.image.decode_png(file_reader, channels=3, name="png_reader")
    elif file_name.endswith(".gif"):
        image_reader = tf.squeeze(tf.image.decode_gif(file_reader, name="gif_reader"))
    elif file_name.endswith(".bmp"):
        image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader")
    else:
        image_reader = tf.image.decode_jpeg(file_reader, channels=3, name="jpeg_reader")
    float_caster = tf.cast(image_reader, tf.float32)
    dims_expander = tf.expand_dims(float_caster, 0)
    resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
    normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])


    return normalized

0 个答案:

没有答案