Tensorflow:是否可以打印会话数?

时间:2019-08-30 11:06:20

标签: python tensorflow

张量流中可能有多个会话吗?可以在Tensorflow中打印会话数吗?

def test_print_number_of_sessions():
    sess1 = tf.Session()
    sess2 = tf.Session()

    //print_number_of_sessions

2 个答案:

答案 0 :(得分:1)

每个图可以有多个会话,但是没有直接的方法来获取图中的所有打开的会话。该图的内部C数据结构确实具有a collection with all the existing sessions,但是不幸的是,它的Python对应部分(._c_graph对象的tf.Graph属性)只是一个没有类型信息的不透明指针。

一种可能的解决方案是使用您自己的会话包装器,该包装器跟踪每个图形的打开会话。这是一种可行的方法。

import tensorflow as tf
import collections

class TrackedSession(tf.Session):
    _sessions = collections.defaultdict(list)
    def __init__(self, target='', graph=None, config=None):
        super(tf.Session, self).__init__(target=target, graph=graph, config=config)
        TrackedSession._sessions[self.graph].append(self)
    def close(self):
        super(tf.Session, self).close()
        TrackedSession._sessions[self.graph].remove(self)
    @classmethod
    def get_open_sessions(cls, g=None):
        g = g or tf.get_default_graph()
        return list(cls._sessions[g])

print(TrackedSession.get_open_sessions())
# []
sess1 = TrackedSession()
print(TrackedSession.get_open_sessions())
# [<__main__.TrackedSession object at 0x000001D75B0C77F0>]
sess2 = TrackedSession()
print(TrackedSession.get_open_sessions())
# [<__main__.TrackedSession object at 0x000001D75B0C77F0>, <__main__.TrackedSession object at 0x000001D75B0C7A58>]
sess1.close()
print(TrackedSession.get_open_sessions())
# [<__main__.TrackedSession object at 0x000001D75B0C7A58>]
sess2.close()
print(TrackedSession.get_open_sessions())
# []

但是,这限制了您使用此自定义会话类型,根据情况的不同,这种类型可能不够好(例如,如果该会话是由某些外部代码打开的,例如使用Keras时)。

答案 1 :(得分:0)

愿这一切都好:

tf.InteractiveSession._active_session_count