尝试使用Tensorflow实施对象检测时遇到错误

时间:2019-12-01 06:14:57

标签: python tensorflow

我正在查看的项目是:https://github.com/mdietrichstein/tensorflow-open_nsfw

我拉出了代码,当我运行它时,出现以下错误(特定文件也在下面运行):

回溯(最近通话最近一次):

文件“ classify_nsfw.py”,位于

的第58行

main(sys.argv)

文件“ classify_nsfw.py”,第37行,位于主目录中     将tf.Session()设置为sess:

AttributeError:模块'tensorflow'没有属性'Session'

#!/usr/bin/env python
import sys
import argparse
import tensorflow as tf

from model import OpenNsfwModel, InputType
from image_utils import create_tensorflow_image_loader
from image_utils import create_yahoo_image_loader
import cv2
import time
import os

import numpy as np

def main(argv):
    parser = argparse.ArgumentParser()

    parser.add_argument("input_file", help="Path to the input image.\
                        Only jpeg images are supported.")

    parser.add_argument("-m", "--model_weights", required=True,
                        help="Path to trained model weights file")


    parser.add_argument("-i", "--input_type",
                        default=InputType.TENSOR.name.lower(),
                        help="input type",
                        choices=[InputType.TENSOR.name.lower(),
                                 InputType.BASE64_JPEG.name.lower()])

    args = parser.parse_args()

    model = OpenNsfwModel()

    with tf.Session() as sess:

        input_type = InputType[args.input_type.upper()]

        model.build(weights_path=args.model_weights, input_type=input_type)
        fn_load_image = None

        if input_type == InputType.TENSOR:
            fn_load_image = create_yahoo_image_loader()
        if input_type == InputType.BASE64_JPEG:
            import base64
            fn_load_image = lambda filename: np.array([base64.urlsafe_b64encode(open(filename, "rb").read())])

        sess.run(tf.global_variables_initializer())

        image = fn_load_image(args.input_file)
        predictions = sess.run(model.predictions, feed_dict={model.input: image})
        print("Results for '{}'".format(args.input_file))
        print("\tSFW score:\t{}\n\tNSFW score:\t{}".format(*predictions[0]))

if __name__ == "__main__":
    main(sys.argv)

1 个答案:

答案 0 :(得分:0)

原因是您使用的是旧的TensorFlow 1.X 语法。 tf.Session()已在TensorFlow 2.0.0 中弃用。

模块tf.compat.v1引入所有公共TensorFlow接口。如果仍要使用tf.Session(),请改用语法tf.compat.v1.Session()