我有一个非常简单的程序,带有Python控制流语句
@tf.function
def mandelbrot(T, max_iter):
for i in range(10):
if (tf.abs(T)) >= 4:
return 5
return max_iter
T=tf.complex(10.,2.)
mandelbrot(T, 100)
但是它不起作用,并引发大量跟踪错误。这样简单的代码有什么问题?
-------------------------------------------------- ---------------------------- AssertionError Traceback(最近的呼叫 最后) 2 T = tf.complex(10.,2。) 3 ----> 4个mandelbrot(T,100)
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ eager \ def_function.py 在通话中(自己,* args,** kwds) 424#这是 call 的第一个调用,因此我们必须进行初始化。 425 initializer_map = {} -> 426 self._initialize(args,kwds,add_initializers_to = initializer_map) 第427章 428尝试:
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ eager \ def_function.py 在_initialize(self,args,kwds,add_initializers_to)中 第368章 第369章
pylint:disable =受保护的访问权限
-> 370 * args,** kwds)) 371 372 def invalid_creator_scope(* unused_args,** unused_kwds):
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ eager \ function.py 在_get_concrete_function_internal_garbage_collected(self,* args, ** kwargs)1311 if self._input_signature:1312 args,kwargs = None,None -> 1313 graph_function,_,_ = self._maybe_define_function(args,kwargs)1314返回graph_function 1315
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ eager \ function.py 在_maybe_define_function(self,args,kwargs)1578中或 call_context_key不在self._function_cache.missed中):1579
self._function_cache.missed.add(call_context_key) -> 1580 graph_function = self._create_graph_function(args,kwargs)1581 self._function_cache.primary [cache_key] = graph_function 1582返回graph_function,args,kwargs〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ eager \ function.py 在_create_graph_function(self,args,kwargs, 1510 arg_names = arg_names,
1511 overlay_flat_arg_shapes = override_flat_arg_shapes, -> 1512 capture_by_value = self._capture_by_value),1513 self._function_attributes)1514〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ framework \ func_graph.py 在func_graph_from_py_func(name,python_func,args,kwargs,签名, func_graph,亲笔签名,autograph_options,add_control_dependencies, arg_names,op_return_value,集合,capture_by_value, Override_flat_arg_shapes) 第692章 693 -> 694 func_outputs = python_func(* func_args,** func_kwargs) 695 696#不变量:
func_outputs
仅包含张量,IndexedSlices〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ eager \ def_function.py 在wrapd_fn(* args,** kwds)中 315#包裹允许AutoGraph交换转换后的功能。我们给予 316#该函数对其自身进行弱引用以避免引用循环。 -> 317 return weak_wrapped_fn()。包裹(* args,** kwds) 第318章 319
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ framework \ func_graph.py 在包装器中(* args,** kwargs) 第684章真相 685 force_conversion = True, -> 686),args,kwargs) 687 688#环绕装饰器可以进行tf_inspect.getargspec之类的检查
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ autograph \ impl \ api.py 在convert_call(f,owner,options,args,kwargs)中 390 return _call_unconverted(f,args,kwargs) 391 -> 392结果= convert_f(* effective_args,** kwargs) 393 394#转换后的函数的闭包仅插入到函数的闭包中
〜\ AppData \ Local \ Temp \ tmp95dcry6m.py in tf__mandelbrot(T,max_iter) 20 retval__1,do_return_1 = ag __。if_stmt(cond,if_true,if_false) 21 return retval__1,do_return_1 ---> 22 retval_,do_return = ag __。for_stmt(ag __。converted_call(range,None, ag __。ConversionOptions(recursive = True,verbose = 0, strip_decorators =(tf.function,defun,ag __。convert, ag __。do_not_convert,ag __。converted_call),force_conversion = False, optional_features = {),internal_convert_user_code = True),(10,),{}), extra_test,loop_body((retval_,do_return)) 23 cond_1 = ag __。not_(do_return) 24
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ autograph \ operators \ control_flow.py 在for_stmt(iter_,extra_test,body,init_state)中 79 return _dataset_for_stmt(iter_,extra_test,body,init_state) 其他80个 ---> 81 return _py_for_stmt(iter_,extra_test,body,init_state) 82 83
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ autograph \ operators \ control_flow.py 在_py_for_stmt中(iter_,extra_test,body,init_state) 88如果extra_test不为None并且not extra_test(* state): 89休息 ---> 90状态=正文(目标,*状态) 91返回状态 92
〜\ AppData \ Local \ Temp \ tmp95dcry6m.py在loop_body(loop_vars,retval__1, do_return_1) 18 def if_false(): 19 return retval__1,do_return_1 ---> 20 retval__1,do_return_1 = ag __。if_stmt(cond,if_true,if_false) 21 return retval__1,do_return_1 22 retval_,do_return = ag __。for_stmt(ag __。converted_call(range,None, ag __。ConversionOptions(recursive = True,verbose = 0, strip_decorators =(tf.function,defun,ag __。convert, ag __。do_not_convert,ag __。converted_call),force_conversion = False, optional_features = {),internal_convert_user_code = True),(10,),{}), extra_test,loop_body,(retval_,do_return))
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ autograph \ operators \ control_flow.py 在if_stmt(cond,body,orelse)中 243“”“ 244如果tensor_util.is_tensor(cond): -> 245 return tf_if_stmt(cond,body,orelse) 246其他: 247 return _py_if_stmt(cond,body,orelse)
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ autograph \ operators \ control_flow.py 在tf_if_stmt(cond,body,orelse)中 254 branch_name ='else') 255 -> 256返回control_flow_ops.cond(cond,protected_body,protected_orelse) 257 258
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ util \ deprecation.py 在new_func(* args,** kwargs)中 505'在将来的版本中',如果日期为其他(在%s之后的%日期之后), 506条指令) -> 507 return func(* args,** kwargs) 508 509文档= _add_deprecated_arg_notice_to_docstring(
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ ops \ control_flow_ops.py 在cond(pred,true_fn,false_fn,strict,name,fn1,fn2)中1916如果 (util.EnableControlFlowV2(ops.get_default_graph())和1917
不是context.executing_eagerly()): -> 1918 return cond_v2.cond_v2(pred,true_fn,false_fn,name)1919 1920#我们需要使true_fn / false_fn关键字参数 为〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ ops \ cond_v2.py 在cond_v2中(pred,true_fn,false_fn,名称) 84 true_graph.external_captures, 85 false_graph.external_captures, -> 86 name = scope) 87 88
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ ops \ cond_v2.py 在_build_cond中(pred,true_graph,false_graph,true_inputs, false_inputs,名称) 185个中间输出。 186“”“ -> 187 _check_same_outputs(true_graph,false_graph) 188 189#将输入添加到true_graph和false_graph以使其匹配。请注意
〜.conda \ envs \ alphagpu \ lib \ site-packages \ tensorflow \ python \ ops \ cond_v2.py 在_check_same_outputs(true_graph,false_graph)中 584错误(str(e)) 585 -> 586断言len(true_graph.outputs)== len(false_graph.outputs) 587中的true_out,false_out(true_graph.outputs,false_graph.outputs): 588如果true_out.dtype!= false_out.dtype:
AssertionError:
答案 0 :(得分:2)
看起来像2.0尚不能处理早期的有条件回报。我想这会在某个时候解决(可以随意检查是否有自己的错误报告/文件),但与此同时,以下对我有用。它不允许提早退出,但至少应给出正确的结果。
@tf.function
def mandelbrot(T, max_iter):
out = max_iter
for i in range(10):
if (tf.abs(T)) >= 4:
out = 5
return out
T = tf.complex(10.,2.)
m = mandelbrot(T, 100)
对于多个T
值,我认为您不得不求助于tf.where
def mandelbrot(T, max_iter):
ones = tf.ones(tf.shape(T), dtype=tf.int64)
out = ones * max_iter
fives = ones * 5
for i in range(10):
out = tf.where(tf.greater_equal(tf.abs(T), 4), fives, out)
return out
您可以使用tf.while_loop
和tf.TensorArray
做更复杂的事情,但是我怀疑其中会涉及一些开销,这会使小问题变得更昂贵(并且代码复杂度不会琐碎的。)
请注意,这不是计算mandelbrot集的方法-我假设这是因为您已将其简化为一个最小的示例。 T
从未在此更新,因此您可以删除i
上的循环。