我的团队正在为TensorFlow开发新的后端。通常,tensorflow opkernels作为参数传递" Tensor"使用我们的架构分配的内存的类型:
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// Grab the input tensors
const Tensor& A = context->input(0);
const Tensor& B = context->input(1);
// ...input validation...
const our::Memory_Type *bA = static_cast<const our::Memory_Type *>(DMAHelper::base(&A));
const our::Memory_Type *bB = static_cast<const our::Memory_Type *>(DMAHelper::base(&B));
// ...additional preconditioning...
// Create an output tensor
Tensor *C = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &C));
our::Memory_Type *bC = static_cast<kpi::KPI_RMR_Mem*>(DMAHelper::base(C));
//and run it
our_impl(bA, bB, bC, done);
}
然而,我们在移植&#34; CrossOp&#34; type,因为部分预处理涉及将数据类型转换为Eigen类型:
// in0, in1, and output are all tensorflow::Tensor types, but ConstTensor is an Eigen type
typename TTypes<Type, 2>::ConstTensor in0_data =
in0.flat_inner_dims<Type>();
typename TTypes<Type, 2>::ConstTensor in1_data =
in1.flat_inner_dims<Type>();
typename TTypes<Type, 2>::Tensor output_data =
output->flat_inner_dims<Type>();
DMAHelper::base()
假设它在Tensor
而不是ConstTensor
上运行。使用以下操作遵循上述操作是否安全,或者flat_inner_dims()
的过程是否会更改基础数据的内容,使得结果无效或TensorFlow无法读取?
const our::Memory_Type *in0_arg = static_cast<const our::Memory_Type *>(DMAHelper::base(&in0));
const our::Memory_Type *in1_arg = static_cast<const our::Memory_Type *>(DMAHelper::base(&in1));
const our::Memory_Type *output_arg = static_cast<const our::Memory_Type *>(DMAHelper::base(&output));
our_cross_impl(in0_arg, in1_arg, output_arg);