如何注射"内存中的SQLAlchemy sqlite3数据库进入Flask test_client?

时间:2018-06-13 04:08:46

标签: python flask sqlite sqlalchemy

我有一个Flask应用程序,我正在使用SQLAlchemy。我不想使用Flask-SQLAlchemy扩展。如果我使用"真实"用于测试的数据库,只需设置指向数据库测试实例的环境变量,一切运行良好。

但是,如果我想指向内存中的sqlite数据库进行测试,我就遇到了问题。在这种情况下,我可以在我的测试中设置数据,但是当我使用test_client在我的应用程序中执行给定路由时,它无法在服务器端找到我的数据库表"原样 - 即在hello.py代码中进一步向下显示。必须有一种方法来配置test_client才能使其工作,但我似乎无法弄清楚如何去做。

以下是可能相关的代码片段(db.py):

import os

from sqlalchemy import create_engine

from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker

engine = create_engine(os.environ['SQLALCHEMY_URL'])

Session = scoped_session(sessionmaker(bind=engine))

这里我正在设置scoped_session,以便我的数据库访问将是线程本地的。

引导我的应用程序的代码(__init __。py):

from flask import Flask

from .db import Session

from .hello import hello_blueprint

app = Flask(__name__)
app.register_blueprint(hello_blueprint)

@app.teardown_appcontext
def cleanup(resp_or_exc):
    Session.remove()

每次Flask弹出应用程序上下文时,我都会设置我的应用程序并注册清理回调。

蓝图中的示例路由(hello.py):

import json

from flask import Blueprint

from .db import Session

from .models import Message

hello_blueprint = Blueprint('hello', __name__)

@hello_blueprint.route('/messages')
def messages():
    values = Session.query(Message).all()

    results = []
    for value in values:
        results.append({ 'message': value.message })

    return (json.dumps(results), 200, { 'content_type': 'application/json' })

我在这里使用作用域会话从数据库中获取一些数据。

我的Message模型的定义只是vanilla SQLAlchemy(models.py):

from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy import Column, Integer, String

Base = declarative_base()

class Message(Base):
    __tablename__ = 'messages'
    id = Column(Integer, primary_key=True)
    message = Column(String)

    def __repr__(self):
        return "<Message(message='%s')>" % (self.message)

下面是一个非常原始的pytest单元测试,只是为了在一个地方展示问题(test_hello.py):

import os 

import json

import pytest

import app

from .models import Message

@pytest.fixture
def client():
    client = app.app.test_client()

    return client

def test_hello(client):
    response = client.get('/')
    data = json.loads(response.data.decode('utf-8'))
    assert data == { 'message': "Hello friend!" }

def test_messages(client):
    with app.app.app_context():
        from sqlalchemy import create_engine
        from sqlalchemy import MetaData
        engine = create_engine('sqlite://')

        from .models import Base
        Base.metadata.create_all(engine)

        print('***metadata tables***')
        print(Base.metadata.tables.keys())

        from sqlalchemy.orm import scoped_session
        from sqlalchemy.orm import sessionmaker

        Session = scoped_session(sessionmaker(bind=engine))

        message = Message(message='Hello there!')

        Session.add(message)
        Session.commit()

        values = Session.query(Message).all()

        results = []
        for value in values:
            results.append({ 'message': value.message })

        # This works, prints : [{"message": "Hello there!"}]
        print('*** result***')    
        print(json.dumps(results))

        # The code below doesn't work. Flask's app.py throws an exception
        # with the following at its root:
        # sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) no such table: 
        # messages [SQL: 'SELECT messages.id AS messages_id, messages.message AS messages_message, messages.new_field 
        # AS messages_new_field \nFROM messages'] (Background on this error at: http://sqlalche.me/e/e3q8)

        response = client.get('/messages')

        data =json.loads(response.data.decode('utf-8'))
        assert data == [{'message': 'Hello there!'}]

1 个答案:

答案 0 :(得分:0)

我以前不太确定我做错了什么,但我设法让我的考试成功。

单元测试的修改版本如下(test_hello.py):

import os 

import json

import pytest

import app

from .db import engine
from .db import Session

from .models import Base
from .models import Message

@pytest.fixture
def client():
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)

    client = app.app.test_client()

    return client

def test_messages(client):
    message = Message(message='Hello there!')

    Session.add(message)
    Session.commit()

    response = client.get('/messages')

    data = json.loads(response.data.decode('utf-8'))
    assert data == [{'message': 'Hello there!'}]