将会话作为参数传递给Tensorflow中的用户功能

时间:2019-05-29 13:42:43

标签: python tensorflow neural-network

由于我必须在脚本中多次运行同一部分代码,所以我想知道是否可以将tf.Session作为参数传递给用户定义的函数,以便在函数内运行某些操作,从而避免了重复的代码。

更新 使用示例

with tf.Session as sess: 
        my_training(init, epochs, Xtrain, Ytrain, batch,\
                    optimizer, loss, predictions, X, Y, sess)
        validation_error = sess.run(loss, feed_dict={X:Xvalid, Y:Yvalid})

其中优化器,损失,输入,输出,Xtrain和Ytrain等是先前创建的张量,而my_training是我创建的用于在单独的脚本中训练NN的函数。

def my_training(init, epochs, Xtrain, Ytrain, batch,\

             optimizer, loss, predictions, X, Y, sess):   

    # Shuffle the training set 
    Xtrain, Ytrain = shuffle(Xtrain, Ytrain, random_state = 20)

    # Initialize all variables (network weights, biases and optimizer)
    sess.run(init)

    # Loop over the training epochs
    for epoch in range(epochs):

        # Loop over mini-batches in the epoch
        for offset in range(0, Xtrain.shape[0], batch):

            batch_x = Xtrain[offset: offset + batch]

            batch_y = Ytrain[offset: offset + batch]

            sess.run(optimizer, feed_dict = {X:batch_x, Y:batch_y})

0 个答案:

没有答案