在下面的代码中,我想使用参数来保存两种不同的返回类型。所以我可以删除多余的代码。但是,我这里没有很好的解决方案。
我的版本:
if (...) {
auto params = gather_quantized_params(_params);
// the following lines are just duplicated in different branches
auto results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
return results;
} else {
auto params = gather_quantized_params_fp16(_params);
auto results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
return results
}
=== 相关功能的标题:
static std::vector<QuantizedCellParamsFP16>
gather_quantized_params_fp16(TensorList params) {
...}
static std::vector<QuantizedCellParams>
gather_quantized_params(TensorList params) {
...}
template<template<typename,typename> class LayerT,
template<typename,typename> class BidirLayerT,
typename cell_params, typename io_type>
std::tuple<io_type, Tensor, Tensor> _lstm_impl(
const io_type& input,
const std::vector<cell_params>& params, const Tensor& hx, const Tensor& cx,
int64_t num_layers, double dropout_p, bool train, bool bidirectional) { ...}
=== 当我使用答案中建议的方法(这真的很酷)时,遇到以下错误-“错误:在lambda参数声明中使用'auto'仅适用于-std = c ++ 14或-std = gnu ++ 14”。
似乎我需要另一种解决方案,避免在lambda参数中使用auto。
答案 0 :(得分:7)
我建议这样做:
auto implement_params = [&](auto params) {
auto results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
return results;
};
if(...) {
return implement_params(gather_quantized_params(_params));
} else {
return implement_params(gather_quantized_params_fp16(_params));
}