我在tensorflow中编写了一个用于读取csv格式数据的自定义内核操作。
它在TestCase中正常工作,sess
对象返回test_session()
函数。
当我转向普通代码时,阅读器op每次都返回相同的结果。然后我在MyOp:Compute
函数的开头放了一些调试打印。似乎在第一次运行后,sess.run(myop)
从不调用MyOp:Compute
函数。
然后我回到我的测试用例,如果我用tf.Session()
而不是self.test_session()
替换会话对象,它会以同样的方式失败。
有人对此有任何想法吗?
分享更多细节,这里是我的迷你演示代码: https://github.com/littleDing/mini_csv_reader
在测试用例中:
def testSimple(self):
input_data_schema, feas, batch_size = self.get_simple_format()
iter_op = ops.csv_iter('./sample_data.txt', input_data_schema, feas, batch_size=batch_size, label='label2')
with self.test_session() as sess:
label,sign = sess.run(iter_op)
print label
self.assertAllEqual(label.shape, [batch_size])
self.assertAllEqual(sign.shape, [batch_size, len(feas)])
self.assertAllEqual(sum(label), 2)
self.assertAllEqual(sign[0,:], [7,0,4,1,1,1,5,9,8])
label,sign = sess.run(iter_op)
self.assertAllEqual(label.shape, [batch_size])
self.assertAllEqual(sign.shape, [batch_size, len(feas)])
self.assertAllEqual(sum(label), 1)
self.assertAllEqual(sign[0,:], [9,9,3,1,1,1,5,4,8])
正常通话:
def testing_tf():
path = './sample_data.txt'
input_data_schema, feas, batch_size = get_simple_format()
with tf.device('/cpu:0'):
n_data_op = tf.placeholder(dtype=tf.float32)
iter_op = ops.csv_iter(path, input_data_schema, feas, batch_size=batch_size, label='label2')
init_op = [tf.global_variables_initializer(), tf.local_variables_initializer() ]
with tf.Session() as sess:
sess.run(init_op)
n_data = 0
for batch_idx in range(3):
print '>>>>>>>>>>>>>> before run batch', batch_idx
## it should be some debug printing here, but nothing come out when batch_idx>0
label,sign = sess.run(iter_op)
print '>>>>>>>>>>>>>> after run batch', batch_idx
## the content of sign remain the same every time
print sign
if len(label) == 0:
break
答案 0 :(得分:1)
查看tf.test.TestCase.test_session()
的{{3}}提供了一些线索,因为它配置的会话与直接调用tf.Session
的方式略有不同。特别是test_session()
implementation 常量折叠优化。默认情况下,TensorFlow会将图表的无状态部分转换为tf.constant()
个节点,因为每次运行它们时都会生成相同的结果。
在"CsvIter"
操作注册中,有SetIsStateful()
注释,因此TensorFlow会将其视为无状态,因此会受到不断折叠的影响。但是,它的实现是非常有状态的:通常,任何你希望用相同的输入张量生成不同结果的op,或者任何在成员变量中存储可变状态的op都应该标记为有状态。
解决方案是对REGISTER_OP
的{{1}}进行单行更改:
"CsvIter"