在Tensorflow中添加新的Op:形状函数

时间:2017-01-19 07:55:23

标签: tensorflow

我正在尝试添加新操作(使用此操作方法:https://www.tensorflow.org/how_tos/adding_an_op/)。我在示例中使用简单的Op没有问题,但我需要添加更复杂的Op。它应该有2个输入,输出与输入的矩阵乘法具有相同的形状。 如何为这种情况编写形状函数? 如何分配具有适当形状的输出? 提前谢谢。

1 个答案:

答案 0 :(得分:0)

对问题有部分答案。 我仍然不知道如何编写形状函数,但我写了这样的检查:

const Tensor& input1 = context->input(0);
const Tensor& input2 = context->input(1);
TensorShape sh1 = input1.shape();
TensorShape sh2 = input2.shape();
OP_REQUIRES(context, sh1.dim_size(1)==sh2.dim_size(0),
 errors::InvalidArgument("Can't multiplicate!"));

清分:

sh1.RemoveDim(1);
sh2.RemoveDim(0);
sh1.AppendShape(sh2);
OP_REQUIRES_OK(context, context->allocate_output(0, sh1, &output_tensor));

但好像我正在重新发明一辆自行车。这样做更容易吗?