Python - 如何强制使用工厂方法来实例化对象?

时间:2013-01-01 06:24:42

标签: python class-factory

我有一组相关的类,它们都从一个基类继承。我想使用工厂方法来实例化这些类的对象。我想这样做,因为在将对象返回给调用者之前,我可以将对象存储在由类名称键入的字典中。然后,如果有对特定类的对象的请求,我可以检查我的字典中是否已存在。如果没有,我将实例化它并将其添加到字典中。如果是这样,那么我将从字典中返回现有对象。这将基本上将我模块中的所有类转换为单例。

我想这样做是因为它们都继承的基类会对子类中的函数进行一些自动包装,而且我不希望函数被多次包装,这就是当前发生的事情。创建了同一个类的两个对象。

我能想到这样做的唯一方法是检查基类的__init__()方法中的堆栈跟踪,它总是被调用,如果堆栈跟踪没有显示请求,则抛出异常使对象来自工厂功能。

这是个好主意吗?

编辑:以下是我的基类的源代码。我被告知我需要弄清楚元类以更优雅地完成这个,但这就是我现在所拥有的。所有Page对象都使用相同的Selenium Webdriver实例,该实例位于顶部导入的驱动程序模块中。初始化此驱动程序非常昂贵 - 它是在第一次创建LoginPage时初始化的。初始化后,initialize()方法将返回现有驱动程序,而不是创建新驱动程序。想法是用户必须首先创建一个LoginPage。最终将定义数十个Page类,单元测试代码将使用它们来验证网站的行为是否正确。

from driver import get_driver, urlpath, initialize
from settings import urlpaths

class DriverPageMismatchException(Exception):
    pass

class URLVerifyingPage(object):
    # we add logic in __init__() to check the expected urlpath for the page
    # against the urlpath that the driver is showing - we only want the page's
    # methods to be invokable if the driver is actualy at the appropriate page.
    # If the driver shows a different urlpath than the page is supposed to
    # have, the method should throw a DriverPageMismatchException

    def __init__(self):
        self.driver = get_driver()
        self._adjust_methods(self.__class__)

    def _adjust_methods(self, cls):
        for attr, val in cls.__dict__.iteritems():
            if callable(val) and not attr.startswith("_"):
                print "adjusting:"+str(attr)+" - "+str(val)
                setattr(
                    cls,
                    attr,
                    self._add_wrapper_to_confirm_page_matches_driver(val)
                )
        for base in cls.__bases__:
            if base.__name__ == 'URLVerifyingPage': break
            self._adjust_methods(base)

    def _add_wrapper_to_confirm_page_matches_driver(self, page_method):
        def _wrapper(self, *args, **kwargs):
            if urlpath() != urlpaths[self.__class__.__name__]:
                raise DriverPageMismatchException(
                    "path is '"+urlpath()+
                    "' but '"+urlpaths[self.__class.__name__]+"' expected "+
                    "for "+self.__class.__name__
                )
            return page_method(self, *args, **kwargs)
        return _wrapper


class LoginPage(URLVerifyingPage):
    def __init__(self, username=username, password=password, baseurl="http://example.com/"):
        self.username = username
        self.password = password
        self.driver = initialize(baseurl)
        super(LoginPage, self).__init__()

    def login(self):
        driver.find_element_by_id("username").clear()
        driver.find_element_by_id("username").send_keys(self.username)
        driver.find_element_by_id("password").clear()
        driver.find_element_by_id("password").send_keys(self.password)
        driver.find_element_by_id("login_button").click()
        return HomePage()

class HomePage(URLVerifyingPage):
    def some_method(self):
        ...
        return SomePage()

    def many_more_methods(self):
        ...
        return ManyMorePages()

如果一个页面被实例化了几次也没什么大不了的 - 这些方法只会被包裹几次,并且会发生一些不必要的检查,但一切都会有效。但如果一个页面被实例化数十次,数百次或数万次,那就太糟糕了。我可以在每个页面的类定义中放置一个标志,并检查方法是否已经被包装,但我喜欢保持类定义纯净和干净的想法,并将所有hocus-pocus推到一个深的角落我的系统,没有人能看到它,它只是工作。

4 个答案:

答案 0 :(得分:3)

在Python中,几乎不值得尝试“强制”任何东西。无论你想出什么,有人可以通过monkeypatching你的类,复制和编辑源代码,使用字节码等等来绕过它。

所以,只需编写你的工厂,并记录这是获得你的类实例的正确方法,并期望任何使用你的类编写代码的人理解TOOWTDI,除非她真的知道她在做什么,否则不要违反它。愿意弄清楚并处理后果。

如果你只是想防止意外,而不是故意“滥用”,那就是另一回事了。事实上,它只是标准的合同设计:检查不变量。当然,此时SillyBaseClass已经搞砸了,修复它已经太晚了,你所能做的就是assertraiselog或其他什么其他是合适的。但这就是你想要的:它是应用程序中的逻辑错误,唯一要做的就是让程序员修复它,所以assert可能就是你想要的。

所以:

class SillyBaseClass:
    singletons = {}

class Foo(SillyBaseClass):
    def __init__(self):
        assert self.__class__ not in SillyBaseClass.singletons

def get_foo():
    if Foo not in SillyBaseClass.singletons:
        SillyBaseClass.singletons[Foo] = Foo()
    return SillyBaseClass.singletons[Foo]

如果你真的想阻止事情发展,你可以<{>> 更早地检查__new__方法中的不变量,但除非“SillyBaseClass搞砸了“相当于”发射核武器“,为什么要这么麻烦?

答案 1 :(得分:2)

听起来你想提供一个__new__实现:像:

class MySingledtonBase(object):
    instance_cache = {}
    def __new__(cls, arg1, arg2):
        if cls in MySingletonBase.instance_cache:
            return MySingletonBase.instance_cache[cls]
        self = super(MySingletonBase, cls).__new__(arg1, arg2)
        MySingletonBase.instance_cache[cls] = self
        return self

答案 2 :(得分:2)

我不是在运行时添加复杂的代码来捕获错误,而是首先尝试使用约定来引导模块的用户自己做正确的事情。

为您的班级提供“私人”名称(以下划线为前缀),给他们提供不应该实例化的名称(例如_Internal ...),并使您的工厂功能“公开”。

就是这样:

class _InternalSubClassOne(_BaseClass):
    ...

class _InternalSubClassTwo(_BaseClass):
    ...

# An example factory function.
def new_object(arg):
    return _InternalSubClassOne() if arg == 'one' else _InternalSubClassTwo()

我还会为每个类添加文档字符串或注释,例如“不要手动实例化此类,使用工厂方法new_object。”

答案 3 :(得分:0)

您也可以在工厂方法中嵌套类,如下所述: https://python-3-patterns-idioms-test.readthedocs.io/en/latest/Factory.html#preventing-direct-creation

来自上述来源的工作示例:

# Factory/shapefact1/NestedShapeFactory.py
import random

class Shape(object):
    types = []

def factory(type):
    class Circle(Shape):
        def draw(self): print("Circle.draw")
        def erase(self): print("Circle.erase")

    class Square(Shape):
        def draw(self): print("Square.draw")
        def erase(self): print("Square.erase")

    if type == "Circle": return Circle()
    if type == "Square": return Square()
    assert 0, "Bad shape creation: " + type

def shapeNameGen(n):
    for i in range(n):
        yield factory(random.choice(["Circle", "Square"]))

# Circle() # Not defined

for shape in shapeNameGen(7):
    shape.draw()
    shape.erase()

我不喜欢这种解决方案,只是想将其添加为另一个选项。