sess.run()不运行?

时间:2018-05-03 02:54:06

标签: python tensorflow deep-learning

我是新来的,学习张量流并遇到问题。

import model_method
fittt(model_method.build(self,...),...parameters...)

以上是在main.py中导入model_method.py。 main.py中的函数 fittt

def fittt(model,...):
    model.fit(...)
model_method.py中的

build()

def build(self,...):
    self.op_C,self.op_A = self.function_A(...)
    self.op_B = self.function_B(self.op_C,...)
在model_method.py中

fit()

def fit(self,...):
    sess = tf.Session(graph=self.graph,config=config)
    BB,AA = sess.run([self.op_B,self.op_A],feed_dict)

为了检查运行过程,我在model_method.py中的 function_A() function_B()的开头添加了pdb.set_trace(),如下所示:

def function_A(self,...):
    pdb.set_trace()
    ......

def function_B(self,...):
    pdb.set_trace()
    ......

两个pdb.set_trace()仅在调用build()时停止,并且在调用sess.run([self.op_B,self.op_A],feed_dict)时不起作用。所以这意味着sess.run()实际上没有运行 function_A()和function_B()。我想知道为什么并想知道如何使这两个功能起作用?

1 个答案:

答案 0 :(得分:1)

通过调用model_method.build()函数,您可以创建计算图。在此调用中,每行代码都被执行(因此pdb停止的原因。)

但是,tf.Session.run(...)仅执行计算所获取值所需的计算图形部分(在您的示例中为self.op_Aself.op_B)。该函数不再执行整个build()函数。

因此,当您运行pdb.set_trace()sess.run(...)未执行的原因是因为它们不是有效的Tensor个对象,因此不属于计算图形。

<强>更新

请考虑以下事项:

class My_Model:

  def __init__(self):
      self.np_input = np.random.normal(size=(10,2)) # 10x2

  def build(self):
      self._in = tf.placeholder(dtype=tf.float32, shape=[10, None]) # matrix 10xN
      W_exception = tf.random_normal(dtype=tf.float32, shape=[3,3]) # matrix 3x3
      W_success = tf.random_normal(dtype=tf.float32, shape=[2,3]) # matrix 2x3
      self.op_exception = tf.matmul(self._in, W_exception) # [10x2] x [3x3] = ERROR
      self.op_success = tf.matmul(self._in, W_success) # [10x2] x [2x3] = [10x3]
      print('Computational Graph Built')

  def fit_success(self):
      with tf.Session() as sess:
          res = sess.run(self.op_success, feed_dict={self._in : self.np_input})
          print('Result shape: {}'.format(res.shape))

  def fit_exception(self):
      with tf.Session() as sess:
          res = sess.run(self.op_exception, feed_dict={self._in : self.np_input})
          print('Result shape: {}'.format(res.shape))

然后致电:

m = My_Model()
m.build()
#> Computational Graph Built

m.fit_success()
#> Result shape: (10, 3)

m.fit_exception()
#> InvalidArgumentError: Matrix size-incompatible: In[0]: [10,2], In[1]: [3,3]

所以要解释你在那里看到的东西。我们首先在build()函数中定义计算图。 _in是我们的输入张量; None表示维度1是动态确定的 - 即一旦我们提供具有指定值的张量。

然后我们定义了两个矩阵W_exceptionW_success,它们都指定了所有维度,并且会随机生成它们的值。

然后我们定义两个操作,矩阵乘法,每个操作返回一个张量。

我们调用了build()函数并创建了计算图,print()函数也被执行但未添加到图中。这里没有计算任何东西。事实上,它甚至不可能,因为_in的值没有指定。

现在要说明,只计算计算所需的必要部分,我们称之为fit_success()函数,它只是将输入张量_in乘以W_success张量(正确的尺寸) )。我们得到一个形状正确的张量:[10x3]。请注意,我们不会收到由于尺寸不匹配导致op_exception无法计算的错误。那是因为我们不需要它来评估op_success

最后,我只是表明当我们尝试使用相同的输入张量评估op_exception时,确实会抛出异常。