由于我必须在脚本中多次运行同一部分代码,所以我想知道是否可以将tf.Session
作为参数传递给用户定义的函数,以便在函数内运行某些操作,从而避免了重复的代码。
更新 使用示例
with tf.Session as sess:
my_training(init, epochs, Xtrain, Ytrain, batch,\
optimizer, loss, predictions, X, Y, sess)
validation_error = sess.run(loss, feed_dict={X:Xvalid, Y:Yvalid})
其中优化器,损失,输入,输出,Xtrain和Ytrain等是先前创建的张量,而my_training
是我创建的用于在单独的脚本中训练NN的函数。
def my_training(init, epochs, Xtrain, Ytrain, batch,\
optimizer, loss, predictions, X, Y, sess):
# Shuffle the training set
Xtrain, Ytrain = shuffle(Xtrain, Ytrain, random_state = 20)
# Initialize all variables (network weights, biases and optimizer)
sess.run(init)
# Loop over the training epochs
for epoch in range(epochs):
# Loop over mini-batches in the epoch
for offset in range(0, Xtrain.shape[0], batch):
batch_x = Xtrain[offset: offset + batch]
batch_y = Ytrain[offset: offset + batch]
sess.run(optimizer, feed_dict = {X:batch_x, Y:batch_y})