在Python中,我可以从main函数调用变量吗?使用全局变量?任何帮助赞赏!
def main(dataset, n_h, n_y, batch_size, dev_split, n_epochs):
input_to_state = Linear(name='input_to_state',
input_dim=seq_u.shape[-1],
output_dim=n_h)
global RNN # correct?
RNN = SimpleRecurrent(activation=Tanh(),
dim=n_h, name="RNN")
def predict(dev_X):
dev_transform = main.input_to_state.apply(dev_X) #? call "input_to_state", which one is correct?
dev_transform = input_to_state.apply(dev_X) #?
dev_h = main.RNN.apply(dev_transform) #? call "RNN", which one is correct?
dev_h = RNN.apply(dev_transform) #?
if __name__ == "__main__":
def predict(dev_X): # one more question: can predict function be added here?
dataset = ....
main(dataset, n_h, n_y, batch_size, dev_split, 5000)
get_predictions = theano.function([dev_X], predict) # call predict function
答案 0 :(得分:0)
你必须在'main'函数之外定义'input_to_state'和'RNN',然后再修改它们。像这样:
input_to_state = None
RNN = None
def main(dataset, n_h, n_y, batch_size, dev_split, n_epochs):
# Calling 'global' allows you to modify these variables
global input_to_state
global RNN
input_to_state = Linear(name='input_to_state',
input_dim=seq_u.shape[-1],
output_dim=n_h)
RNN = SimpleRecurrent(activation=Tanh(),
dim=n_h, name="RNN")
def predict(dev_X):
dev_transform = input_to_state.apply(dev_X)
dev_h = RNN.apply(dev_transform)
if __name__ == "__main__":
main(args)
predict(dev_X)
Howerver,我不建议这样做,全局变量应该尽量少用。 more detail here
更好的解决方案是在main函数的末尾返回'input_to_state'和'RNN',如下所示:
def main(dataset, n_h, n_y, batch_size, dev_split, n_epochs):
input_to_state = Linear(name='input_to_state',
input_dim=seq_u.shape[-1],
output_dim=n_h)
RNN = SimpleRecurrent(activation=Tanh(),
dim=n_h, name="RNN")
return input_to_state, RNN
def predict(dev_X, input_to_state, RNN):
dev_transform = input_to_state.apply(dev_X)
dev_h = RNN.apply(dev_transform)
if __name__ == "__main__":
input_to_state, RNN = main(args)
predict(dev_X, input_to_state, RNN)
答案 1 :(得分:-1)
试试这个。
<强> main.py 强>
__dataset__ = main(dataset, n_h, n_y, batch_size, dev_split, 5000)
<强> sub.py 强>
import sys, main
__dataset__ = sys.modules['__main__'].__dataset__
修改强>
另一种方法是使用带有静态变量的类。
<强> mclass.py 强>
class MClass:
i = 0
MClass.i = 1
<强> main.py 强>
import sub
from mclass import MClass
# In the main file
print(MClass.i) # Outputs 1
MClass.i = 3
print(MClass.i) # Outputs 3
# In a subfile
sub.mPrint() # Outputs 3
sub.set(10)
sub.mPrint() # Outputs 10
# And back in the main
print(MClass.i) # Outputs 10
<强> sub.py 强>
from mclass import MClass
def mPrint():
print(MClass.i)
def set(n):
MClass.i = n