ParGeMSLR
parallel.hpp
Go to the documentation of this file.
1 #ifndef PARGEMSLR_PARALLEL_H
2 #define PARGEMSLR_PARALLEL_H
3 
9 #include <assert.h>
10 #include <mpi.h>
11 #include <vector>
12 #ifdef PARGEMSLR_OPENMP
13 #include "omp.h"
14 #endif
15 #ifdef PARGEMSLR_CUDA
16 #include <cuda_runtime.h>
17 #include <curand.h>
18 #include "cublas_v2.h"
19 #include "cusparse.h"
20 #endif
21 
22 #include "utils.hpp"
23 
24 using namespace std;
25 
26 /*- - - - - - - - - Timing information */
27 
28 #define PARGEMSLR_TIMES_NUM 60
29 
30 #define PARGEMSLR_BUILDTIME_PARTITION 0 // time for the entire partitioning
31 #define PARGEMSLR_BUILDTIME_IE 1 // time for split the interior and external nodes
32 #define PARGEMSLR_BUILDTIME_METIS 2 // time for calling the METIS
33 #define PARGEMSLR_BUILDTIME_STRUCTURE 3 // time for building the structure when having the domain number
34 #define PARGEMSLR_BUILDTIME_RCM 4 // time for applying the RCM ordering
35 #define PARGEMSLR_BUILDTIME_ILUT 5 // time for the ILUT factorization
36 #define PARGEMSLR_BUILDTIME_LRC 6 // time for building the low-rank correction
37 #define PARGEMSLR_BUILDTIME_ARNOLDI 7 // time for the standard arnoldi
38 #define PARGEMSLR_BUILDTIME_BUILD_RES 8 // time for phase the result of arnoldi
39 #define PARGEMSLR_BUILDTIME_SOLVELU 9 // time for the ILU solve in the setup phase
40 #define PARGEMSLR_BUILDTIME_SOLVELU_L 10 // time for the ILU solve in the setup phase on the last level
41 #define PARGEMSLR_BUILDTIME_SOLVELR 11 // time for applying the low-rank correction in the setup phase
42 #define PARGEMSLR_BUILDTIME_SOLVEEBFC 12 // time for solve with EBiFCi
43 #define PARGEMSLR_BUILDTIME_EXTRACTMAT 13 // time for extracting E, B, F, and C on the first level
44 #define PARGEMSLR_BUILDTIME_MOVEDATA 14 // time for moving data between levels
45 #define PARGEMSLR_BUILDTIME_LOCALPERM 15 // time for local permutation
46 #define PARGEMSLR_BUILDTIME_EMV 16 // time for matvec with E on all levels
47 #define PARGEMSLR_BUILDTIME_FMV 17 // time for matvec with F on all levels
48 #define PARGEMSLR_BUILDTIME_GEN_MAT 18 // time for generating matrix
49 #define PARGEMSLR_BUILDTIME_DECOMP 19 // time for decompositions. Hess, schur, eig, ordschur...
50 #define PARGEMSLR_BUILDTIME_MGS 20 // time for MGS in the Arnoldi
51 #define PARGEMSLR_BUILDTIME_EBFC 21 // time for EBFC in the Arnoldi
52 #define PARGEMSLR_PRECTIME_PRECOND 30 // time for applying the preconditioner
53 #define PARGEMSLR_PRECTIME_ILUT 31 // time for the ILU solve in the solve phase
54 #define PARGEMSLR_PRECTIME_ILUT_L 32 // time for the ILU solve in the solve phase on the last level
55 #define PARGEMSLR_PRECTIME_LRC 33 // time for applying the low-rank correction
56 #define PARGEMSLR_PRECTIME_EMV 34 // time for matvec with E on all levels
57 #define PARGEMSLR_PRECTIME_FMV 35 // time for matvec with F on all levels
58 #define PARGEMSLR_PRECTIME_INNER 36 // time for the inner iteration
59 #define PARGEMSLR_PRECTIME_MOVEDATA 37 // time for moving data between levels
60 #define PARGEMSLR_PRECTIME_LOCALPERM 38 // time for moving data between levels
61 #define PARGEMSLR_ITERTIME_AMV 40 // time for matvec with A
62 #define PARGEMSLR_ITERTIME_MGS 41 // time for MGS during solve
63 #define PARGEMSLR_TOTAL_GEN_MAT_TIME 50 // time for transfering data to device
64 #define PARGEMSLR_TOTAL_SETUP_TIME 51 // time for the setup phase
65 #define PARGEMSLR_TOTAL_SOLVE_TIME 52 // time for the solve phase
66 #define PARGEMSLR_BUILDTIME_BMV 53 // time for matvec with B on all levels
67 #define PARGEMSLR_BUILDTIME_CMV 54 // time for matvec with C on all levels
68 #define PARGEMSLR_BUILDTIME_SMV 55 // time for matvec with S on all levels
69 #define PARGEMSLR_PRECTIME_BMV 56 // time for matvec with B on all levels
70 #define PARGEMSLR_PRECTIME_CMV 57 // time for matvec with C on all levels
71 #define PARGEMSLR_PRECTIME_SMV 58 // time for matvec with S on all levels
72 
73 #define PARGEMSLR_GLOBAL_FIRM_TIME_CALL(num, ...) {\
74  pargemslr::PargemslrMpiTime( (*(pargemslr::ParallelLogClass::_gcomm)), pargemslr::ParallelLogClass::_times_buffer_start[num]);\
75  (__VA_ARGS__);\
76  pargemslr::PargemslrMpiTime( (*(pargemslr::ParallelLogClass::_gcomm)), pargemslr::ParallelLogClass::_times_buffer_end[num]);\
77  pargemslr::ParallelLogClass::_times[num] += pargemslr::ParallelLogClass::_times_buffer_end[num] - pargemslr::ParallelLogClass::_times_buffer_start[num];\
78 }
79 
80 #define PARGEMSLR_LOCAL_FIRM_TIME_CALL(num, ...) {\
81  PARGEMSLR_CUDA_SYNCHRONIZE;\
82  pargemslr::ParallelLogClass::_times_buffer_start[num] = MPI_Wtime();\
83  (__VA_ARGS__);\
84  PARGEMSLR_CUDA_SYNCHRONIZE;\
85  pargemslr::ParallelLogClass::_times_buffer_end[num] = MPI_Wtime();\
86  pargemslr::ParallelLogClass::_times[num] += pargemslr::ParallelLogClass::_times_buffer_end[num] - pargemslr::ParallelLogClass::_times_buffer_start[num];\
87 }
88 
89 #define PARGEMSLR_FIRM_TIME_CALL(comm,num, ...) {\
90  pargemslr::PargemslrMpiTime( (comm), pargemslr::ParallelLogClass::_times_buffer_start[num]);\
91  (__VA_ARGS__);\
92  pargemslr::PargemslrMpiTime( (comm), pargemslr::ParallelLogClass::_times_buffer_end[num]);\
93  pargemslr::ParallelLogClass::_times[num] += pargemslr::ParallelLogClass::_times_buffer_end[num] - pargemslr::ParallelLogClass::_times_buffer_start[num];\
94 }
95 
96 #ifdef PARGEMSLR_TIMING
97 
98 #define PARGEMSLR_PRINT_TIMING_RESULT(print_level, ...) {\
99  if(__VA_ARGS__)\
100  {\
101  PARGEMSLR_PRINT("\n");\
102  PargemslrPrintDashLine(pargemslr::pargemslr_global::_dash_line_width);\
103  PARGEMSLR_PRINT("Time info:\n");\
104  PARGEMSLR_PRINT("\tLoad matrix time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_GEN_MAT_TIME]);\
105  PARGEMSLR_PRINT("\tPartition time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_PARTITION]+pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_STRUCTURE]);\
106  PARGEMSLR_PRINT("\tSetup time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_SETUP_TIME]-pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_PARTITION]-pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_STRUCTURE]);\
107  PARGEMSLR_PRINT("\tSolve time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_SOLVE_TIME]);\
108  PARGEMSLR_PRINT("\tTotal time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_SETUP_TIME]+pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_SOLVE_TIME]);\
109  PargemslrPrintDashLine(pargemslr::pargemslr_global::_dash_line_width);\
110  PARGEMSLR_PRINT("\n");\
111  if(print_level > 0)\
112  {\
113  PARGEMSLR_PRINT("\n");\
114  PargemslrPrintDashLine(pargemslr::pargemslr_global::_dash_line_width);\
115  PARGEMSLR_PRINT("Time detail:\n");\
116  PARGEMSLR_PRINT("\tMatvec with A time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_ITERTIME_AMV]);\
117  PARGEMSLR_PRINT("\tPrecond setup time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_SETUP_TIME]);\
118  PARGEMSLR_PRINT("\t-GeMSLR reordering time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_PARTITION]);\
119  PARGEMSLR_PRINT("\t-GeMSLR Setup Structure time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_STRUCTURE]);\
120  PARGEMSLR_PRINT("\t-GeMSLR ILU setup time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_ILUT]);\
121  PARGEMSLR_PRINT("\t--GeMSLR ILU reordering time: %fs - (note: this is the time on p0.)\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_RCM]);\
122  PARGEMSLR_PRINT("\t-GeMSLR low-rank setup time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_LRC]);\
123  PARGEMSLR_PRINT("\t--GeMSLR arnoldi iter time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_ARNOLDI]);\
124  PARGEMSLR_PRINT("\t---GeMSLR MGS time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_MGS]);\
125  PARGEMSLR_PRINT("\t---GeMSLR EB^{-1}FC^{-1} time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_EBFC]);\
126  PARGEMSLR_PRINT("\t---GeMSLR setup ILU solve time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_SOLVELU]);\
127  PARGEMSLR_PRINT("\t---GeMSLR setup ILU solve last lev: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_SOLVELU_L]);\
128  PARGEMSLR_PRINT("\t---GeMSLR setup LRC apply time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_SOLVELR]);\
129  PARGEMSLR_PRINT("\t---GeMSLR setup sparse matvec time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_EMV]+pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_FMV]);\
130  PARGEMSLR_PRINT("\t--GeMSLR build result time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_BUILD_RES]);\
131  PARGEMSLR_PRINT("\t--GeMSLR Lapack Dcomp time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_DECOMP]);\
132  PARGEMSLR_PRINT("\tPrecond applying time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_PRECTIME_PRECOND]);\
133  PARGEMSLR_PRINT("\t-GeMSLR ILU solve time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_PRECTIME_ILUT]);\
134  PARGEMSLR_PRINT("\t-GeMSLR ILU solve last lev: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_PRECTIME_ILUT_L]);\
135  PARGEMSLR_PRINT("\t-GeMSLR sparse matvec time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_PRECTIME_EMV]+pargemslr::ParallelLogClass::_times[PARGEMSLR_PRECTIME_FMV]\
136  +pargemslr::ParallelLogClass::_times[PARGEMSLR_PRECTIME_BMV]+pargemslr::ParallelLogClass::_times[PARGEMSLR_PRECTIME_SMV]+pargemslr::ParallelLogClass::_times[PARGEMSLR_PRECTIME_CMV]);\
137  PARGEMSLR_PRINT("\t-GeMSLR LRC apply time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_PRECTIME_LRC]);\
138  PARGEMSLR_PRINT("\tIterative solve MGS time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_ITERTIME_MGS]);\
139  PargemslrPrintDashLine(pargemslr::pargemslr_global::_dash_line_width);\
140  PARGEMSLR_PRINT("\n");\
141  }\
142  }\
143 }
144 
145 #define PARGEMSLR_GLOBAL_TIME_CALL(num, ...) {\
146  pargemslr::PargemslrMpiTime( (*(pargemslr::ParallelLogClass::_gcomm)), pargemslr::ParallelLogClass::_times_buffer_start[num]);\
147  (__VA_ARGS__);\
148  pargemslr::PargemslrMpiTime( (*(pargemslr::ParallelLogClass::_gcomm)), pargemslr::ParallelLogClass::_times_buffer_end[num]);\
149  pargemslr::ParallelLogClass::_times[num] += pargemslr::ParallelLogClass::_times_buffer_end[num] - pargemslr::ParallelLogClass::_times_buffer_start[num];\
150 }
151 
152 #define PARGEMSLR_LOCAL_TIME_CALL(num, ...) {\
153  PARGEMSLR_CUDA_SYNCHRONIZE;\
154  pargemslr::ParallelLogClass::_times_buffer_start[num] = MPI_Wtime();\
155  (__VA_ARGS__);\
156  PARGEMSLR_CUDA_SYNCHRONIZE;\
157  pargemslr::ParallelLogClass::_times_buffer_end[num] = MPI_Wtime();\
158  pargemslr::ParallelLogClass::_times[num] += pargemslr::ParallelLogClass::_times_buffer_end[num] - pargemslr::ParallelLogClass::_times_buffer_start[num];\
159 }
160 
161 #define PARGEMSLR_TIME_CALL(comm,num, ...) {\
162  pargemslr::PargemslrMpiTime( (comm), pargemslr::ParallelLogClass::_times_buffer_start[num]);\
163  (__VA_ARGS__);\
164  pargemslr::PargemslrMpiTime( (comm), pargemslr::ParallelLogClass::_times_buffer_end[num]);\
165  pargemslr::ParallelLogClass::_times[num] += pargemslr::ParallelLogClass::_times_buffer_end[num] - pargemslr::ParallelLogClass::_times_buffer_start[num];\
166 }
167 
168 #define PARGEMSLR_RESET_TIME std::fill(pargemslr::ParallelLogClass::_times.begin(), pargemslr::ParallelLogClass::_times.end(), 0.0);
169 
170 #else
171 
172 #define PARGEMSLR_GLOBAL_TIME_CALL(num, ...) {\
173  (__VA_ARGS__);\
174 }
175 
176 #define PARGEMSLR_LOCAL_TIME_CALL(num, ...) {\
177  (__VA_ARGS__);\
178 }
179 
180 #define PARGEMSLR_TIME_CALL(comm,num, ...) {\
181  (__VA_ARGS__);\
182 }
183 
184 #define PARGEMSLR_PRINT_TIMING_RESULT(print_level, ...) {\
185  if(__VA_ARGS__)\
186  {\
187  PARGEMSLR_PRINT("\n");\
188  PargemslrPrintDashLine(pargemslr::pargemslr_global::_dash_line_width);\
189  PARGEMSLR_PRINT("Time info:\n");\
190  PARGEMSLR_PRINT("\tLoad matrix time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_GEN_MAT_TIME]);\
191  PARGEMSLR_PRINT("\tPartition time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_PARTITION]);\
192  PARGEMSLR_PRINT("\tSetup time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_SETUP_TIME]-pargemslr::ParallelLogClass::_times[PARGEMSLR_BUILDTIME_PARTITION]);\
193  PARGEMSLR_PRINT("\tSolve time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_SOLVE_TIME]);\
194  PARGEMSLR_PRINT("\tTotal time: %fs\n",pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_SETUP_TIME]+pargemslr::ParallelLogClass::_times[PARGEMSLR_TOTAL_SOLVE_TIME]);\
195  PargemslrPrintDashLine(pargemslr::pargemslr_global::_dash_line_width);\
196  PARGEMSLR_PRINT("\n");\
197  }\
198 }
199 
200 #define PARGEMSLR_RESET_TIME
201 
202 #endif
203 
204 namespace pargemslr
205 {
211  typedef class ParallelLogClass
212  {
213  public:
214  /* variables */
215 #ifdef PARGEMSLR_CUDA
216 
221  static curandGenerator_t _curand_gen;
222 
227  static cublasHandle_t _cublas_handle;
228 
233  static cusparseHandle_t _cusparse_handle;
234 
239  static cudaStream_t _stream;
240 
245  static cusparseIndexBase_t _cusparse_idx_base;
246 
251  static cusparseMatDescr_t _mat_des;
252 
257  static cusparseMatDescr_t _matL_des;
258 
263  static cusparseMatDescr_t _matU_des;
264 
269  static cusparseSolvePolicy_t _ilu_solve_policy;
270 
275  static void *_cusparse_buffer;
276 
281  static size_t _cusparse_buffer_length;
282 
283 #if (PARGEMSLR_CUDA_VERSION == 11)
284 
288  static cusparseIndexType_t _cusparse_idx_type;
289 
294  static cusparseSpMVAlg_t _cusparse_spmv_algorithm;
295 #endif
296 
297 #endif
298 
303  static int _working_location;
304 
309  static int _gsize;
310 
315  static int _grank;
316 
321  static MPI_Comm *_gcomm;
322 
327  static MPI_Comm *_lcomm;
328 
333  int _size;
334 
339  int _rank;
340 
345  MPI_Comm _commref;
346 
351  MPI_Comm *_comm;
352 
357  static vector<double> _times;
358 
363  static vector<double> _times_buffer_start;
364 
369  static vector<double> _times_buffer_end;
370 
375  int Clear();
376 
382 
388 
395 
402  ParallelLogClass& operator= (const ParallelLogClass &parlog);
403 
410  ParallelLogClass& operator= ( ParallelLogClass &&parlog);
411 
417  ParallelLogClass(MPI_Comm comm_in);
418 
424 
433  int GetMpiInfo(int &np, int &myid, MPI_Comm &comm) const;
434 
440  MPI_Comm GetComm() const;
441 
442  }parallel_log, *parallel_logp;
443 
450  int PargemslrSetOpenmpNumThreads(int nthreads);
451 
452 #ifdef PARGEMSLR_OPENMP
453 
459  int PargemslrGetOpenmpThreadNum();
460 
461 
467  int PargemslrGetOpenmpNumThreads();
468 
474  int PargemslrGetOpenmpMaxNumThreads();
475 
481  int PargemslrGetOpenmpGlobalMaxNumThreads();
482 
483 #endif
484 
494  int PargemslrNLocalToNGlobal( int n_local, long int &n_start, long int &n_global, MPI_Comm &comm);
495 
508  int PargemslrNLocalToNGlobal( int nrow_local, int ncol_local, long int &nrow_start, long int &ncol_start, long int &nrow_global, long int &ncol_global, MPI_Comm &comm);
509 
517  int PargemslrInit(int *argc, char ***argv);
518 
525  int PargemslrInitMpi(MPI_Comm comm);
526 
533  int PargemslrInitOpenMP(int nthreads);
534 
540  int PargemslrInitCUDA();
541 
547  int PargemslrPrintParallelInfo();
548 
554  int PargemslrFinalize();
555 
561  int PargemslrFinalizeMpi();
562 
568  int PargemslrFinalizeOpenMP();
569 
575  int PargemslrFinalizeCUDA();
576 
584  int PargemslrMpiTime(MPI_Comm comm, double &t);
585 
586 #ifdef MPI_C_FLOAT_COMPLEX
587 
599  template <typename T>
600  int PargemslrMpiIsend(T *buf, int count, int dest, int tag, MPI_Comm comm, MPI_Request *request);
601 
613  template <typename T>
614  int PargemslrMpiIrecv(T *buf, int count, int source, int tag, MPI_Comm comm, MPI_Request *request);
615 
626  template <typename T>
627  int PargemslrMpiSend(T *buf, int count, int dest, int tag, MPI_Comm comm);
628 
640  template <typename T>
641  int PargemslrMpiRecv(T *buf, int count, int source, int tag, MPI_Comm comm, MPI_Status * status);
642 
643 #else
644 
656  template <typename T>
657  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
658  PargemslrMpiIsend(T *buf, int count, int dest, int tag, MPI_Comm comm, MPI_Request *request);
659 
671  template <typename T>
672  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
673  PargemslrMpiIsend(T *buf, int count, int dest, int tag, MPI_Comm comm, MPI_Request *request);
674 
686  template <typename T>
687  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
688  PargemslrMpiIrecv(T *buf, int count, int source, int tag, MPI_Comm comm, MPI_Request *request);
689 
701  template <typename T>
702  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
703  PargemslrMpiIrecv(T *buf, int count, int source, int tag, MPI_Comm comm, MPI_Request *request);
704 
715  template <typename T>
716  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
717  PargemslrMpiSend(T *buf, int count, int dest, int tag, MPI_Comm comm);
718 
730  template <typename T>
731  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
732  PargemslrMpiSend(T *buf, int count, int dest, int tag, MPI_Comm comm);
733 
745  template <typename T>
746  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
747  PargemslrMpiRecv(T *buf, int count, int source, int tag, MPI_Comm comm, MPI_Status * status);
748 
760  template <typename T>
761  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
762  PargemslrMpiRecv(T *buf, int count, int source, int tag, MPI_Comm comm, MPI_Status * status);
763 
764 #endif
765 
766 #ifdef MPI_C_FLOAT_COMPLEX
767 
777  template <typename T>
778  int PargemslrMpiBcast(T *buf, int count, int root, MPI_Comm comm);
779 
780 #else
781 
791  template <typename T>
792  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
793  PargemslrMpiBcast(T *buf, int count, int root, MPI_Comm comm);
794 
804  template <typename T>
805  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
806  PargemslrMpiBcast(T *buf, int count, int root, MPI_Comm comm);
807 
808 #endif
809 
810 #ifdef MPI_C_FLOAT_COMPLEX
811 
822  template <typename T>
823  int PargemslrMpiScan(T *sendbuf, T *recvbuf, int count, MPI_Op op, MPI_Comm comm);
824 
825 #else
826 
837  template <typename T>
838  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
839  PargemslrMpiScan(T *sendbuf, T *recvbuf, int count, MPI_Op op, MPI_Comm comm);
840 
851  template <typename T>
852  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
853  PargemslrMpiScan(T *sendbuf, T *recvbuf, int count, MPI_Op op, MPI_Comm comm);
854 
855 #endif
856 
857 #ifdef MPI_C_FLOAT_COMPLEX
858 
870  template <typename T>
871  int PargemslrMpiReduce(T *sendbuf, T *recvbuf, int count, MPI_Op op, int root, MPI_Comm comm);
872 
873 #else
874 
886  template <typename T>
887  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
888  PargemslrMpiReduce(T *sendbuf, T *recvbuf, int count, MPI_Op op, int root, MPI_Comm comm);
889 
901  template <typename T>
902  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
903  PargemslrMpiReduce(T *sendbuf, T *recvbuf, int count, MPI_Op op, int root, MPI_Comm comm);
904 
905 #endif
906 
907 #ifdef MPI_C_FLOAT_COMPLEX
908 
919  template <typename T>
920  int PargemslrMpiAllreduce(T *sendbuf, T *recvbuf, int count, MPI_Op op, MPI_Comm comm);
921 
922 #else
923 
934  template <typename T>
935  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
936  PargemslrMpiAllreduce(T *sendbuf, T *recvbuf, int count, MPI_Op op, MPI_Comm comm);
937 
948  template <typename T>
949  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
950  PargemslrMpiAllreduce(T *sendbuf, T *recvbuf, int count, MPI_Op op, MPI_Comm comm);
951 
952 #endif
953 
954 #ifdef MPI_C_FLOAT_COMPLEX
955 
965  template <typename T>
966  int PargemslrMpiAllreduceInplace(T *buf, int count, MPI_Op op, MPI_Comm comm);
967 
968 #else
969 
979  template <typename T>
980  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
981  PargemslrMpiAllreduceInplace(T *buf, int count, MPI_Op op, MPI_Comm comm);
982 
992  template <typename T>
993  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
994  PargemslrMpiAllreduceInplace(T *buf, int count, MPI_Op op, MPI_Comm comm);
995 
996 #endif
997 
998 #ifdef MPI_C_FLOAT_COMPLEX
999 
1010  template <typename T>
1011  int PargemslrMpiGather(T *sendbuf, int count, T *recvbuf, int root, MPI_Comm comm);
1012 
1013 #else
1014 
1025  template <typename T>
1026  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
1027  PargemslrMpiGather(T *sendbuf, int count, T *recvbuf, int root, MPI_Comm comm);
1028 
1039  template <typename T>
1040  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
1041  PargemslrMpiGather(T *sendbuf, int count, T *recvbuf, int root, MPI_Comm comm);
1042 
1043 #endif
1044 
1045 #ifdef MPI_C_FLOAT_COMPLEX
1046 
1056  template <typename T>
1057  int PargemslrMpiAllgather(T *sendbuf, int count, T *recvbuf, MPI_Comm comm);
1058 
1059 #else
1060 
1070  template <typename T>
1071  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
1072  PargemslrMpiAllgather(T *sendbuf, int count, T *recvbuf, MPI_Comm comm);
1073 
1083  template <typename T>
1084  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
1085  PargemslrMpiAllgather(T *sendbuf, int count, T *recvbuf, MPI_Comm comm);
1086 
1087 #endif
1088 
1089 #ifdef MPI_C_FLOAT_COMPLEX
1090 
1102  template <typename T>
1103  int PargemslrMpiAllgatherv(T *sendbuf, int count, T *recvbuf, int *recvcounts, int *recvdisps, MPI_Comm comm);
1104 
1105 #else
1106 
1118  template <typename T>
1119  typename std::enable_if<!PargemslrIsComplex<T>::value, int>::type
1120  PargemslrMpiAllgatherv(T *sendbuf, int count, T *recvbuf, int *recvcounts, int *recvdisps, MPI_Comm comm);
1121 
1133  template <typename T>
1134  typename std::enable_if<PargemslrIsComplex<T>::value, int>::type
1135  PargemslrMpiAllgatherv(T *sendbuf, int count, T *recvbuf, int *recvcounts, int *recvdisps, MPI_Comm comm);
1136 
1137 #endif
1138 
1139 #ifdef PARGEMSLR_CUDA
1140 
1146  int PargemslrCudaSynchronize();
1147 #endif
1148 
1154  template <typename T>
1155  MPI_Datatype PargemslrMpiDataType();
1156 
1162  template<>
1163  MPI_Datatype PargemslrMpiDataType<int>();
1164 
1170  template<>
1171  MPI_Datatype PargemslrMpiDataType<long int>();
1172 
1178  template<>
1179  MPI_Datatype PargemslrMpiDataType<float>();
1180 
1186  template<>
1187  MPI_Datatype PargemslrMpiDataType<double>();
1188 
1194  template<>
1195  MPI_Datatype PargemslrMpiDataType<complexs>();
1196 
1202  template<>
1203  MPI_Datatype PargemslrMpiDataType<complexd>();
1204 
1205 }
1206 
1207 /*- - - - - - - - - OPENMP default schedule */
1208 #ifdef PARGEMSLR_OPENMP
1209 
1210 /* some implementations requires same operation order, use PARGEMSLR_OPENMP_SCHEDULE_STATIC to make sure */
1211 //#define PARGEMSLR_OPENMP_SCHEDULE_DEFAULT schedule(dynamic)
1212 #define PARGEMSLR_OPENMP_SCHEDULE_DEFAULT schedule(static)
1213 #define PARGEMSLR_OPENMP_SCHEDULE_STATIC schedule(static)
1214 
1215 #endif
1216 
1217 /*- - - - - - - - - MPI calls */
1218 
1219 #ifdef PARGEMSLR_DEBUG
1220 
1221 #define PARGEMSLR_MPI_CALL(...) {\
1222  assert( (__VA_ARGS__) == MPI_SUCCESS);\
1223 }
1224 
1225 #else
1226 
1227 #define PARGEMSLR_MPI_CALL(...) {\
1228  (__VA_ARGS__);\
1229 }
1230 
1231 #endif
1232 
1233 /*- - - - - - - - - CUDA calls */
1234 
1235 #ifdef PARGEMSLR_CUDA
1236 
1237 #ifndef PARGEMSLR_CUDA_VERSION
1238 
1239 /* the default CUDA version is 11, note that we only support CUDA 10 and CUDA 11 yet */
1240 #define PARGEMSLR_CUDA_VERSION 11
1241 
1242 #endif
1243 
1244 #define PARGEMSLR_CUDA_SYNCHRONIZE PargemslrCudaSynchronize();
1245 
1246 #ifdef PARGEMSLR_DEBUG
1247 
1248 #define PARGEMSLR_CUDA_CALL(...) {\
1249  assert((__VA_ARGS__) == cudaSuccess);\
1250 }
1251 
1252 #define PARGEMSLR_CURAND_CALL(...) {\
1253  assert((__VA_ARGS__) == CURAND_STATUS_SUCCESS);\
1254 }
1255 
1256 #define PARGEMSLR_CUBLAS_CALL(...) {\
1257  assert( (__VA_ARGS__) == CUBLAS_STATUS_SUCCESS);\
1258 }
1259 
1260 #define PARGEMSLR_CUSPARSE_CALL(...) {\
1261  assert((__VA_ARGS__) == CUSPARSE_STATUS_SUCCESS);\
1262 }
1263 
1264 #else
1265 
1266 #define PARGEMSLR_CUDA_CALL(...) {\
1267  (__VA_ARGS__);\
1268 }
1269 
1270 #define PARGEMSLR_CURAND_CALL(...) {\
1271  (__VA_ARGS__);\
1272 }
1273 
1274 #define PARGEMSLR_CUBLAS_CALL(...) {\
1275  (__VA_ARGS__);\
1276 }
1277 
1278 #define PARGEMSLR_CUSPARSE_CALL(...) {\
1279  (__VA_ARGS__);\
1280 }
1281 
1282 #endif
1283 
1284 //#define PARGEMSLR_THRUST_CALL(thrust_function, ...) thrust::thrust_function(thrust::cuda::par.on(pargemslr::ParallelLogClass::_stream), __VA_ARGS__)
1285 
1286 #define PARGEMSLR_THRUST_CALL(thrust_function, ...) thrust::thrust_function( __VA_ARGS__)
1287 
1288 #else
1289 
1290 #define PARGEMSLR_CUDA_SYNCHRONIZE
1291 
1292 #endif
1293 
1294 #define PARGEMSLR_GLOBAL_SEQUENTIAL_RUN(...) {\
1295  for(int pgsri = 0 ; pgsri < pargemslr::parallel_log::_gsize ; pgsri++)\
1296  {\
1297  if( pargemslr::parallel_log::_grank == pgsri)\
1298  {\
1299  (__VA_ARGS__);\
1300  }\
1301  MPI_Barrier(*(pargemslr::parallel_log::_gcomm));\
1302  }\
1303 }
1304 
1305 #endif
pargemslr::ParallelLogClass::ParallelLogClass
ParallelLogClass(ParallelLogClass &&parlog)
The = operator of parallel_log.
pargemslr::ParallelLogClass::_gsize
static int _gsize
The total number of global MPI ranks.
Definition: parallel.hpp:309
pargemslr::ParallelLogClass::_comm
MPI_Comm * _comm
The local MPI comm.
Definition: parallel.hpp:351
utils.hpp
Basic ultility functions.
pargemslr::ParallelLogClass::Clear
int Clear()
Free the parallel_log.
pargemslr::ParallelLogClass::_gcomm
static MPI_Comm * _gcomm
The global MPI comm.
Definition: parallel.hpp:321
pargemslr::ParallelLogClass::ParallelLogClass
ParallelLogClass(MPI_Comm comm_in)
The constructor of parallel_log, setup a new local comm.
pargemslr::ParallelLogClass::~ParallelLogClass
~ParallelLogClass()
The destructor of parallel_log.
pargemslr::ParallelLogClass::_lcomm
static MPI_Comm * _lcomm
The local MPI comm (one np only, for consistancy).
Definition: parallel.hpp:327
pargemslr::ParallelLogClass::ParallelLogClass
ParallelLogClass()
The default constructor of parallel_log.
pargemslr::ParallelLogClass::GetComm
MPI_Comm GetComm() const
Get the MPI_comm. When _comm is NULL, get the global one, otherwise get the local one.
pargemslr::ParallelLogClass::_grank
static int _grank
The number of global MPI rank.
Definition: parallel.hpp:315
pargemslr::ParallelLogClass
The data structure for parallel computing, including data structures for MPI and CUDA.
Definition: parallel.hpp:212
pargemslr::ParallelLogClass::_times_buffer_start
static vector< double > _times_buffer_start
The std::vector stores the start time of each section.
Definition: parallel.hpp:363
pargemslr::ParallelLogClass::_times_buffer_end
static vector< double > _times_buffer_end
The std::vector stores the end time of each section.
Definition: parallel.hpp:369
pargemslr::ParallelLogClass::_commref
MPI_Comm _commref
The MPI comm that doesn't need to be freed.
Definition: parallel.hpp:345
pargemslr::ParallelLogClass::_rank
int _rank
The number of local MPI rank.
Definition: parallel.hpp:339
pargemslr::ParallelLogClass::ParallelLogClass
ParallelLogClass(const ParallelLogClass &parlog)
The copy constructor of parallel_log.
pargemslr::ParallelLogClass::_working_location
static int _working_location
The working location of the code (device/host).
Definition: parallel.hpp:303
pargemslr::ParallelLogClass::GetMpiInfo
int GetMpiInfo(int &np, int &myid, MPI_Comm &comm) const
Get comm, np, and myid. When _comm is NULL, get the global one, otherwise get the local one.
pargemslr::ParallelLogClass::_size
int _size
The total number of local MPI ranks.
Definition: parallel.hpp:333
pargemslr::ParallelLogClass::_times
static vector< double > _times
The std::vector stores the timing information.
Definition: parallel.hpp:357