为什么这个TensorFlow代码在测试用例中表现不同?

时间:2017-11-16 00:04:10

标签: python unit-testing testing tensorflow

我有一个函数(下面为foo),当它直接运行时与在tf.test.TestCase内运行时的行为不同。

代码应该创建一个带有elems [1..5]的数据集并对其进行随机播放。然后它重复3次:从数据创建一个迭代器,并使用它来打印5个元素。

当它自己运行时,它会输出所有列表都被混洗的输出,例如:

[4, 0, 3, 2, 1]
[0, 2, 1, 3, 4]
[2, 3, 4, 0, 1]

但是当在测试用例中运行时,它们总是相同的,即使在运行之间:

[0, 4, 2, 3, 1]
[0, 4, 2, 3, 1]
[0, 4, 2, 3, 1]

我想这与测试用例如何处理随机种子有关,但我在TensorFlow文档中看不到任何相关内容。谢谢你的帮助!

代码:

import tensorflow as tf

def foo():
    sess = tf.Session()
    dataset = tf.data.Dataset.range(5)
    dataset = dataset.shuffle(5, reshuffle_each_iteration=False)

    for _ in range(3):
        data_iter = dataset.make_one_shot_iterator()
        next_item = data_iter.get_next()
        with sess.as_default():
            data_new = [next_item.eval() for _ in range(5)]
        print(data_new)


class DatasetTest(tf.test.TestCase):
    def testDataset(self):
        foo()

if __name__ == '__main__':
    foo()
    tf.test.main()

我使用Python 3.6和TensorFlow 1.4运行它。不需要其他模块。

1 个答案:

答案 0 :(得分:2)

我认为你是对的; tf.test.TestCase正在setup使用固定种子。

class TensorFlowTestCase(googletest.TestCase):
# ...
def setUp(self):
  self._ClearCachedSession()
  random.seed(random_seed.DEFAULT_GRAPH_SEED)
  np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
  ops.reset_default_graph()
  ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED

DEFAULT_GRAPH_SEED = 87654321 tensorflow/tensorflow/python/framework/random_seed.py中的this行。{/ 1>