我有一些测试可以使用图形和会话。我还想用热切模式编写一些小测试来轻松测试一些功能。例如:
def test_normal_execution():
matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
iterator = dataset.make_one_shot_iterator()
first_elem = iterator.get_next()
with tf.Session() as sess:
result = sess.run(first_elem)
assert (result == [1, 2, 3, 4]).all()
sess.close()
在另一个档案中:
def test_eager_execution():
matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
tf.enable_eager_execution()
dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
iterator = dataset.__iter__()
first_elem = iterator.next()
assert (first_elem.numpy() == [1, 2, 3, 4]).all()
有没有办法解决这个问题?当我试图急切地执行测试时,我得到ValueError: tf.enable_eager_execution must be called at program startup.
。我正在使用pytest
来运行我的测试。
修改:
在接受回复的帮助很少的情况下,我创建了一个装饰器,它可以很好地适应热切模式和pytest的灯具:
def run_eagerly(func):
@functools.wraps(func)
def eager_fun(*args, **kwargs):
with tf.Session() as sess:
sess.run(tfe.py_func(func, inp=list(kwargs.values()), Tout=[]))
return eager_fun
答案 0 :(得分:7)
请注意tf.contrib
命名空间中的任何内容都是subject to change between releases,您可以使用@tf.contrib.eager.run_test_in_graph_and_eager_modes
修饰您的测试。其他一些项目,如TensorFlow Probability seem to use this。
对于非测试,需要注意的一些事项是:
tf.contrib.eager.defun
:当您启用了预先执行但希望"编译"一些计算到图中以从内存和/或性能优化中受益。tf.contrib.eager.py_func
:在没有启用预先执行但希望在图中以Python函数执行某些计算时非常有用。有人可能质疑不允许撤消对tf.enable_eager_execution()
的调用的原因。想法是库作者不应该调用它,只有最终用户应该在main()
中调用它。减少了库以不兼容的方式编写的可能性(其中一个库中的函数禁用急切执行并返回符号张量,而另一个库中的函数启用急切执行并期望具体的值张量。这会使库混合成问题)。
希望有所帮助
答案 1 :(得分:2)
有use eager execution in a graph environment的官方方式。但我不确定这对你来说是否好又方便,因为你需要写一些代码来包装和运行你的测试函数。无论如何,这是你应该至少工作的例子:
import numpy as np
import tensorflow as tf
def test_normal_execution():
matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
iterator = dataset.make_one_shot_iterator()
first_elem = iterator.get_next()
with tf.Session() as sess:
result = sess.run(first_elem)
assert (result == [1, 2, 3, 4]).all()
sess.close()
def test_eager_execution():
matrix_2x4 = np.array([[1, 2, 3, 4], [6, 7, 8, 9]])
dataset = tf.data.Dataset.from_tensor_slices(matrix_2x4)
iterator = dataset.__iter__()
first_elem = iterator.next()
assert (first_elem.numpy() == [1, 2, 3, 4]).all()
test_normal_execution()
# test_eager_execution() # Instead, you have to use the following three lines.
with tf.Session() as sess:
tfe = tf.contrib.eager
sess.run(tfe.py_func(test_eager_execution, [], []))
答案 2 :(得分:0)
有些未记录,但tensorflow 2具有用于装饰测试类的函数run_all_in_graph_and_eager_modes run_in_graph_and_eager_modes用于修饰测试方法:
view.findViewById(R.id.CardView_fridge_in_popup_fragment).setOnDragListener(new View.OnDragListener() {
@Override
public boolean onDrag(View view, DragEvent dragEvent) {
final int action = dragEvent.getAction();
switch(action) {
case DragEvent.ACTION_DRAG_STARTED:
Toast.makeText(getContext(), "received DragStarted", Toast.LENGTH_SHORT).show();
return true;
case DragEvent.ACTION_DRAG_ENTERED:
view.setBackgroundColor(getResources().getColor(R.color.colorAccent));
return true;
case DragEvent.ACTION_DRAG_EXITED: return true;
case DragEvent.ACTION_DRAG_LOCATION: return true;
case DragEvent.ACTION_DROP:
draggedFoodItem.setTable(FoodItem.TABLE_FRIDGE);
foodItemViewModel.update(draggedFoodItem);
return true;
case DragEvent.ACTION_DRAG_ENDED:
dismiss();
return true;
}
return false;
}
});
import tensorflow as tf
from tensorflow.python.framework import test_util
@test_util.run_all_in_graph_and_eager_modes
class MyTestCase(tf.test.TestCase):
#...