我正在导入一个带有一些连接到数据库的函数的模块。我像这样调用该模块中的函数:
main.py:
import data_source as DS
DS.get_data(type_of_data)
DS.get_some_other_data(type_of_data)
data_source.py:
def get_data(type_of_data):
# connect to db and return a dictionary
def get_some_other_data(type_of_data):
# connect to db and return a dictionary
每个函数都建立了数据库连接。我正在尝试减少代码重复,所以我可以有一个连接到DB的函数,如下所示:
data_source.py
中的:
def connect_to_db():
# connect to bd and return connection to caller
但我必须从每个connect_to_db()
拨打function
。有没有办法确保是否调用导入模块中的函数另一个函数默认运行?所以我不必在connect_to_db
内的每个函数中都有data_source.py
作为顶部?所以每次通话都没有连接?
我知道这不会影响我的目的,但我只是好奇。我知道我可以在第一次调用之后将连接传递给调用者,然后对于后续调用,我可以将连接传递给导入模块中的函数,但这是我想要避免的事情。
就像当你python main.py
if __name__ == "main":
下的所有内容都运行时,可能就像
if __function_is_being_called__:
?
答案 0 :(得分:2)
您可以使用decorator首先拨打connect_to_db
,然后正常调用该功能:
import functools
def with_db_connection(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
connect_to_db()
return f(*args, **kwargs)
return wrapper
@with_db_connection
def get_data(type_of_data):
...
您还可以通过数据库参数来避免全局变量:
import functools
def with_db_connection(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
db = connect_to_db()
return f(db, *args, **kwargs)
return wrapper
@with_db_connection
def get_data(db, type_of_data):
...
get_data('string') # db argument is passed in through the decorator
答案 1 :(得分:2)
我不知道你如何定义data_source.py
,但只要你立即执行为所有其他功能设置舞台的功能,它的全局变量就足够了,例如:
# data_source.py
connection_status = 0 # initial status is set to 0
def connect_to_db():
global connection_status
connection_status = 1 # this can be your dependent variable
connect_to_db() # call it immediately, this will execute on first import
def get_data():
print("Current connection status: {}".format(connection_status))
def get_some_other_data():
print("Current connection status: {}".format(connection_status))
现在,如果你将其导入任何地方:
import data_source as DS
# even at this point DS.connection_status is already 1
DS.get_data() # Current connection status: 1
DS.get_some_other_data() # Current connection status: 1
这是一个使用内存中SQLite实例的简单设置:
# data_source.py
import sqlite3
connection = None # holds connection to the database
def connect_to_db():
global connection
connection = sqlite3.connect(":memory:") # create an in-memory DB
cursor = connection.cursor() # create a local cursor
# let's create a simple squares table from 1 to 100
cursor.execute("CREATE TABLE squares (num INTEGER PRIMARY KEY, square INTEGER)")
for i in range(1, 101): # fill the table
cursor.execute("INSERT INTO squares (num, square) VALUES ({}, {})".format(i, i**2))
connect_to_db()
def get_square(num):
cursor = connection.cursor() # create a local cursor
cursor.execute("SELECT square FROM squares WHERE num={}".format(num))
data = cursor.fetchone()
return data[0] if data else None
def get_square_root(num):
cursor = connection.cursor() # create a local cursor
cursor.execute("SELECT num FROM squares WHERE square={}".format(num))
data = cursor.fetchone()
return data[0] if data else None
当你想要使用它时:
import data_source as DS
print("43 squared: {}".format(DS.get_square(43))) # 43 squared: 1849
print("sqrt of 4489: {}".format(DS.get_square_root(4489))) # sqrt of 4489: 67