通用MPI代码

时间:2017-02-27 16:04:24

标签: c++ generics mpi

我想创建一个通用的MPI方法,让我们说一个特定对象的bcast。但我需要将原始类型转换为MPI_Data类型? 知道怎么做吗?

template <typename T>
void bcast_data(std::vector<T> vec) 
{
...
}

我需要将MPI_INT用于int,将MPI_DOUBLE用于double,... 所以我需要一个类型转换方法,我想创建一个可以给我MPI_datatypes的数据表枚举,但它需要将类型作为输入参数传递。

任何想法?

由于

4 个答案:

答案 0 :(得分:2)

您可以使用“type traits”惯用法来序列化通用对象T。这为您提供了在不更改实现的情况下添加对新类型的支持的优势。

看看我多年前写的这个MPI包装器:https://github.com/motonacciu/mpp

您想要定义类似以下的类型特征:

template <class T>
struct mpi_type_traits {
    typedef T element_type;
    typedef T* element_addr_type;

    static inline MPI_Datatype get_type(T&& raw);
    static inline size_t get_size(T& raw);
    static inline element_addr_type get_addr(T& raw);
};

并提供具体类型的专业化,例如std::vector<T>如下:

template <class T>
struct mpi_type_traits<std::vector<T>> {

    typedef T element_type;
    typedef T* element_addr_type;

    static inline size_t get_size(std::vector<T>& vec) {
       return vec.size();
    }

    static inline MPI_Datatype get_type(std::vector<T>&& vec) {
        return mpi_type_traits<T>::get_type( T{} );
    }

    static inline element_addr_type get_addr(std::vector<T>& vec) {
        return mpi_type_traits<T>::get_addr( vec.front() );
    }
};

您需要做的最后一件事是实现您的MPI方法并使用类型特征,例如在致电MPI_Send

template <class T>
void send(T &&value, ...) {
   MPI_Send(mpi_type_traits<T>::get_addr(value),
            mpi_type_traits<T>::get_size(value),
            mpi_type_traits<T>::get_type(value), ...);
}

答案 1 :(得分:2)

我认为Boost feature get_mpi_datatype应该提供此功能。

类似地,可以使用constexpr的{​​{1}}函数将apramc's idea扩展到所有当前的MPI data types,使得相应的MPI数据类型为already evaluated at compile time,如下所示( click here for the Gist,需要C ++ 17)

type_traits

您可以先调用MPI命令,例如

#include <cassert>
#include <complex>
#include <type_traits>

#include <mpi.h>


