TF - 在model_fn中将global_step传递给种子

时间:2017-11-19 19:47:55

标签: tensorflow tensorflow-estimator

在Tensorflow Estimator API(tf.estimator)中有没有办法在model_fn中使用当前会话来评估张量并将值传递给python?我希望丢失的种子取决于global_step的值,但前者需要int,后者是tensor

1 个答案:

答案 0 :(得分:1)

我没有看到任何方式访问session内的model_fn以获取global_step的当前值。

即使有可能,在每个步骤更改tf.nn.dropout的种子也会在每个步骤中使用不同的种子创建新的Graph操作,这将使图形变得越来越大。即使没有tf.estimator,我也不知道如何实现这个目标?

我认为你想要的是确保你在两次跑步之间获得相同的随机性。使用tf.set_random_seed()设置图表级随机种子或仅使用丢失中的正常seed应创建可重现的掩码序列。以下是代码示例:

x = tf.ones(10)
y = tf.nn.dropout(x, 0.5, seed=42)

sess1 = tf.Session()
y1 = sess1.run(y)
y2 = sess1.run(y)

sess2 = tf.Session()
y3 = sess2.run(y)
y4 = sess2.run(y)

assert (y1 == y3).all()  # y1 and y3 are the same
assert (y2 == y4).all()  # y2 and y4 are the same

答案here提供了有关如何使图形随机性可重现的更多详细信息。