我正在用C ++为tensrflow编写新的操作 是Linal.cc:
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include<vector>
using namespace tensorflow;
REGISTER_OP("Linal")
//.Attr("preserve_index: int")
.Input("x: float")
.Input("a: float")
.Output("y: float");
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class Linl: public OpKernel {
public:
explicit Linal(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
//Get x
const Tensor& input_tensor = context->input(0);
//size_t size_of_x = input_tensor.shape().dim_size(1);
auto input = input_tensor.flat<float>();
size_t size_of_x = input.size();
std::vector<float> x;
for(int i = 0; i < size_of_x; ++i){
x.push_back(input(i));
}
//Get a
const Tensor& a_tensor = context->input(1);
auto a_vector = a_tensor.flat<float>();
float a = a_vector(0);
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output_flat = output_tensor->flat<float>();
for(int i = 0; i < size_of_x; ++i){
output_flat(i) = a * x[i];
}
}
};
REGISTER_KERNEL_BUILDER(Name("Linal").Device(DEVICE_CPU), Linal);
我想检查自己的渐变是否像张量流渐变一样工作。 它是testing.py
import numpy as np
import tensorflow as tf
import modul_gradient
from tensorflow.python.framework import ops
@tf.RegisterGradient("Linal")
def modgrad(op, grad):
x = op.inputs[0] # the first argument (normally you need those to calculate the gradient, like the gradient of x^2 is 2x. )
a = op.inputs[1] # the second argument
return grad * a, x * grad
Linal_module = tf.load_op_library('./LUT.so')
with tf.Session() as sess:
#get x and a
x = tf.constant([0.2, 1.2, 6, 7, 8])
a = tf.constant([3.2])
#describe the calculation graph in tf
product = LUT_module.lut(x, a)
product_tf = tf.multiply(x, a)
#calculate gradient for our op and default gradient
gr = tf.gradients(product, [x, a])
gr_tf = tf.gradients(product_tf, [x, a])
tf.initialize_all_variables().run()
#this radients should be equal
print(gr[1].eval())
print(gr_tf[1].eval())
但是我得到了ValueError:运算输入和计算的输入梯度之间的形状不兼容。正向操作:LUT。输入索引:1.原始输入形状:(1,)。计算出的输入梯度形状:(5,)
请你帮我一下。我真的不明白怎么了