张量流测试中的急切和图形执行

时间:2018-05-02 21:36:08

标签: unit-testing tensorflow

我有一些测试可以使用图形和会话。我还想用热切模式编写一些小测试来轻松测试一些功能。例如:

def test_normal_execution():
    matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
    dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
    iterator = dataset.make_one_shot_iterator()
    first_elem = iterator.get_next()
    with tf.Session() as sess:
        result = sess.run(first_elem)
        assert (result == [1, 2, 3, 4]).all()
    sess.close()

在另一个档案中:

def test_eager_execution():
    matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
    tf.enable_eager_execution()
    dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
    iterator = dataset.__iter__()
    first_elem = iterator.next()
    assert (first_elem.numpy() == [1, 2, 3, 4]).all() 

有没有办法解决这个问题?当我试图急切地执行测试时,我得到ValueError: tf.enable_eager_execution must be called at program startup.。我正在使用pytest来运行我的测试。

修改

在接受回复的帮助很少的情况下,我创建了一个装饰器,它可以很好地适应热切模式和pytest的灯具:

def run_eagerly(func):
    @functools.wraps(func)
    def eager_fun(*args, **kwargs):
        with tf.Session() as sess:
            sess.run(tfe.py_func(func, inp=list(kwargs.values()), Tout=[]))

    return eager_fun

3 个答案:

答案 0 :(得分:7)

请注意tf.contrib命名空间中的任何内容都是subject to change between releases,您可以使用@tf.contrib.eager.run_test_in_graph_and_eager_modes修饰您的测试。其他一些项目,如TensorFlow Probability seem to use this

对于非测试,需要注意的一些事项是:

  • tf.contrib.eager.defun:当您启用了预先执行但希望"编译"一些计算到图中以从内存和/或性能优化中受益。
  • tf.contrib.eager.py_func:在没有启用预先执行但希望在图中以Python函数执行某些计算时非常有用。

有人可能质疑不允许撤消对tf.enable_eager_execution()的调用的原因。想法是库作者不应该调用它,只有最终用户应该在main()中调用它。减少了库以不兼容的方式编写的可能性(其中一个库中的函数禁用急切执行并返回符号张量,而另一个库中的函数启用急切执行并期望具体的值张量。这会使库混合成问题)。

希望有所帮助

答案 1 :(得分:2)

use eager execution in a graph environment的官方方式。但我不确定这对你来说是否好又方便,因为你需要写一些代码来包装和运行你的测试函数。无论如何,这是你应该至少工作的例子:

import numpy as np
import tensorflow as tf

def test_normal_execution():
    matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
    dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
    iterator = dataset.make_one_shot_iterator()
    first_elem = iterator.get_next()
    with tf.Session() as sess:
        result = sess.run(first_elem)
        assert (result == [1, 2, 3, 4]).all()
    sess.close()

def test_eager_execution():
    matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
    dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
    iterator = dataset.__iter__()
    first_elem = iterator.next()
    assert (first_elem.numpy() == [1, 2, 3, 4]).all()

test_normal_execution()
# test_eager_execution() # Instead, you have to use the following three lines.
with tf.Session() as sess:
    tfe = tf.contrib.eager
    sess.run(tfe.py_func(test_eager_execution, [], []))

答案 2 :(得分:0)

有些未记录,但tensorflow 2具有用于装饰测试类的函数run_all_in_graph_and_eager_modes run_in_graph_and_eager_modes用于修饰测试方法:

view.findViewById(R.id.CardView_fridge_in_popup_fragment).setOnDragListener(new View.OnDragListener() {
       @Override
       public boolean onDrag(View view, DragEvent dragEvent) {

           final int action = dragEvent.getAction();
           switch(action) {
               case DragEvent.ACTION_DRAG_STARTED:
                   Toast.makeText(getContext(), "received DragStarted", Toast.LENGTH_SHORT).show();
                    return true;

               case DragEvent.ACTION_DRAG_ENTERED:
                   view.setBackgroundColor(getResources().getColor(R.color.colorAccent));
                   return true;

               case DragEvent.ACTION_DRAG_EXITED: return true;

               case DragEvent.ACTION_DRAG_LOCATION: return true;

               case DragEvent.ACTION_DROP:
                   draggedFoodItem.setTable(FoodItem.TABLE_FRIDGE);
                   foodItemViewModel.update(draggedFoodItem);
                   return true;

               case DragEvent.ACTION_DRAG_ENDED:
                   dismiss();
                   return true;
           }
            return false;
       }
   });
import tensorflow as tf
from tensorflow.python.framework import test_util

@test_util.run_all_in_graph_and_eager_modes
class MyTestCase(tf.test.TestCase):
#...