在单元测试中退出tf.Session()时重置默认图形

时间:2019-06-22 20:47:20

标签: python unit-testing tensorflow testing pytest

  • 在每个单元测试结束时,我调用tf.reset_default_graph()清除默认图形。
  • 但是,当单元测试失败时,该图不会被清除。这使得下一个单元测试也失败了。

如何在退出tf.Session()上下文后清除图形?

示例(pytest):

import tensorflow as tf


def test_1():
    x = tf.get_variable('x', initializer=1)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(4 / 0)
        print(sess.run(x))


def test_2():
    x = tf.get_variable('x', initializer=1)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(x))

3 个答案:

答案 0 :(得分:2)

我建议使用pytest提供的工具:

@pytest.fixture(autouse=True)
def reset():
    yield
    tf.reset_default_graph()

在每次测试之前和之后(标识autouse)都会自动调用夹具,在测试之前/之后执行yield之前/之后的代码。这样,您的问题中的测试将无需任何修改即可工作,并且您遵循DRY原理,拒绝在每个测试中编写重复的代码。另一个例子:

@pytest.fixture(autouse=True)
def init_graph():
    with tf.Graph().as_default():
        yield

将在测试执行之前为每个测试创建一个新图形。

pytest中的修复功能非常强大,如果使用得当,可以完全消除代码重复。例如,您问题中的测试等同于:

@pytest.fixture
def x():
    return tf.get_variable('x', initializer=1)


@pytest.fixture
def session(x):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        yield sess


@pytest.fixture(autouse=True)
def init_graph():
    with tf.Graph().as_default():
        yield


def test_1(session, x):
    print(4 / 0)
    print(session.run(x))


def test_2(session, x):
    print(session.run(x))

如果您想了解更多信息,请从pytest fixtures: explicit, modular, scalable开始。

答案 1 :(得分:1)

这样的作品行吗?


sudo -u hive sqoop import --connect 'jdbc:sqlserver://test.goldman-invest.data:1433;databaseName=Investment_Banking' --username user_***_cqe --password ****** --table cases --target-dir /goldman/yahoo --hive-import --create-hive-table --hive-table topclient.mpool

答案 2 :(得分:1)

直接的解决方案是使用try ... finally子句(实际上,最好将该子句放在运行单元测试的代码中,而不是直接在单元测试中) :

def test_1():
    x = tf.get_variable('x', initializer=1)
    try:
       with tf.Session() as sess:
           sess.run(tf.global_variables_initializer())
           print(4 / 0)
           print(sess.run(x))
    finally:
       tf.reset_default_graph()


def test_2():
    x = tf.get_variable('x', initializer=1)
    try:
        with tf.Session() as sess:
           sess.run(tf.global_variables_initializer())
            print(sess.run(x))
    finally:
        tf.reset_default_graph()

另一种干净的解决方案是为每个单元测试使用一个图形,如上一个答案所示。这是基于此想法的替代解决方案,语法略有简化:

def test_1():
    with tf.Graph().as_default(), tf.Session() as sess:
        x = tf.get_variable('x', initializer=1)

        sess.run(tf.global_variables_initializer())
        print(4 / 0)
        print(sess.run(x))


def test_2():
    with tf.Graph().as_default(), tf.Session() as sess:
        x = tf.get_variable('x', initializer=1)

        sess.run(tf.global_variables_initializer())
        print(sess.run(x))

类似于第一个解决方案,with语句也可以放在运行单元测试的代码周围,而不是在每个单元测试中都重复。