'tuple'对象没有属性'gpu_fraction'

时间:2019-12-02 10:35:05

标签: python tensorflow

我现在正在使用colab复制我的显着性检测! 我是学生,所以请理解我的知识还不够, 我使用tensorflow找到了代码,因此我试图使用该代码来重现项目。 但是,作者说代码是在tensorflow 1.00上编写的,但是如果我只是从colab tensorflow起,我就不知道import tensorflow as tf的版本。 我收到错误

  

'tuple'对象没有属性'gpu_fraction'

  

模块'tensorflow'没有属性'GPUOptions'

这是我的源代码,请查看我的问题

import tensorflow as tf

import numpy as np

import os

from scipy import misc

import argparse

import sys


g_mean = np.array(([126.88,120.24,112.19])).reshape([1,1,3])

output_folder = "./test_output"

def rgba2rgb(img):
    return img[:,:,:3]*np.expand_dims(img[:,:,3],2)

def main(args):

    if not os.path.exists(output_folder):
        os.mkdir(output_folder) 

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_fraction)
    with tf.Session(config=tf.ConfigProto(gpu_options = gpu_options)) as sess:
        saver = tf.train.import_meta_graph('./meta_graph/my-model.meta')
        saver.restore(sess,tf.train.latest_checkpoint('./salience_model'))
        image_batch = tf.get_collection('image_batch')[0]
        pred_mattes = tf.get_collection('mask')[0]

        if args.rgb_folder:
            rgb_pths = os.listdir(args.rgb_folder)
            for rgb_pth in rgb_pths:
                rgb = misc.imread(os.path.join(args.rgb_folder,rgb_pth))
                if rgb.shape[2]==4:
                    rgb = rgba2rgb(rgb)
                origin_shape = rgb.shape
                rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)

                feed_dict = {image_batch:rgb}
                pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
                final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
                misc.imsave(os.path.join(output_folder,rgb_pth),final_alpha)

        else:
            rgb = misc.imread(args.rgb)
            if rgb.shape[2]==4:
                rgb = rgba2rgb(rgb)
            origin_shape = rgb.shape[:2]
            rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)

            feed_dict = {image_batch:rgb}
            pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
            final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
            misc.imsave(os.path.join(output_folder,'alpha.png'),final_alpha)

def parse_arguments(argv):
    parser = argparse.ArgumentParser()

    parser.add_argument('--rgb', type=str,
        help='input rgb',default = None)
    parser.add_argument('--rgb_folder', type=str,
        help='input rgb',default = None)
    parser.add_argument('--gpu_fraction', type=float,
        help='how much gpu is needed, usually 4G is enough',default = 1.0)
    return parser.parse_args(argv)


if __name__ == '__main__':
    main(parse_arguments(sys.argv[1:]))

1 个答案:

答案 0 :(得分:0)

您可以尝试摆脱显式的GPU配置,即打开您的

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_fraction)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:

进入

    with tf.Session() as sess: