将TensorFlow sess.run转换为@ tf.function

时间:2019-10-26 18:49:15

标签: python tensorflow tensorflow2.0

如何编辑sessions.run函数,使其在Tensorflow 2.0上运行?

  with tf.compat.v1.Session(graph=graph) as sess:
    start = time.time()
    results = sess.run(output_operation.outputs[0],
                      {input_operation.outputs[0]: t})

我阅读了文档over here,并了解到您必须更改如下功能:

  normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
  sess = tf.compat.v1.Session()
  result = sess.run(normalized)

  return result

对此:

def myFunctionToReplaceSessionRun(resized,input_mean,input_std):
    return tf.divide(tf.subtract(resized, [input_mean]), [input_std])

normalized = myFunctionToReplaceSessionRun(resized,input_mean,input_std)

但是我不知道如何更改第一个。

这里有点上下文,我正在尝试this代码实验室,并发现sess.run给我带来了麻烦。

This is the command line output when running the label_images file.

And this is the function that gave errors.

1 个答案:

答案 0 :(得分:0)

使用TensorFlow 1.x,我们创建了tf.placeholder张量,数据可以通过它们进入图形。我们将feed_dict=对象与tf.Session()一起使用。

在TensorFlow 2.0中,由于默认情况下急于执行,我们可以将数据直接馈送到图形。使用@tf.function批注,我们可以将函数直接包含在图形中。 official docs说,

  

此合并的中心是tf.function,它使您可以   将Python语法的子集转换为可移植的高性能   TensorFlow图。

这是文档中的一个简单示例,

@tf.function
def simple_nn_layer(x, y):
  return tf.nn.relu(tf.matmul(x, y))


x = tf.random.uniform((3, 3))
y = tf.random.uniform((3, 3))

simple_nn_layer(x, y)

现在,调查您的问题,您可以将函数转换为

@tf.function
def get_output_operation( input_op ):
    # The function goes here
    # from here return `results`

results = get_output_operation( some_input_op )

简而言之,占位符张量转换为函数参数,tensor中的sess.run( tensor )由函数返回。所有这些都发生在带有@tf.function注释的函数中。