如何在tf2.0上同时进行多个提取?

时间:2019-09-02 08:53:06

标签: tensorflow tensorflow2.0

在tf1中,我可以定义所需提取的列表并使用

sess.run(myList, feed_dict)

通过图表获得tf1同时计算的列表的所有元素。在tf2.0中如何做?

tf1中的示例代码:

import tensorflow as tf
a = [None]*5
for i in range(5):
    a[i] = tf.Variable(tf.random.normal([3,3]))
fetch_list = [None]*5
for i in range(5):
    fetch_list[i] = tf.add(tf.gather(a, i), tf.ones([3,3]))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(fetch_list)

没有检查上面的代码是否运行,但我希望您明白了。谢谢

1 个答案:

答案 0 :(得分:1)

因为默认情况下tf 2.x会急切地执行,所以您可以这样做:

a = [None]*5
for i in range(5):
    a[i] = tf.Variable(tf.random.normal([3,3]))
fetch_list = [None]*5
for i in range(5):
    fetch_list[i] = tf.add(tf.gather(a, i), tf.ones([3,3]))

然后像以前一样填充fetch_list

根据实际单词示例的复杂性,您还可以考虑使用@tf.function来构建执行图,然后再将数据推送通过,以类似于tf1进行优化(这是一个极大的简化,但是你明白了)。

您可能会考虑对代码进行一些简化/重做以简化此过程。最好只在可能的情况下使用张量而不是张量列表。很难确切建议如何完成此操作,因为我不知道您为示例简化了什么。

例如,如果我们认为您的fetch_list(5,3,3)张量,而不是5个(3,3)张量的列表,那么我确定您已经意识到简化的示例代码(更多或更少)归结为以下内容:

@tf.function
def get_list(n):
  return tf.random.normal((n,3,3))

fetch_list = get_list(5)