使用以下示例代码创建带cond:
的图表<html>
<body>
<div style="border: 5px dashed crimson; color: maroon; background-color: darksalmon">
This is a footer
</div>
</body>
</html>
编译:
from __future__ import absolute_import
import tensorflow as tf
from tensorflow.compiler.tf2xla.tf2xla_pb2 import Config, Feed, Fetch, TensorId
from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
def tf2xla_config_feed( feed ):
name = feed.name.split( ':' )[ 0 ]
pb_id = TensorId( node_name = name )
pb_dim = [ TensorShapeProto.Dim( size = x.value ) for x in feed.shape ]
pb_tensor_shape_proto = TensorShapeProto( dim = pb_dim )
pb_feed = Feed( id = pb_id, shape = pb_tensor_shape_proto )
return pb_feed
def tf2xla_config_fetch( fetch ):
name = fetch.name.split( ':' )[ 0 ]
pb_id = TensorId( node_name = name )
pb_fetch = Fetch( id = pb_id )
return pb_fetch
def tf2xla_config( feeds, fetches ):
pb_feeds = map( tf2xla_config_feed, feeds )
pb_fetches = map( tf2xla_config_fetch, fetches )
return Config( feed = pb_feeds, fetch = pb_fetches )
a = tf.placeholder( tf.float64, shape = ( 2, ), name = 'a' )
a1 = a[ 0 ]
a2 = a[ 1 ]
one = tf.constant( 1 )
two = tf.constant( 2 )
res = tf.cond( a1 < a2, lambda: one, lambda: two )
with open( 'test_graph.pb', 'wb' ) as f:
f.write( res.graph.as_graph_def().SerializeToString() )
with open( 'test_config.pb', 'wb' ) as f:
f.write( tf2xla_config( [ a ], [ res ] ).SerializeToString() )
导致以下错误:
2017-11-29 20:40:26.725164:F tensorflow / compiler / aot / tfcompile_main.cc:140]非OK状态:状态 status:未实现:从TensorFlow图转换为XLA 导致1个不变的结果。输出args的配置 (即获取ID)可能是错误的。
看来这个错误是没有根据的?或者我做错了什么?