tf.reset_default_graph()
清除默认图形。 如何在退出tf.Session()
上下文后清除图形?
示例(pytest):
import tensorflow as tf
def test_1():
x = tf.get_variable('x', initializer=1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(4 / 0)
print(sess.run(x))
def test_2():
x = tf.get_variable('x', initializer=1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(x))
答案 0 :(得分:2)
我建议使用pytest
提供的工具:
@pytest.fixture(autouse=True)
def reset():
yield
tf.reset_default_graph()
在每次测试之前和之后(标识autouse
)都会自动调用夹具,在测试之前/之后执行yield
之前/之后的代码。这样,您的问题中的测试将无需任何修改即可工作,并且您遵循DRY原理,拒绝在每个测试中编写重复的代码。另一个例子:
@pytest.fixture(autouse=True)
def init_graph():
with tf.Graph().as_default():
yield
将在测试执行之前为每个测试创建一个新图形。
pytest
中的修复功能非常强大,如果使用得当,可以完全消除代码重复。例如,您问题中的测试等同于:
@pytest.fixture
def x():
return tf.get_variable('x', initializer=1)
@pytest.fixture
def session(x):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
yield sess
@pytest.fixture(autouse=True)
def init_graph():
with tf.Graph().as_default():
yield
def test_1(session, x):
print(4 / 0)
print(session.run(x))
def test_2(session, x):
print(session.run(x))
如果您想了解更多信息,请从pytest fixtures: explicit, modular, scalable开始。
答案 1 :(得分:1)
这样的作品行吗?
sudo -u hive sqoop import --connect 'jdbc:sqlserver://test.goldman-invest.data:1433;databaseName=Investment_Banking' --username user_***_cqe --password ****** --table cases --target-dir /goldman/yahoo --hive-import --create-hive-table --hive-table topclient.mpool
答案 2 :(得分:1)
直接的解决方案是使用try
... finally
子句(实际上,最好将该子句放在运行单元测试的代码中,而不是直接在单元测试中) :
def test_1():
x = tf.get_variable('x', initializer=1)
try:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(4 / 0)
print(sess.run(x))
finally:
tf.reset_default_graph()
def test_2():
x = tf.get_variable('x', initializer=1)
try:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(x))
finally:
tf.reset_default_graph()
另一种干净的解决方案是为每个单元测试使用一个图形,如上一个答案所示。这是基于此想法的替代解决方案,语法略有简化:
def test_1():
with tf.Graph().as_default(), tf.Session() as sess:
x = tf.get_variable('x', initializer=1)
sess.run(tf.global_variables_initializer())
print(4 / 0)
print(sess.run(x))
def test_2():
with tf.Graph().as_default(), tf.Session() as sess:
x = tf.get_variable('x', initializer=1)
sess.run(tf.global_variables_initializer())
print(sess.run(x))
类似于第一个解决方案,with
语句也可以放在运行单元测试的代码周围,而不是在每个单元测试中都重复。