我正在用Python编写“批处理”过程(不使用任何框架)。
项目配置位于config.ini
文件中
[db]
db_uri = mysql+pymysql://root:password@localhost:3306/manage
我还有另一个文件config.test
在测试过程中要交换
[db]
db_uri = sqlite://
我有一个简单的test_sample.py
# tests/test_sample.py
import pytest
import shutil
import os
import batch
import batch_manage.utils.getconfig as getconfig_class
class TestClass():
def setup_method(self, method):
""" Rename the config """
shutil.copyfile("config.ini", "config.bak")
os.remove('config.ini')
shutil.copyfile("config.test", "config.ini")
def teardown_method(self, method):
""" Replace the config """
shutil.copyfile("config.bak", "config.ini")
os.remove('config.bak')
def test_can_get_all_data_from_table(self):
conf = getconfig_class.get_config('db')
db_uri = conf.get('db_uri')
assert db_uri == "sqlite://"
# This pass! ok!
people = batch.get_all_people()
assert len(people) == 0
# This fails, because counts the records in production database
db_uri
肯定可以(在测试时间是sqlite而不是mysql),但是len不是0,而是42(MySql数据库中的记录数。)
我怀疑SqlAlchemy ORM会话存在问题。我做了几次尝试,都无法覆盖/删除它。
其余代码非常简单:
# batch_manage/models/base.py
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import batch_manage.utils.getconfig as getconfig_class
conf = getconfig_class.get_config('db')
db_uri = conf.get('db_uri')
engine = create_engine(db_uri)
Session = sessionmaker(bind=engine)
Base = declarative_base()
# batch_manage/models/persone.py
from sqlalchemy import Column, String, Integer, Date
from batch_manage.models.base import Base
class Persone(Base):
__tablename__ = "persone"
idpersona = Column(Integer, primary_key=True)
nome = Column(String)
created_at = Column(Date)
def __init__(self, nome, created_at):
self.nome = nome
self.created_at = created_at
还有batch.py
本身
# batch.py
import click
from batch_manage.models.base import Session
from batch_manage.models.persone import Persone
def get_all_people():
""" Get all people from database """
session = Session()
people = session.query(Persone).all()
return people
@click.command()
def batch():
click.echo("------------------------------")
click.echo("Running Batch")
click.echo("------------------------------")
people = get_all_people()
for item in people:
print(f"Persona con ID {item.idpersona} creata il {item.created_at}")
if __name__ == '__main__':
batch()
我暂时通过以下方法进行了测试:
def test_can_get_all_data_from_table(self):
conf = getconfig_class.get_config('db')
db_uri = conf.get('db_uri')
assert db_uri == "sqlite://"
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
engine = create_engine(db_uri)
Session = sessionmaker(bind=engine)
session = Session()
people = batch.get_all_people(session)
assert len(people) == 0
和get_all_people
方法和
def get_all_people(session = None):
""" Get all people from database """
if session is None:
session = Session()
people = session.query(Persone).all()
return people
但是这种解决方案不是很好,并且会降低代码覆盖率(如果没有遵循if路径的话)。
答案 0 :(得分:1)
因此,如果我正确地遵循您的代码,则可能是您在设置测试之前导入了ORM内容。这是您当前的操作顺序:
batch.py
已导入。models/base.py
文件的顶级模块代码中,配置要使用的数据库。因此,解决方案:
如果只是想更改操作顺序,请在测试之前不要导入代码。无论如何,这通常是很好的测试实践:
class TestClass():
... (your existing code) ...
def test_can_get_all_data_from_table(self):
# ONLY import stuff inside your test
from batch_manage.models.base import Session
from batch_manage.models.persone import Persone
这可能会解决您眼前的问题,但可能有一个更优雅的解决方案
我不知道您是否正在使用Flask,但是无论哪种方式,the Flask testing documentation都具有一些有关如何设置测试数据库的良好说明。导入模块后,您需要配置数据库URL。
例如:
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
(your models)
请注意,我尚未定义引擎。我可以在运行时 。
def setup_engine():
engine = create_engine(db_url)
Base.metadata.bind = engine
在您的主要代码中,向用户提供内容之前,请致电setup_engine
。在测试环境中,您将调用自己的setup_engine
来绑定到测试环境。