tf。编译常量错误的tf.cond

时间:2017-11-30 22:23:02

标签: tensorflow constants aot xla

使用以下示例代码创建带cond:

的图表
<html>
    <body>
        <div style="border: 5px dashed crimson; color: maroon; background-color: darksalmon">
            This is a footer
        </div>
    </body>
</html>

编译:

from __future__ import absolute_import

import tensorflow as tf

from tensorflow.compiler.tf2xla.tf2xla_pb2 import Config, Feed, Fetch, TensorId
from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto


def tf2xla_config_feed( feed ):

  name = feed.name.split( ':' )[ 0 ]
  pb_id = TensorId( node_name = name )
  pb_dim = [ TensorShapeProto.Dim( size = x.value ) for x in feed.shape ]
  pb_tensor_shape_proto = TensorShapeProto( dim = pb_dim )
  pb_feed = Feed( id = pb_id, shape = pb_tensor_shape_proto )
  return pb_feed


def tf2xla_config_fetch( fetch ):

  name = fetch.name.split( ':' )[ 0 ]
  pb_id = TensorId( node_name = name )
  pb_fetch = Fetch( id = pb_id )
  return pb_fetch


def tf2xla_config( feeds, fetches ):

  pb_feeds = map( tf2xla_config_feed, feeds )
  pb_fetches = map( tf2xla_config_fetch, fetches )
  return Config( feed = pb_feeds, fetch = pb_fetches )


a = tf.placeholder( tf.float64, shape = ( 2, ), name = 'a' )

a1 = a[ 0 ]
a2 = a[ 1 ]

one = tf.constant( 1 )
two = tf.constant( 2 )

res = tf.cond( a1 < a2, lambda: one, lambda: two )

with open( 'test_graph.pb', 'wb' ) as f:
  f.write( res.graph.as_graph_def().SerializeToString() )

with open( 'test_config.pb', 'wb' ) as f:
  f.write( tf2xla_config( [ a ], [ res ] ).SerializeToString() )

导致以下错误:

  

2017-11-29 20:40:26.725164:F   tensorflow / compiler / aot / tfcompile_main.cc:140]非OK状态:状态   status:未实现:从TensorFlow图转换为XLA   导致1个不变的结果。输出args的配置   (即获取ID)可能是错误的。

看来这个错误是没有根据的?或者我做错了什么?

0 个答案:

没有答案