检查Tensorflow - C ++ API中的广播兼容性

时间:2016-02-06 04:42:56

标签: tensorflow

我在TensorFlow中实现了元素操作。许多TensorFlow操作,例如添加,支持广播(from numpy)。如果遵守以下规则,则可以进行广播:

  

当在两个张量上操作时,它们的形状应该在元素方面进行比较。该过程从尾随维度开始,并向前发展。当两个维度相等或者其中一个为1时,它们是兼容的。如果不满足这些条件,则抛出异常,表明张量具有不兼容的形状。结果张量的大小是输入数组每个维度的最大大小。

TensorFlow C ++ API 是否提供了比较两个张量的兼容性的方法?或者,这是最快的方法吗?

1 个答案:

答案 0 :(得分:1)

所有元素二进制运算' TensorFlow中的内核实现派生自BinaryOpShared类,它通过辅助类BinaryOpState进行兼容性检查。也许,您可以简单地从BinaryOpShared派生您的内核类,并免费获得兼容性检查。