为什么自定义读取操作仅适用于test_session

时间:2017-10-18 13:10:26

标签: python tensorflow

我在tensorflow中编写了一个用于读取csv格式数据的自定义内核操作。

它在TestCase中正常工作,sess对象返回test_session()函数。

当我转向普通代码时,阅读器op每次都返回相同的结果。然后我在MyOp:Compute函数的开头放了一些调试打印。似乎在第一次运行后,sess.run(myop)从不调用MyOp:Compute函数。

然后我回到我的测试用例,如果我用tf.Session()而不是self.test_session()替换会话对象,它会以同样的方式失败。

有人对此有任何想法吗?

分享更多细节,这里是我的迷你演示代码:  https://github.com/littleDing/mini_csv_reader

在测试用例中:

def testSimple(self):
  input_data_schema, feas, batch_size = self.get_simple_format()
  iter_op = ops.csv_iter('./sample_data.txt', input_data_schema, feas, batch_size=batch_size, label='label2')
  with self.test_session() as sess:
    label,sign = sess.run(iter_op)
    print label

    self.assertAllEqual(label.shape, [batch_size])
    self.assertAllEqual(sign.shape, [batch_size, len(feas)])
    self.assertAllEqual(sum(label), 2)
    self.assertAllEqual(sign[0,:], [7,0,4,1,1,1,5,9,8])

    label,sign = sess.run(iter_op)
    self.assertAllEqual(label.shape, [batch_size])
    self.assertAllEqual(sign.shape, [batch_size, len(feas)])
    self.assertAllEqual(sum(label), 1)
    self.assertAllEqual(sign[0,:], [9,9,3,1,1,1,5,4,8])

正常通话:

def testing_tf():
    path = './sample_data.txt'
    input_data_schema, feas, batch_size = get_simple_format()
    with tf.device('/cpu:0'):
        n_data_op = tf.placeholder(dtype=tf.float32)
        iter_op = ops.csv_iter(path, input_data_schema, feas, batch_size=batch_size, label='label2') 
        init_op = [tf.global_variables_initializer(), tf.local_variables_initializer() ]

    with tf.Session() as sess:
      sess.run(init_op)
      n_data = 0
      for batch_idx in range(3):
        print '>>>>>>>>>>>>>> before run batch', batch_idx
        ## it should be some debug printing here, but nothing come out when batch_idx>0
        label,sign = sess.run(iter_op)
        print '>>>>>>>>>>>>>> after run batch', batch_idx
        ## the content of sign remain the same every time
        print sign
        if len(label) == 0:
          break

1 个答案:

答案 0 :(得分:1)

查看tf.test.TestCase.test_session()的{​​{3}}提供了一些线索,因为它配置的会话与直接调用tf.Session的方式略有不同。特别是test_session() implementation 常量折叠优化。默认情况下,TensorFlow会将图表的无状态部分转换为tf.constant()个节点,因为每次运行它们时都会生成相同的结果。

"CsvIter"操作注册中,有SetIsStateful()注释,因此TensorFlow会将其视为无状态,因此会受到不断折叠的影响。但是,它的实现是非常有状态的:通常,任何你希望用相同的输入张量生成不同结果的op,或者任何在成员变量中存储可变状态的op都应该标记为有状态。

解决方案是对REGISTER_OP的{​​{1}}进行单行更改:

"CsvIter"