template<typename T>
constexpr MPI_Datatype mpi_get_type()
{
    MPI_Datatype mpi_type = MPI_DATATYPE_NULL;
    
    if constexpr (std::is_same<T, char>::value)
    {
        mpi_type = MPI_CHAR;
    }
    else if constexpr (std::is_same<T, signed char>::value)
    {
        mpi_type = MPI_SIGNED_CHAR;
    }
    else if constexpr (std::is_same<T, unsigned char>::value)
    {
        mpi_type = MPI_UNSIGNED_CHAR;
    }
    else if constexpr (std::is_same<T, wchar_t>::value)
    {
        mpi_type = MPI_WCHAR;
    }
    else if constexpr (std::is_same<T, signed short>::value)
    {
        mpi_type = MPI_SHORT;
    }
    else if constexpr (std::is_same<T, unsigned short>::value)
    {
        mpi_type = MPI_UNSIGNED_SHORT;
    }
    else if constexpr (std::is_same<T, signed int>::value)
    {
        mpi_type = MPI_INT;
    }
    else if constexpr (std::is_same<T, unsigned int>::value)
    {
        mpi_type = MPI_UNSIGNED;
    }
    else if constexpr (std::is_same<T, signed long int>::value)
    {
        mpi_type = MPI_LONG;
    }
    else if constexpr (std::is_same<T, unsigned long int>::value)
    {
        mpi_type = MPI_UNSIGNED_LONG;
    }
    else if constexpr (std::is_same<T, signed long long int>::value)
    {
        mpi_type = MPI_LONG_LONG;
    }
    else if constexpr (std::is_same<T, unsigned long long int>::value)
    {
        mpi_type = MPI_UNSIGNED_LONG_LONG;
    }
    else if constexpr (std::is_same<T, float>::value)
    {
        mpi_type = MPI_FLOAT;
    }
    else if constexpr (std::is_same<T, double>::value)
    {
        mpi_type = MPI_DOUBLE;
    }
    else if constexpr (std::is_same<T, long double>::value)
    {
        mpi_type = MPI_LONG_DOUBLE;
    }
    else if constexpr (std::is_same<T, int8_t>::value)
    {
        mpi_type = MPI_INT8_T;
    }
    else if constexpr (std::is_same<T, int16_t>::value)
    {
        mpi_type = MPI_INT16_T;
    }
    else if constexpr (std::is_same<T, int32_t>::value)
    {
        mpi_type = MPI_INT32_T;
    }
    else if constexpr (std::is_same<T, int64_t>::value)
    {
        mpi_type = MPI_INT64_T;
    }
    else if constexpr (std::is_same<T, uint8_t>::value)
    {
        mpi_type = MPI_UINT8_T;
    }
    else if constexpr (std::is_same<T, uint16_t>::value)
    {
        mpi_type = MPI_UINT16_T;
    }
    else if constexpr (std::is_same<T, uint32_t>::value)
    {
        mpi_type = MPI_UINT32_T;
    }
    else if constexpr (std::is_same<T, uint64_t>::value)
    {
        mpi_type = MPI_UINT64_T;
    }
    else if constexpr (std::is_same<T, bool>::value)
    {
        mpi_type = MPI_C_BOOL;
    }
    else if constexpr (std::is_same<T, std::complex<float>>::value)
    {
        mpi_type = MPI_C_COMPLEX;
    }
    else if constexpr (std::is_same<T, std::complex<double>>::value)
    {
        mpi_type = MPI_C_DOUBLE_COMPLEX;
    }
    else if constexpr (std::is_same<T, std::complex<long double>>::value)
    {
        mpi_type = MPI_C_LONG_DOUBLE_COMPLEX;
    }
    
    assert(mpi_type != MPI_DATATYPE_NULL);
    return mpi_type;    
}

答案 2 :(得分:2)

我曾经想出一个与已经显示的解决方案非常相似的解决方案,但是有一点优势。这个想法是使用模板化的函数并添加特殊化以将类型解析为相应的MPI类型:

namespace mpiUtil { // Namespace for convenience

    template <typename T>
    MPI_Datatype resolveType();

    template <>
    MPI_Datatype resolveType<double>()
    {
        return MPI_DOUBLE;
    }

    // ... add a specialization for all other types

    template <typename T>
    int autoSend(const T *items, int count, const int dest, const int tag, 
                 const MPI_Comm comm)
    {
        return MPI_Send(items, count, resolveType<T>(), 0, tag, comm);
    }
    
    // You can repeat this procedure for other MPI functions
}

主要优点是它对于自定义类型很容易扩展。例如,如果您有一个class和一个要传送的所说class数组,则可以add a custom type to the MPI system并为此resolveType()注入一个特殊化的class。进入mpiUtil命名空间。

答案 3 :(得分:1)

我使用了这样的东西,它绝对不是一个完整的答案,因为它留下了一些类型。但它适用于我的情况

template<typename T>
MPI_Datatype get_type()
{
    char name = typeid(T).name()[0];
    switch (name) {
        case 'i':
            return MPI_INT;
            break;
        case 'f':
            return MPI_FLOAT;
            break;
        case 'j':
            return MPI_UNSIGNED;
            break;
        case 'd':
            return MPI_DOUBLE;
            break;
        case 'c':
            return MPI_CHAR;
            break;
        case 's':
            return MPI_SHORT;
            break;
        case 'l':
            return MPI_LONG;
            break;
        case 'm':
            return MPI_UNSIGNED_LONG;
            break;
        case 'b':
            return MPI_BYTE;
            break;
    }
}