我有一个函数(下面为foo
),当它直接运行时与在tf.test.TestCase
内运行时的行为不同。
代码应该创建一个带有elems [1..5]的数据集并对其进行随机播放。然后它重复3次:从数据创建一个迭代器,并使用它来打印5个元素。
当它自己运行时,它会输出所有列表都被混洗的输出,例如:
[4, 0, 3, 2, 1]
[0, 2, 1, 3, 4]
[2, 3, 4, 0, 1]
但是当在测试用例中运行时,它们总是相同的,即使在运行之间:
[0, 4, 2, 3, 1]
[0, 4, 2, 3, 1]
[0, 4, 2, 3, 1]
我想这与测试用例如何处理随机种子有关,但我在TensorFlow文档中看不到任何相关内容。谢谢你的帮助!
import tensorflow as tf
def foo():
sess = tf.Session()
dataset = tf.data.Dataset.range(5)
dataset = dataset.shuffle(5, reshuffle_each_iteration=False)
for _ in range(3):
data_iter = dataset.make_one_shot_iterator()
next_item = data_iter.get_next()
with sess.as_default():
data_new = [next_item.eval() for _ in range(5)]
print(data_new)
class DatasetTest(tf.test.TestCase):
def testDataset(self):
foo()
if __name__ == '__main__':
foo()
tf.test.main()
我使用Python 3.6和TensorFlow 1.4运行它。不需要其他模块。
答案 0 :(得分:2)
我认为你是对的; tf.test.TestCase
正在setup使用固定种子。
class TensorFlowTestCase(googletest.TestCase):
# ...
def setUp(self):
self._ClearCachedSession()
random.seed(random_seed.DEFAULT_GRAPH_SEED)
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
ops.reset_default_graph()
ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED
和
DEFAULT_GRAPH_SEED = 87654321
tensorflow/tensorflow/python/framework/random_seed.py
中的this行。{/ 1>