自定义Tensorflow操作采用可变输入张量的列表

时间:2019-05-24 15:32:42

标签: python c++ tensorflow

我正在尝试使用C ++编写自定义Tensorflow操作。此操作应将张量列表作为输入并修改其内容。我以为使用Assign操作的示例,该示例在Tensorflow代码中注册如下:

REGISTER_OP("Assign")
    .Input("ref: Ref(T)")
    .Input("value: T")
    .Output("output_ref: Ref(T)")
    .Attr("T: type")
    ...

作为参考,Assign操作(input(0))的ref是要分配给的张量,input(1)value)是其新值。输出张量(output_ref)只是对传播的input(0)的引用。

在其定义中,Assign操作还具有以下代码来检查第一个输入是可变张量:

OP_REQUIRES(context, IsRefType(context->input_type(0)),
errors::InvalidArgument("lhs input needs to be a ref type"));

与Assign操作相反,我的自定义操作应使用一个可变张量列表(而不是单个张量),其内容将被该操作修改。

我尝试通过以下方式注册我的操作:

REGISTER_OP("MyCustomOperation")
    .Input("refs: list(Ref(T))")
    .Attr("T: type")
    ...

但是在加载库时,Tensorflow给我以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Reference to unknown attr 'list' from Input("refs: list(Ref(T))") for Op MyCustomOperation

我也尝试使用属性refs: list(T)T: Ref(type),但这也不起作用(Tensorflow会显示错误Trouble parsing type string at 'Ref(type)' from Attr("T: Ref(type)"))。

所以我切换到以下注册:

REGISTER_OP("MyCustomOperation")
    .Input("refs: list(Ref(T))")
    .Attr("T: type")
    ...

但是,使用此定义,IsRefType断言失败。请注意,我在Python级别传递了tf.RefVariable的列表,我认为这是可变的。

如何使我的操作正确预期可变张量列表?

1 个答案:

答案 0 :(得分:0)

经过一番调查,我找到了一个执行此操作的示例。这是传递可变张量列表的解决方案:

REGISTER_OP("MyCustomOperation")
    .Input("refs: Ref(N * T)")
    .Attr("T: type")
    .Attr("N: int")
    ...