从tf会话中获取变量

时间:2018-01-18 15:34:45

标签: python tensorflow

我已经介绍了这段代码片段,这是一个使用梯度下降的非常简单的线性回归模型。

让我感到困惑的是最后一行final_slope , final_intercept = sess.run([m,b]),这是从会话中获取变量的最佳方式,而不是再次运行会话吗?

我希望了解这句话如何在引擎盖下运作

我的代码:

import tensorflow as tf
import numpy as np

x_data = np.linspace(0,10,10) + np.random.uniform(-1.5,1.5,10)
y_label = np.linspace(0,10,10) + np.random.uniform(-1.5,1.5,10)

m = tf.Variable(0.29220241)
b = tf.Variable(0.84038402)

error = 0

for x,y in zip(x_data,y_label):

    y_hat = m*x + b

    error += (y-y_hat)**2 

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
train = optimizer.minimize(error)

init = tf.global_variables_initializer()


with tf.Session() as sess:

    sess.run(init)

    epochs = 1

    for i in range(epochs):

        sess.run(train)


    # Fetch Back Results
    final_slope , final_intercept = sess.run([m,b])

根据文件

   a = tf.constant([10, 20])
   b = tf.constant([1.0, 2.0])
   # 'fetches' can be a singleton
   v = session.run(a)
   # v is the numpy array [10, 20]
   # 'fetches' can be a list.
   v = session.run([a, b])
   # v a Python list with 2 numpy arrays: the numpy array [10, 20] and the
   # 1-D array [1.0, 2.0]
   # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
   MyData = collections.namedtuple('MyData', ['a', 'b'])
   v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
   # v is a dict with
   # v['k1'] is a MyData namedtuple with 'a' the numpy array [10, 20] and
   # 'b' the numpy array [1.0, 2.0]
   # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
   # [10, 20].

仍然无法获取再次运行会话以提取变量的有意义信息。我想了解如果每次运行会话是获取变量的最佳情况还是有更好更好的另一种方式

1 个答案:

答案 0 :(得分:2)

通过使用“运行会话”这一短语来说明这里的误解。会话未运行'。会话'运行'一个东西。思考过程是,在会话中,计算图的某些部分运行,由您要求的特定节点确定。因此,当您执行session.run([y_hat])时,我们的tensorflow会话(基本上只是完全可以进行任何计算的必要条件)'运行'我们计算图的部分是计算张量y_hat所必需的。在你的情况下,y_hat需要抓取几个变量的值,并进行一些张量乘法和加法。

如果你想要图表中的其他张量,你可以“运行”。那些也是。有时,某些张量是在通往其他人的途中计算出来的。例如当我们计算(y-y_hat)**2时,会计算y_hat。我们可以session.run([y_hat, (y-y_hat)**2])y_hat(afaik)只计算一次,而不必为每个执行整个计算图。

这里的关键见解是在运行之间不存储张量。因此,如果您致电session.run([y_hat]) ; session.run([(y-y_hat)**2]),那么所有导致y_hat的计算都必须执行两次。

希望这有帮助。