我需要推断一个包含两个子网的网络,并且我试图避免从两个GPU中复制中间数据。第一个子网使用数据并在一个[242,324,32]数组中累积结果(输入数据可以按任何顺序到达此子网)。馈入一定数量的数据后,我想运行另一个子网,该子网使用相同的数组作为输入并计算输出预测。根据结果,可以从外部产生新数据并将其馈送到第一个子网(这将导致中间阵列的某些元素发生修改,然后第二个子网将需要重新计算一些取决于中间阵列更改的输出。非常简单,因此很容易计算出需要重新评估的输出。
如果中间阵列位于CPU内存中并且两个子网分别在单独的session.run()调用中运行,则实现将非常简单,但最好将其保留在GPU中以减少内存流量,尤其是在更改时该数组中可能只有几个元素。
对于中间数组,正确使用什么变量(Variable,ResourceVariable或tf.contrib.eager.Variable,如果与ResourceVariable不同)是正确的?
在所有数据馈入第一个子网并将结果传播到该中间阵列之前,如何控制阶段2不运行?
馈入第一个子网的数据量可能很大,因此我真的希望第一个子网在第二个子网消耗掉之前填充中间阵列(并且该数组应该是持久性的,因为可能会有更多的整体迭代需要)。