Tensorflow tfcompile:获取渐变

时间:2017-07-27 10:01:16

标签: tensorflow

我创建了一个非常简单的张量流模型,我在其中获取渐变:

# tf Graph Input
X = tf.placeholder(tf.float32, [1, 2], name="X")
Y = tf.placeholder(tf.float32, [1, 2], name="Y")

# Model parameter variables
W = tf.Variable([[1.0, 2.0], [3.0, 4.0]], name="weight")
B = tf.Variable([[5.0, 6.0]], name="bias")

# Construct a multivariate linear model
matmul = tf.matmul(X, W, name="matrixMul")
pred = tf.add(matmul, B, name="addition")

# Mean squared error
cost = tf.reduce_sum(tf.pow(pred-Y, 2) / 2 )

# Fetch gradients
grads = tf.gradients(cost, [W, B])

我将此图表导出到protobuf中,现在我使用tfcompile进行AOT编译。我想在C ++程序中使用已编译的图形并获取计算的渐变。 tfcompile的配置文件如下所示:

feed {
  id { node_name: "X" }
      shape {
        dim { size: 1 }
        dim { size: 2 }  
      }
  name: "x"
}
feed {
  id { node_name: "Y" }
  shape {
    dim { size: 1 }
    dim { size: 2 }
  }
  name: "y"
}
feed {
  id { node_name: "weight" }
  shape {
    dim { size: 2 }
    dim { size: 2 }
  }
  name: "w"
}
feed {
  id { node_name: "bias" }
  shape {
    dim { size: 1 }       
    dim { size: 2 }
  }
  name: "b"
}
fetch {
   id { node_name:   "addition"}
   name: "prediction"
}
fetch {
  id { node_name:   "gradients/matrixMul_grad/MatMul_1"}
  name: "weight_grad"
}
fetch {
  id { node_name:   "gradients/addition_grad/Reshape"}
  name: "bias_grad"
}

最后我运行这个C ++代码:

obj.set_arg_x_data(x.data());
obj.set_arg_y_data(y.data());
obj.set_arg_w_data(w.data());
obj.set_arg_b_data(b.data());

obj.Run();

std::cout << "result_prediction =" << std::endl  << obj.result_prediction(0,0) << " " << obj.result_prediction(0,1) << std::endl;
std::cout << "result_weight_grad =" << std::endl << obj.result_weight_grad(0,0) << " " << obj.result_weight_grad(0,1) << " " << obj.result_weight_grad(1,0) << " " << obj.result_weight_grad(1,1) << std::endl;
std::cout << "result_bias_grad =" << std::endl  <<  obj.result_bias_grad(0,0) << " " <<  obj.result_bias_grad(0,1) << std::endl;

对于result_predictionresult_bias_grad,我获得了预期值。 仅针对result_weight_grad我只得到0,0,0,0。

也许我在那里找错了节点:

fetch {
  id { node_name:   "gradients/matrixMul_grad/MatMul_1"}
  name: "weight_grad"
}

有人尝试过取得计算的渐变吗? Tensorflow仅提供使用tfcompile进行预测的示例。

0 个答案:

没有答案