tensorflow是否可以正确计算complex64操作的梯度?

时间:2018-07-07 09:40:51

标签: tensorflow gradient complex-numbers

我写了一个测试用例来测试操作ops :: Square及其梯度函数SquareGrad,这是代码:

TEST(MathTest, SingleComplex64Function) {
 //
 //data prepare
 Scope root = Scope::NewRootScope();
 DataType x_type = DataTypeToEnum<complex64>::v();
 TensorShape shape({1});

 Tensor x_data(x_type, shape);
 auto x_data_flat = x_data.flat<complex64>();
 x_data_flat(0) = {1,2};

 //
 //graph construct: y = x*x, x and y are complex64
 auto x = ops::Placeholder(root, x_type, Placeholder::Shape(shape));
 auto y = ops::Square(root, x);

 ClientSession session(root);
 std::vector<Tensor> outputs;

 //
 //calculate y
 Status s = session.Run({{x, x_data}} , {y}, &outputs);
 ASSERT_TRUE(s.ok());
 ASSERT_TRUE(outputs.size() == 1);
 ASSERT_TRUE(outputs[0].dtype() == DT_COMPLEX64);
 ASSERT_TRUE(outputs[0].flat<complex64>()(0) == complex64(-3,4));

 std::vector<Output> grads;
 s = AddSymbolicGradients(root, {y}, {x}, &grads);
 ASSERT_TRUE(s.ok());

 //
 //calculate gradients of y w.r.t x
 std::vector<Tensor> grad_outputs;
 s = session.Run({{x, x_data}}, grads, &grad_outputs);
 ASSERT_TRUE(s.ok());
 ASSERT_TRUE(grad_outputs.size() == 1);
 ASSERT_TRUE(grad_outputs[0].dtype() == DT_COMPLEX64);
 ASSERT_TRUE(grad_outputs[0].flat<complex64>()(0) == complex64(2,4));
}

当我运行它时,出现此错误:

[ RUN      ] MathTest.SingleComplex64Function
tensorflow/cc/gradients/math_grad_test.cc:84: Failure
Value of: grad_outputs[0].flat<complex64>()(0) == complex64(2,4)
  Actual: false
  Expected: true
[  FAILED  ] MathTest.SingleComplex64Function (39 ms)

我发现grad_outputs [0]的计算值为(2,-4),我不知道为什么?

据我所知,我已经定义了一个函数R-> R:y = x x(x和y是complex64),它的第一个派生函数应该是:dy / dx = 2x,并且根据梯度的定义:grad y = [dy / dx],因此如果z = 1 + 2i,则y应该是(1 + 2i)(1 + 2i)= 1 + 4i + 4i i = 1 + 4i-4 = -3 + 4i,并且dy / dx = 2(1 + 2i)= 2 + 4i,因此理论上y应该是[(2,4)]。

我现在是吗?

当我开始查找问题时,我找到如下的Funciton SquareGrad代码:

Status SquareGrad(const Scope& scope, const Operation& op,
              const std::vector<Output>& grad_inputs,
              std::vector<Output>* grad_outputs) {
  // dy/dx = (2 * x)
  auto two = Cast(scope, Const(scope, 2), op.input(0).type());
  auto dydx = Mul(scope, two, op.input(0));
  // grad(x) = grad(y) * conj(dy/dx)
  grad_outputs->push_back(
      Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
  return scope.status();
}

conj(dy / dx)是dy / dx的共轭复合物,对吧?

有人可以解释为什么grad(x)= grad(y)* conj(dy / dx)吗?

0 个答案:

没有答案