在tf1中,我可以定义所需提取的列表并使用
sess.run(myList, feed_dict)
通过图表获得tf1同时计算的列表的所有元素。在tf2.0中如何做?
tf1中的示例代码:
import tensorflow as tf
a = [None]*5
for i in range(5):
a[i] = tf.Variable(tf.random.normal([3,3]))
fetch_list = [None]*5
for i in range(5):
fetch_list[i] = tf.add(tf.gather(a, i), tf.ones([3,3]))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(fetch_list)
没有检查上面的代码是否运行,但我希望您明白了。谢谢
答案 0 :(得分:1)
因为默认情况下tf 2.x会急切地执行,所以您可以这样做:
a = [None]*5
for i in range(5):
a[i] = tf.Variable(tf.random.normal([3,3]))
fetch_list = [None]*5
for i in range(5):
fetch_list[i] = tf.add(tf.gather(a, i), tf.ones([3,3]))
然后像以前一样填充fetch_list
。
根据实际单词示例的复杂性,您还可以考虑使用@tf.function
来构建执行图,然后再将数据推送通过,以类似于tf1进行优化(这是一个极大的简化,但是你明白了)。
您可能会考虑对代码进行一些简化/重做以简化此过程。最好只在可能的情况下使用张量而不是张量列表。很难确切建议如何完成此操作,因为我不知道您为示例简化了什么。
例如,如果我们认为您的fetch_list
是(5,3,3)
张量,而不是5个(3,3)
张量的列表,那么我确定您已经意识到简化的示例代码(更多或更少)归结为以下内容:
@tf.function
def get_list(n):
return tf.random.normal((n,3,3))
fetch_list = get_list(5)