同时使用pytest和tf.test.TestCase的问题

时间:2018-12-18 00:06:14

标签: python tensorflow pytest

我有一个简单的单元测试,检查是否可以使用略有不同的参数实例化Tensorflow类。对于@pytest.mark.parametrize,这似乎是一个很好的用例。

但是,如果我的单元测试是parametrize的方法,我发现tf.test.TestCase被忽略了。

例如,当我在以下代码上运行pytest时:

class TestBasicRewardNet(tf.test.TestCase):                                                                                                                          
    @pytest.mark.parametrize("env", ['FrozenLake-v0', 'CartPole-v1',                                                                                               
        'CarRacing-v0', 'LunarLander-v2'])                                                                                                                           
    def test_init_no_crash(self, env):                                                                                                                               
        for i in range(3):                                                                                                                                    
            x = BasicRewardNet(env)  

我收到错误TypeError: test_init_no_crash() missing 1 required positional argument: 'env'

要解决此问题,我尝试摆脱类包装器,但是这使我错过了一些自动Tensorflow测试初始化​​的机会。特别是,现在每个BasicRewardNet都建立在相同的TensorFlow图中,因此我需要执行类似添加可变范围的操作以避免 冲突。在这个可变范围内添加似乎很困难。

@pytest.mark.parametrize("env", ['FrozenLake-v0', 'CartPole-v1',                                                                                               
     'CarRacing-v0', 'LunarLander-v2'])  
def test_init_no_crash(env):                                                                                                                                         
    for i in range(3):                                                                                                                                               
        with tf.variable_scope(env+str(i)):                                                                                                                          
            x = BasicRewardNet(env)   

我想知道是否有人知道我可以完美地兼得两全的方法?我希望能够使用parametrize并同时获得tf.test.TestCase的自动Tensorflow初始化。

1 个答案:

答案 0 :(得分:1)

hoefling的注释中所述,可以使用tf.test.TestCase.subTest来解决。

class TestBasicRewardNet(tf.test.TestCase):

    @staticmethod
    def my_sub_test(env):
        for i in range(3):                                                                                                                                               
            with tf.variable_scope(env+str(i)):                                                                                                                          
                x = BasicRewardNet(env)

    def test_init_no_crash(env):
        for env in ['FrozenLake-v0', 'CartPole-v1','CarRacing-v0', 'LunarLander-v2']:
            with self.subTest(env):                                                                                                                                                                                                                    
                 self.my_sub_test(env)

要在与subTest一起运行时能够使用pytest功能,应在需求中添加pytest-subtests ,否则就不会有它们!