我正在创建一个项目,我需要创建一个类实例,该实例具有连接到db并从中获取数据的方法(我使用SQLite作为后端)。
我有一些关于flask-sqlalchemy的经验,但是当谈到纯粹的SQLAlchemy时,我迷失了方向。
概念如下:
用户创建DataSet
的实例,并将路径作为__init__
参数传递给数据库。如果数据库已经存在,我只想连接它并进行查询,如果没有,我想用模型创建一个新的。但我无法理解如何这样做。
以下是DataSet
代码:
from os.path import normcase, split, join, isfile
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import errors
import trainset
import testset
class DataSet:
def __init__(self, path_to_set, path_to_db, train_set=False, path_to_labels=None, label_dict=None,
custom_name=None):
self.__path_to_set = path_to_set
self.__label_dict = label_dict
if custom_name is None:
dbpath = join(path_to_db, 'train.db')
if train_set is False:
dbpath = join(path_to_db, 'test.db')
else:
dbpath = join(path_to_db, custom_name)
if isfile(dbpath):
self.__prepopulated = True
else:
self.__prepopulated = False
self.__dbpath = dbpath
if train_set is True and path_to_labels is None:
raise errors.InsufficientData('labels', 'specified')
if train_set is True and not isfile(path_to_labels):
raise errors.InsufficientData('labels', 'found at specified path', path_to_labels)
def prepopulate(self):
engine = create_engine('sqlite:////' + self.__dbpath)
self.__prepopulated = True
以下是trainset
代码:
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, PickleType, Integer, MetaData
Base = declarative_base()
metadata = MetaData()
class TrainSet(Base):
__tablename__ = 'train set'
id = Column(Integer, primary_key=True)
real_id = Column(String(60))
path = Column(String(120))
labels = Column(PickleType)
features = Column(PickleType)
以下是testset
代码:
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, PickleType, Integer, MetaData
Base = declarative_base()
metadata = MetaData()
class TestSet(Base):
__tablename__ = 'test set'
id = Column(Integer, primary_key=True)
real_id = Column(String(60))
path = Column(String(120))
features = Column(PickleType)
因此,如果用户在创建train_set=True
实例时传递DataSet
,我想使用TrainSet
模型创建数据库,否则创建一个TestSet
数据库。我希望在prepopulate
方法中发生这种情况,但是,我不明白该怎么做 - 文档要求这样做:Base.metadata.create_all(engine)
,但我很遗憾在哪里放这个代码。
答案 0 :(得分:2)
首先保存参数train_set
:
class DataSet:
def __init__(self, path_to_set, path_to_db, train_set=False, path_to_labels=None, label_dict=None,
custom_name=None):
self._train_set = train_set
# ...
然后,在prepopulate
中使用它来创建正确的模型:
def prepopulate(self):
engine = create_engine('sqlite:////' + self.__dbpath)
if self._train_set:
trainset.Base.create_all(engine)
else:
testset.Base.create_all(engine)
self.__prepopulated = True
还有一件事:不要在“私有”变量前加上双下划线。请阅读PEP 8 -- Style Guide for Python Code
以供参考。