我正在研究变压器纸,在HarvardNLP(链接:http://nlp.seas.harvard.edu/2018/04/03/attention.html)中找到了一个代码实现。
我知道,当我们想更改已在函数外部初始化的变量时,应将其初始化为函数内部的全局变量。
x = 5
def foo():
global x
x = x * 2
print(x)
foo()
但是下面的代码中存在全局初始化,这使我感到困惑。
'''code link: http://nlp.seas.harvard.edu/2018/04/03/attention.html'''
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding."
global max_src_in_batch, max_tgt_in_batch
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
max_src_in_batch = max(max_src_in_batch, len(new.src))
max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)
我希望函数进行全局初始化,让变量保留在我们的内存中。任何人都可以清楚地解释一下,如果我不使用函数外部的全局初始化怎么办?
这是另一个使用该功能的代码。
class MyIterator(data.Iterator):
def create_batches(self):
if self.train:
def pool(d, random_shuffler):
for p in data.batch(d, self.batch_size * 100):
p_batch = data.batch(
sorted(p, key=self.sort_key),
self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)):
yield b
self.batches = pool(self.data(), self.random_shuffler)
else:
self.batches = []
for b in data.batch(self.data(), self.batch_size,
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))
答案 0 :(得分:0)
这是因为,当您不在函数/子例程/类解释器中时,此特定变量引用的是全局变量,而不是局部变量(因为您可以在本地和全局使用相同的变量名)。因此,在内部还需要声明要调用全局变量。
这与内存无关,而是与参考有关...
例如,检查以下代码:
>>> global x
>>> x
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NameError: name 'x' is not defined
>>> def foo():
... x=1
...
>>> foo()
>>> x
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NameError: name 'x' is not defined
>>> def foo2():
... global x
... x=1
...
>>> x
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NameError: name 'x' is not defined
>>> foo2()
>>> x
1