我写了一个测试用例来测试操作ops :: Square及其梯度函数SquareGrad,这是代码:
TEST(MathTest, SingleComplex64Function) {
//
//data prepare
Scope root = Scope::NewRootScope();
DataType x_type = DataTypeToEnum<complex64>::v();
TensorShape shape({1});
Tensor x_data(x_type, shape);
auto x_data_flat = x_data.flat<complex64>();
x_data_flat(0) = {1,2};
//
//graph construct: y = x*x, x and y are complex64
auto x = ops::Placeholder(root, x_type, Placeholder::Shape(shape));
auto y = ops::Square(root, x);
ClientSession session(root);
std::vector<Tensor> outputs;
//
//calculate y
Status s = session.Run({{x, x_data}} , {y}, &outputs);
ASSERT_TRUE(s.ok());
ASSERT_TRUE(outputs.size() == 1);
ASSERT_TRUE(outputs[0].dtype() == DT_COMPLEX64);
ASSERT_TRUE(outputs[0].flat<complex64>()(0) == complex64(-3,4));
std::vector<Output> grads;
s = AddSymbolicGradients(root, {y}, {x}, &grads);
ASSERT_TRUE(s.ok());
//
//calculate gradients of y w.r.t x
std::vector<Tensor> grad_outputs;
s = session.Run({{x, x_data}}, grads, &grad_outputs);
ASSERT_TRUE(s.ok());
ASSERT_TRUE(grad_outputs.size() == 1);
ASSERT_TRUE(grad_outputs[0].dtype() == DT_COMPLEX64);
ASSERT_TRUE(grad_outputs[0].flat<complex64>()(0) == complex64(2,4));
}
当我运行它时,出现此错误:
[ RUN ] MathTest.SingleComplex64Function
tensorflow/cc/gradients/math_grad_test.cc:84: Failure
Value of: grad_outputs[0].flat<complex64>()(0) == complex64(2,4)
Actual: false
Expected: true
[ FAILED ] MathTest.SingleComplex64Function (39 ms)
我发现grad_outputs [0]的计算值为(2,-4),我不知道为什么?
据我所知,我已经定义了一个函数R-> R:y = x x(x和y是complex64),它的第一个派生函数应该是:dy / dx = 2x,并且根据梯度的定义:grad y = [dy / dx],因此如果z = 1 + 2i,则y应该是(1 + 2i)(1 + 2i)= 1 + 4i + 4i i = 1 + 4i-4 = -3 + 4i,并且dy / dx = 2(1 + 2i)= 2 + 4i,因此理论上y应该是[(2,4)]。
我现在是吗?
当我开始查找问题时,我找到如下的Funciton SquareGrad代码:
Status SquareGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
// dy/dx = (2 * x)
auto two = Cast(scope, Const(scope, 2), op.input(0).type());
auto dydx = Mul(scope, two, op.input(0));
// grad(x) = grad(y) * conj(dy/dx)
grad_outputs->push_back(
Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
return scope.status();
}
conj(dy / dx)是dy / dx的共轭复合物,对吧?
有人可以解释为什么grad(x)= grad(y)* conj(dy / dx)吗?