将自定义tensorflow操作转换为tensorflow 2.0

时间:2019-11-26 17:25:54

标签: python c++ tensorflow

在教程here之后,我有两个用C ++编写的自定义Tensorflow操作。我的操作如下导入到python中:

import tensorflow as tf
myops = tf.load_op_library("libmyops.so")
opA = myops.op_a
opB = myops.op_b

使用Tensorflow 1.x,我将按以下方式调用这些操作:

my_opA_instance = opA(arg1=..., arg2=...)
my_opA_instance.run(sesstion=tf.get_session())

我使用run而不是eval,因为这些操作不会输出任何张量。

OpA是一种操作,它具有一些字符串属性和张量列表。它在其C ++文件中注册如下:

REGISTER_OP("OpA")
    .Attr("arg1: string")
    .Attr("arg2: string")
    .Input("tensors: T")
    .Attr("T: list(type)")
    .SetIsStateful()
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
            return Status::OK();
    });

OpB相似,但采用张量引用列表,因此,我没有.Input("tensors: T").Attr("T: list(type)")了:

.Input("tensors: Ref(N * T)")
.Attr("T: type")
.Attr("N: int")

当我尝试在Tensorflow 2.0中使用我的代码时,我的行为发生了以下变化:

get_session()不再存在,显然,急切的执行模式现在使opA(arg1=..., arg2=...)实际上执行了该操作,而不是实例化该操作。无需调用run()。这意味着我无法实例化该操作一次,并且无法通过多次调用run()来重用它。真的不是问题。

更麻烦的是,使用OpB时,调用它时出现以下错误:RuntimeError: op_b op does not support eager execution. Arg 'tensors' is a ref. 我试图在调用它之前禁用急切执行,但这会给我以下错误: RuntimeError: Attempting to capture an EagerTensor without building a function.

我应该如何在Tensorflow 2.0中使用OpB?

0 个答案:

没有答案