我有一个template
函数,在变量中使用typename
时总是出错。
我有两个函数可以调用,它们是
template <>
void caffe_axpy<float>(const int N, const float alpha, const float* X,
float* Y) { cblas_saxpy(N, alpha, X, 1, Y, 1); }
template <>
void caffe_axpy<double>(const int N, const double alpha, const double* X,
double* Y) { cblas_daxpy(N, alpha, X, 1, Y, 1); }
因为它们可以是double或float,所以我将模板函数设为
template <typename Dtype>
void Blob<Dtype>::UpdateNew(Dtype decay, Dtype lr) {
// We will perform update based on where the data is located.
switch (data_->head()) {
case SyncedMemory::HEAD_AT_CPU:
// perform computation on CPU
//this is using data and decay to update diff
caffe_axpy<Dtype>(count_, (Dtype)decay,
static_cast<const Dtype*>(data_->cpu_data()),
static_cast<Dtype*>(diff_->mutable_cpu_data()));
//this is using diff and lr to update data
caffe_axpy<Dtype>(count_, (Dtype)lr,
static_cast<const Dtype*>(diff_->cpu_data()),
static_cast<Dtype*>(data_->mutable_cpu_data()));
break;
case SyncedMemory::HEAD_AT_GPU:
case SyncedMemory::SYNCED:
#ifndef CPU_ONLY
// perform computation on GPU
//this is using data and decay to update diff
caffe_gpu_axpy<Dtype>(count_, (Dtype)decay,
static_cast<const Dtype*>(data_->gpu_data()),
static_cast<Dtype*>(diff_->mutable_gpu_data()));
//this is using diff and lr to update data
caffe_gpu_axpy<Dtype>(count_, (float)lr,
static_cast<const Dtype*>(diff_->gpu_data()),
static_cast<Dtype*>(data_->mutable_gpu_data()));
#else
NO_GPU;
#endif
break;
default:
LOG(FATAL) << "Syncedmem not initialized.";
}
}
如果我将所有以下行从Dtype更改为float,则没有错误。如果我使用Dtype,那么我有未定义的引用错误。
caffe_axpy<float>(count_, (float)decay,
static_cast<const Dtype*>(data_->cpu_data()),
static_cast<float*>(diff_->mutable_cpu_data()));
//this is using diff and lr to update data
caffe_axpy<float>(count_, (float)lr,
static_cast<const float*>(diff_->cpu_data()),
static_cast<float*>(data_->mutable_cpu_data()));
caffe_gpu_axpy<float>(count_, (float)decay,
static_cast<const float*>(data_->gpu_data()),
static_cast<float*>(diff_->mutable_gpu_data()));
//this is using diff and lr to update data
caffe_gpu_axpy<float>(count_, (float)lr,
static_cast<const float*>(diff_->gpu_data()),
static_cast<float*>(data_->mutable_gpu_data()));
可能出现什么问题?
编辑: 编译错误是
debug_debug/lib/libcaffe.so: undefined reference to `void caffe::caffe_axpy<unsigned int>(int, unsigned int, unsigned int const*, unsigned int*)'
.debug_debug/lib/libcaffe.so: undefined reference to `void caffe::caffe_gpu_axpy<unsigned int>(int, unsigned int, unsigned int const*, unsigned int*)'
.debug_debug/lib/libcaffe.so: undefined reference to `void caffe::caffe_gpu_axpy<int>(int, int, int const*, int*)'
.debug_debug/lib/libcaffe.so: undefined reference to `void caffe::caffe_axpy<int>(int, int, int const*, int*)'
collect2: error: ld returned 1 exit status