Tensorflow Multiple Graph&补丁关注

时间:2018-04-02 22:25:35

标签: python tensorflow graph deep-learning

我的情况是我使用两个不同的网络,一个网络告诉我在给定的补丁中是否有重要信息,另一个网络告诉我补丁中的重要信息在哪里使用分段。

如果我在同一个TF图形/会话中操作它们,我最终必须使用tf.where或tf.cond来告诉我实际上想要使用哪些补丁,但我的优化器正在为每个条件创建渐变整个网,或者至少那是我的工作理论。

这是使用segmentation_logit = tf.where(is_useful_patch,coarse_log,negative_log) 负日志是与粗略logit相同形状的0的张量。

如果我使用192(128x128)补丁,优化器会尝试创建一个参数超过1亿的参数(例如:[192,222,129,128]),它会破坏我的GPU内存并导致崩溃。

因此,如果没有实际定义两个不同的会话,图形,储蓄器,恢复器和张量板编写器,是否有更好的方法来解决这个问题,更好的方法来计算渐变,或者在同一个会话中组合多个图形的方法?

提前致谢!

1 个答案:

答案 0 :(得分:0)

我假设你得到一个192长的is_useful_patch向量,其值为0到1(概率),这是第一个网络的结果。

首先,忘记tf.cond的{​​{1}}。我建议采用较小的数字,例如16个左右(根据您的经验,通常有多少有用的补丁),并使用tf.nn.top_k这样的最佳16个补丁的索引:< / p>

tf.where

然后使用tf.gather_nd收集最佳补丁:

values, idx_best_patches = tf.nn.top_k( is_useful_patch,
    k = 16, sorted = False, name = 'idx_best_patches' )

这将收集你最好的16个补丁,然后你继续只有那些16进入分段器,而不是192,只是将分段器的内存需求减少到1/12。这是这个想法的核心。

如果少于16个补丁包含有用信息,则可以屏蔽部分输出。此外,我不知道您的补丁是如何构建的,因此请务必查看tf.gather_nd参数的正确性,这可能会非常棘手。