使用tf.assign()

时间:2017-11-14 01:23:01

标签: tensorflow

我学习Tensor流程并尝试将每个迭代持续时间捕获到摘要变量中。我已将其识别为一个特殊问题,我在下面的代码中总结了这个问题

1)定义张量(捕捉开始/结束时间)

t = tf.Variable(0.0, tf.float64)

2)在每次迭代时运行迭代和输出时间

with tf.Session() as sess:
    sess.run(tf.variables_initializer(tf.global_variables()))
    for i in range(20):
        _ = time.time()
        sess.run(tf.assign(t,_)) #update 't' tensor value to start time
        print("time_1: {}, time_2_tensor: {} ".format(_,sess.run(t)))

当我运行代码时,我想知道为什么 time_1 之间的差异很大(4-5秒) time_2_tensor 值。这里的输出(time_1似乎更正确,并且想知道为什么time_2_tensor看起来是未来的时间和所有相同的价值!

  

time_1:1510622147.797711,time_2_tensor:1510622208.0

     

time_1:1510622147.823721,time_2_tensor:1510622208.0

     

time_1:1510622147.846073,time_2_tensor:1510622208.0

     

time_1:1510622147.872359,time_2_tensor:1510622208.0

     

time_1:1510622147.893345,time_2_tensor:1510622208.0

     

time_1:1510622147.913889,time_2_tensor:1510622208.0

     

time_1:1510622147.94033,time_2_tensor:1510622208.0

     

time_1:1510622147.960254,time_2_tensor:1510622208.0

     

time_1:1510622147.98226,time_2_tensor:1510622208.0

     

time_1:1510622148.007267,time_2_tensor:1510622208.0

     

time_1:1510622148.045414,time_2_tensor:1510622208.0

     

time_1:1510622148.072437,time_2_tensor:1510622208.0

     

time_1:1510622148.104469,time_2_tensor:1510622208.0

     

time_1:1510622148.124364,time_2_tensor:1510622208.0

     

time_1:1510622148.143735,time_2_tensor:1510622208.0

     

time_1:1510622148.161832,time_2_tensor:1510622208.0

     

time_1:1510622148.179756,time_2_tensor:1510622208.0

     

time_1:1510622148.216838,time_2_tensor:1510622208.0

     

time_1:1510622148.235228,time_2_tensor:1510622208.0

     

time_1:1510622148.254686,time_2_tensor:1510622208.0

非常感谢您对此有任何见解!

1 个答案:

答案 0 :(得分:2)

在声明变量(它取代参​​数trainable时,至少在v1.4中)并且变为float32时,不会考虑您的数据类型,这会导致精确问题。

如果您添加关键字dtype

,则此功能正常
t = tf.Variable(0.0, dtype=tf.float64)

'实验':

t_bad = tf.Variable(0.0, tf.float64)
t = tf.Variable(0.0, dtype=tf.float64)

print(t_bad.dtype)  # <dtype: 'float32_ref'>
print(t.dtype)  # <dtype: 'float64_ref'>