我正在尝试使用C ++编写自定义Tensorflow操作。此操作应将张量列表作为输入并修改其内容。我以为使用Assign操作的示例,该示例在Tensorflow代码中注册如下:
REGISTER_OP("Assign")
.Input("ref: Ref(T)")
.Input("value: T")
.Output("output_ref: Ref(T)")
.Attr("T: type")
...
作为参考,Assign操作(input(0)
)的ref
是要分配给的张量,input(1)
(value
)是其新值。输出张量(output_ref
)只是对传播的input(0)
的引用。
在其定义中,Assign操作还具有以下代码来检查第一个输入是可变张量:
OP_REQUIRES(context, IsRefType(context->input_type(0)),
errors::InvalidArgument("lhs input needs to be a ref type"));
与Assign操作相反,我的自定义操作应使用一个可变张量列表(而不是单个张量),其内容将被该操作修改。
我尝试通过以下方式注册我的操作:
REGISTER_OP("MyCustomOperation")
.Input("refs: list(Ref(T))")
.Attr("T: type")
...
但是在加载库时,Tensorflow给我以下错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Reference to unknown attr 'list' from Input("refs: list(Ref(T))") for Op MyCustomOperation
我也尝试使用属性refs: list(T)
来T: Ref(type)
,但这也不起作用(Tensorflow会显示错误Trouble parsing type string at 'Ref(type)' from Attr("T: Ref(type)")
)。
所以我切换到以下注册:
REGISTER_OP("MyCustomOperation")
.Input("refs: list(Ref(T))")
.Attr("T: type")
...
但是,使用此定义,IsRefType
断言失败。请注意,我在Python级别传递了tf.RefVariable
的列表,我认为这是可变的。
如何使我的操作正确预期可变张量列表?
答案 0 :(得分:0)
经过一番调查,我找到了一个执行此操作的示例。这是传递可变张量列表的解决方案:
REGISTER_OP("MyCustomOperation")
.Input("refs: Ref(N * T)")
.Attr("T: type")
.Attr("N: int")
...