enter image description here我首先检测到图像的显着性,然后使用抓取算法来分割显着性目标。然而,结果是一个显着的图像,但没有分割显着图。错误如下:错误:-5图像糊状在函数抓取中有cv_8uc3类型,这是我的源代码,我该怎么办?“
import tensorflow as tf
import numpy as np
import os
from scipy import misc
import argparse
import sys,cv2
from skimage.io import imread, imsave
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
g_mean = np.array(([126.88,120.24,112.19])).reshape([1,1,3])
output_folder = "./test_output"
def rgba2rgb(img):
if img.ndim == 2:
img = gray2rgb(img)
elif img.shape[2] == 4:
img = img[:, :, :3]
upper_dim = max(img.shape[:2])
if upper_dim > args.max_dim:
img = rescale(img, args.max_dim/float(upper_dim), order=3)
return img
def largest_contours_rect(saliency):
contours, hierarchy = cv2.findContours(saliency * 3,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
contours = sorted(contours, key = cv2.contourArea)
return cv2.boundingRect(contours[-1])
def refine_saliency_with_grabcut(img, saliency):
rect = largest_contours_rect(saliency)
bgdmodel = np.zeros((1, 65),np.float64)
fgdmodel = np.zeros((1, 65),np.float64)
saliency[np.where(saliency > 0)] = cv2.GC_FGD
mask = saliency
cv2.grabCut(img, mask, rect, bgdmodel, fgdmodel, 1, cv2.GC_INIT_WITH_RECT)
mask = np.where((mask==2)|(mask==0),0,1).astype('uint8')
return mask
def backprojection_saliency(img,args):
saliency =main(args)
#cv2.imshow("original", saliency)
#saliency=mpimg.imread('alpha1.png')
img = cv2.resize(img, (320, 232))
mask = refine_saliency_with_grabcut(img, saliency)
#misc.imsave(os.path.join(output_folder,'flowers2.png'),result)
return mask
def main(args):
if not os.path.exists(output_folder):
os.mkdir(output_folder)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_fraction)
with tf.Session(config=tf.ConfigProto(gpu_options = gpu_options)) as sess:
saver = tf.train.import_meta_graph('./meta_graph/my-model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./salience_model'))
image_batch = tf.get_collection('image_batch')[0]
pred_mattes = tf.get_collection('mask')[0]
if args.rgb_folder:
rgb_pths = os.listdir(args.rgb_folder)
for rgb_pth in rgb_pths:
rgb = misc.imread(os.path.join(args.rgb_folder,rgb_pth))
if rgb.shape[2]==4:
rgb = rgba2rgb(rgb)
origin_shape = rgb.shape
rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)
feed_dict = {image_batch:rgb}
pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
misc.imsave(os.path.join(output_folder,rgb_pth),final_alpha)
else:
rgb = misc.imread(args.rgb)
if rgb.shape[2]==4:
rgb = rgba2rgb(rgb)
origin_shape = rgb.shape[:2]
rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)
feed_dict = {image_batch:rgb}
pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
misc.imsave(os.path.join(output_folder,'alpha.png'),final_alpha)
#rgbs = mpimg.imread('flower1.jpg')
result=refine_saliency_with_grabcut(rgb, final_alpha)
misc.imsave(os.path.join(output_folder,'segmentation.png'),result)
#cv2.imshow("original", final_alpha)
#plt.imshow(final_alpha)
return final_alpha;
def parse_arguments(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--rgb', type=str,
help='input rgb',default = None)
parser.add_argument('--rgb_folder', type=str,
help='input rgb',default = None)
parser.add_argument('--gpu_fraction', type=float,
help='how much gpu is needed, usually 4G is enough',default = 1.0)
return parser.parse_args(argv)
if __name__ == '__main__':
main(parse_arguments(sys.argv[1:]))``