有人可以帮我解决这个python代码。我是Python的初学者,但对java很好。
def train_nn(iterations, batch_size, use_tf_mnist=False):
def perPartition(it):
if not use_tf_mnist:
train_data = RowData(it)
test_data = train_data
else:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
train_data = mnist.train
test_data = mnist.train
return create_nn(train_data, test_data, iterations, batch_size)
return perPartition
来电者只拨打train_nn(ITERATIONS, BATCH_SIZE, USE_TF_MNIST)
。那么内部函数从哪里得到它的参数?
答案 0 :(得分:2)
train_nn
返回perPartition
函数。当调用者稍后调用函数时它会得到它的参数:
trained = train_nn(100, 50, False)
nn = trained(some_it)