我正在使用本教程编写自定义Tensorflow操作系统,而且我无法理解如何读取和写入Tensors。
让我说我的OpKernel中有一个Tensor
const Tensor& values_tensor = context->input(0);
(其中context = OpKernelConstruction*
)
如果Tensor有形状,比如说[2,10,20],我怎样才能将其编入索引(例如auto x = values_tensor[1, 4, 12]
等)?
等价,如果我有
Tensor *output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(
0,
{batch_size, value_len - window_size, window_size},
&output_tensor
));
如何分配到output_tensor
,如output_tensor[1, 2, 3] = 11
等?
对不起这个愚蠢的问题,但是文档真的让我吵架了,内置操作的Tensorflow内核代码中的示例以某种方式混淆了这一点,我感到非常困惑:)
谢谢你!答案 0 :(得分:1)
读取和写入tensorflow::Tensor
个对象的最简单方法是使用Eigen tensor方法将它们转换为tensorflow::Tensor::tensor<T, NDIMS>()
。请注意,您必须将张量中的(C ++)元素类型指定为模板参数T
。
例如,要从DT_FLOAT32
张量中读取特定值:
const Tensor& values_tensor = context->input(0);
auto x = value_tensor.tensor<float, 3>()(1, 4, 12);
要将特定值写入DT_FLOAT32
张量:
Tensor* output_tensor = ...;
output_tensor->tensor<float, 3>()(1, 2, 3) = 11.0;