在Python中,我可以从main函数调用变量 - 使用全局变量吗?

时间:2015-08-28 15:23:34

标签: python machine-learning theano

在Python中,我可以从main函数调用变量吗?使用全局变量?任何帮助赞赏!

def main(dataset, n_h, n_y, batch_size, dev_split, n_epochs):
    input_to_state = Linear(name='input_to_state',
                            input_dim=seq_u.shape[-1],
                            output_dim=n_h)
    global RNN # correct?
    RNN = SimpleRecurrent(activation=Tanh(),
                          dim=n_h, name="RNN")


def predict(dev_X):
    dev_transform = main.input_to_state.apply(dev_X) #?  call  "input_to_state", which one is correct?
    dev_transform = input_to_state.apply(dev_X) #?
    dev_h = main.RNN.apply(dev_transform) #? call "RNN", which one is correct?
    dev_h = RNN.apply(dev_transform) #?

if __name__ == "__main__":    
    def predict(dev_X): #  one more question: can predict function be added here?
    dataset =  ....
    main(dataset, n_h, n_y, batch_size, dev_split, 5000)
    get_predictions = theano.function([dev_X], predict) # call predict function

2 个答案:

答案 0 :(得分:0)

你必须在'main'函数之外定义'input_to_state'和'RNN',然后再修改它们。像这样:

input_to_state = None
RNN = None
def main(dataset, n_h, n_y, batch_size, dev_split, n_epochs):
    # Calling 'global' allows you to modify these variables
    global input_to_state
    global RNN
    input_to_state = Linear(name='input_to_state',
                            input_dim=seq_u.shape[-1],
                            output_dim=n_h)
    RNN = SimpleRecurrent(activation=Tanh(),
                          dim=n_h, name="RNN")


def predict(dev_X):
    dev_transform = input_to_state.apply(dev_X)
    dev_h = RNN.apply(dev_transform)

if __name__ == "__main__":   
    main(args) 
    predict(dev_X)

Howerver,我不建议这样做,全局变量应该尽量少用。 more detail here

更好的解决方案是在main函数的末尾返回'input_to_state'和'RNN',如下所示:

def main(dataset, n_h, n_y, batch_size, dev_split, n_epochs):
    input_to_state = Linear(name='input_to_state',
                            input_dim=seq_u.shape[-1],
                            output_dim=n_h)
    RNN = SimpleRecurrent(activation=Tanh(),
                          dim=n_h, name="RNN")
    return input_to_state, RNN

def predict(dev_X, input_to_state, RNN):
    dev_transform = input_to_state.apply(dev_X)
    dev_h = RNN.apply(dev_transform)

if __name__ == "__main__":   
    input_to_state, RNN = main(args) 
    predict(dev_X, input_to_state, RNN)

答案 1 :(得分:-1)

试试这个。

<强> main.py

__dataset__ = main(dataset, n_h, n_y, batch_size, dev_split, 5000)

<强> sub.py

import sys, main
__dataset__ = sys.modules['__main__'].__dataset__


修改
另一种方法是使用带有静态变量的类。

<强> mclass.py

class MClass:
    i = 0

MClass.i = 1

<强> main.py

import sub
from mclass import MClass

# In the main file
print(MClass.i) # Outputs 1
MClass.i = 3
print(MClass.i) # Outputs 3

# In a subfile
sub.mPrint() # Outputs 3
sub.set(10)
sub.mPrint() # Outputs 10

# And back in the main
print(MClass.i) # Outputs 10

<强> sub.py

from mclass import MClass

def mPrint():
   print(MClass.i)

def set(n):
   MClass.i = n