使用抓取算法分离显着区域

时间:2017-11-27 02:32:05

标签: python opencv

enter image description here我首先检测到图像的显着性,然后使用抓取算法来分割显着性目标。然而,结果是一个显着的图像,但没有分割显着图。错误如下:错误:-5图像糊状在函数抓取中有cv_8uc3类型,这是我的源代码,我该怎么办?“

    import tensorflow as tf
    import numpy as np
    import os
    from scipy import misc
    import argparse
    import sys,cv2
    from skimage.io import imread, imsave
    import matplotlib.pyplot as plt 
    import matplotlib.image as mpimg 

g_mean = np.array(([126.88,120.24,112.19])).reshape([1,1,3])
output_folder = "./test_output"

def rgba2rgb(img):
         if img.ndim == 2:
            img = gray2rgb(img)
         elif img.shape[2] == 4:
            img = img[:, :, :3]
         upper_dim = max(img.shape[:2])
         if upper_dim > args.max_dim:
            img = rescale(img, args.max_dim/float(upper_dim), order=3)
     return img

def largest_contours_rect(saliency):
    contours, hierarchy = cv2.findContours(saliency * 3,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
    contours = sorted(contours, key = cv2.contourArea)
    return cv2.boundingRect(contours[-1])

def refine_saliency_with_grabcut(img, saliency):
    rect = largest_contours_rect(saliency)
    bgdmodel = np.zeros((1, 65),np.float64)
    fgdmodel = np.zeros((1, 65),np.float64)
    saliency[np.where(saliency > 0)] = cv2.GC_FGD
    mask = saliency
    cv2.grabCut(img, mask, rect, bgdmodel, fgdmodel, 1, cv2.GC_INIT_WITH_RECT)
    mask = np.where((mask==2)|(mask==0),0,1).astype('uint8')
    return mask

def backprojection_saliency(img,args):  
        saliency =main(args)
        #cv2.imshow("original", saliency)
        #saliency=mpimg.imread('alpha1.png')
        img = cv2.resize(img, (320, 232))
    mask = refine_saliency_with_grabcut(img, saliency)
        #misc.imsave(os.path.join(output_folder,'flowers2.png'),result)
    return mask

def main(args):

    if not os.path.exists(output_folder):
        os.mkdir(output_folder) 

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_fraction)
        with tf.Session(config=tf.ConfigProto(gpu_options = gpu_options)) as sess:
        saver = tf.train.import_meta_graph('./meta_graph/my-model.meta')
        saver.restore(sess,tf.train.latest_checkpoint('./salience_model'))
        image_batch = tf.get_collection('image_batch')[0]
        pred_mattes = tf.get_collection('mask')[0]

        if args.rgb_folder:
            rgb_pths = os.listdir(args.rgb_folder)
            for rgb_pth in rgb_pths:
                rgb = misc.imread(os.path.join(args.rgb_folder,rgb_pth))
                if rgb.shape[2]==4:
                    rgb = rgba2rgb(rgb)
                origin_shape = rgb.shape
                rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)

                feed_dict = {image_batch:rgb}
                pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
                final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
                misc.imsave(os.path.join(output_folder,rgb_pth),final_alpha)
        else:
            rgb = misc.imread(args.rgb)
            if rgb.shape[2]==4:
                rgb = rgba2rgb(rgb)
            origin_shape = rgb.shape[:2]
            rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)

            feed_dict = {image_batch:rgb}
            pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
            final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
            misc.imsave(os.path.join(output_folder,'alpha.png'),final_alpha)
                        #rgbs = mpimg.imread('flower1.jpg')
                        result=refine_saliency_with_grabcut(rgb, final_alpha)
                        misc.imsave(os.path.join(output_folder,'segmentation.png'),result)
                        #cv2.imshow("original", final_alpha)
                        #plt.imshow(final_alpha)                       
        return final_alpha;

def parse_arguments(argv):
    parser = argparse.ArgumentParser()

    parser.add_argument('--rgb', type=str,
        help='input rgb',default = None)
    parser.add_argument('--rgb_folder', type=str,
        help='input rgb',default = None)
    parser.add_argument('--gpu_fraction', type=float,
        help='how much gpu is needed, usually 4G is enough',default = 1.0)
    return parser.parse_args(argv)


if __name__ == '__main__':
        main(parse_arguments(sys.argv[1:]))``

0 个答案:

没有答案