从tensorflow 1.x升级到2.0

时间:2019-08-19 19:17:05

标签: python python-3.x tensorflow python-3.6

我是Tensorflow的新手。 尝试过这个简单的例子:

import tensorflow as tf
sess = tf.Session()
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = x + y
print(sess.run(z, feed_dict={x: 3.0, y: 4.5}))

并收到一些警告The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.和正确的答案-7.5

在阅读here之后,我了解到警告是由于从tf 1.x升级到2.0而引起的,描述的步骤“简单”,但没有给出任何示例。...

我尝试过:

@tf.function
def f1(x1, y1):
    return tf.math.add(x1, y1)


print(f1(tf.constant(3.0), tf.constant(4.5)))
  1. 我的代码是否正确(按照链接中定义的含义)?
  2. 现在,我得到Tensor("PartitionedCall:0", shape=(), dtype=float32)作为输出,如何获得实际值?

2 个答案:

答案 0 :(得分:2)

您的代码确实正确。您收到的警告表明,从Tensorflow 2.0开始,API中不存在tf.Session()。因此,如果您希望代码与Tensorflow 2.0兼容,则应改用tf.compat.v1.Session。因此,只需更改此行:

sess = tf.Session()

收件人:

sess = tf.compat.v1.Session()

然后,即使将Tensorflow从1.xx更新为2.xx,您的代码也将以相同的方式执行。至于Tensorflow 2.0中的代码:

@tf.function
def f1(x1, y1):
    return tf.math.add(x1, y1)

print(f1(tf.constant(3.0), tf.constant(4.5)))

如果在Tensorflow 2.0中运行它很好。如果要运行相同的代码而不安装Tensorflow 2.0,则可以执行以下操作:

import tensorflow as tf
tf.enable_eager_execution()

@tf.function
def f1(x1, y1):
    return tf.math.add(x1, y1)

print(f1(tf.constant(3.0), tf.constant(4.5)).numpy())

这样做的原因是因为从Tensorflow 2.0开始执行Tensorflow操作的默认方式是急切模式。在Tensorflow 1.xx中激活渴望模式的方法是在导入Tensorflow之后立即启用它,就像我在上面的示例中所做的那样。

答案 1 :(得分:1)

根据Tensorflow 2.0,您的代码是正确的。 Tensorflow 2.0与numpy更加紧密地结合在一起,因此,如果要获取操作结果,可以使用numpy()方法:

print(f1(tf.constant(3.0), tf.constant(4.5)).numpy())