假设我有一个像这样的函数:
vector<double> compute(const vector<double>& x, vector<double> dx = boost::none) const{
if (dx)
dx = getDerivative();
// do operations on x and return result
}
,此功能有两个用例,一个用例只需要计算结果,另一个用例需要派生结果和导数,例如:
vector<double> J;
vector<double> y = compute(x);
vector<double> y = compute(x, J);
当我调用第二个版本(传递J)时,尽管dx
在compute
内部被更新,但J的值未更新。我猜这是关于按引用传递的问题,但是当我将函数签名更新为vector<double>& dx
时,出现以下错误:
error: cannot bind non-const lvalue reference of type ‘vector<double>&’ to an rvalue of type ‘vector<double>’
有什么建议可以解决这个问题吗?
答案 0 :(得分:3)
在这里,您可以使用重载代替默认参数。
从您的代码中,最好的方法是:
vector<double> compute(const vector<double>& x, vector<double> &dx) const
{
dx = getDerivative();
// do operations on x and return result
}
vector<double> compute(const vector<double>& x) const
{
vector<double> unused_dx;
compute(x, unused_dx);
}
采用这种方式编写的代码,可以避免不必要的开销,而又不需要派生代码,而仍然可以重复使用通用代码。
在某些其他情况下,如果多余的数据并不难计算,但是您希望在调用者站点使用最简单的代码,则可以采用另一种方法。看起来像:
# Variables (separate version)
W_in = tf.Variable(tf.random_normal([n_input, n_hidden]))
W_rec = tf.Variable(tf.random_normal([n_hidden, n_hidden]))
b_rec = tf.Variable(tf.random_normal([n_hidden]))
W_out = tf.Variable(tf.random_normal([n_hidden, n_output]))
b_out = tf.Variable(tf.random_normal([n_output]))
h_init = tf.zeros([1,n_hidden])
# Manual calculation of RNN output
def RNNoutput(Xinput):
h_state = h_init # initial hidden state
for iX in Xinput:
h_state = tf.nn.tanh(iX @ W_in + (h_state @ W_rec + b_rec))
rnn_output = h_state @ W_out + b_out
return(rnn_output)
但是,即使您不需要向量,在构造向量和计算导数时也会有更多开销。因此,该解决方案仅在多余的代码可忽略不计时才适用。
答案 1 :(得分:0)
vector<double> compute(const vector<double>& x, vector<double> *dx = nullptr) const{
if (dx)
*dx = getDerivative();
// do operations on x and return result
}
vector<double> y = compute(x);
vector<double> J;
vector<double> y = compute(x, &J);
答案 2 :(得分:0)
想要使用指针而不是参考的声音。例如:
vector<double> compute(const vector<double>& x, vector<double> * const dx = nullptr) const{
if (dx != nullptr)
*dx = getDerivative();
// do operations on x and return result
}
从那里,您只需要稍微调整一下通话:
vector<double> J;
vector<double> y = compute(x);
vector<double> y = compute(x, &J);
// ^ Note passing address here.