1 #ifndef PARGEMSLR_FGMRES_H
2 #define PARGEMSLR_FGMRES_H
9 #include "../utils/memory.hpp"
10 #include "../utils/parallel.hpp"
11 #include "../utils/utils.hpp"
12 #include "../vectors/vector.hpp"
13 #include "../vectors/sequential_vector.hpp"
14 #include "../matrices/matrix.hpp"
15 #include "../matrices/matrixops.hpp"
16 #include "../matrices/dense_matrix.hpp"
27 template <
class MatrixType,
class VectorType,
typename DataType>
36 typename std::conditional<PargemslrIsDoublePrecision<DataType>::value,
62 typename std::conditional<PargemslrIsDoublePrecision<DataType>::value,
64 float>::type _rel_res;
76 typename std::conditional<PargemslrIsDoublePrecision<DataType>::value,
95 this->_location = kMemoryHost;
99 this->_absolute_tol =
false;
100 this->_rel_res = 0.0;
111 this->_location = solver._location;
112 this->_kdim = solver._kdim;
113 this->_maxits = solver._maxits;
114 this->_tol = solver._tol;
115 this->_absolute_tol = solver._absolute_tol;
116 this->_rel_res = solver._rel_res;
117 this->_iter = solver._iter;
118 this->_rel_res_vector = solver._rel_res_vector;
128 this->_location = solver._location;
129 solver._location = kMemoryHost;
130 this->_kdim = solver._kdim;
132 this->_maxits = solver._maxits;
134 this->_tol = solver._tol;
136 this->_absolute_tol = solver._absolute_tol;
137 solver._absolute_tol =
false;
138 this->_rel_res = solver._rel_res;
140 this->_iter = solver._iter;
142 this->_rel_res_vector = std::move(solver._rel_res_vector);
155 this->_location = solver._location;
156 this->_kdim = solver._kdim;
157 this->_maxits = solver._maxits;
158 this->_tol = solver._tol;
159 this->_absolute_tol = solver._absolute_tol;
160 this->_rel_res = solver._rel_res;
161 this->_iter = solver._iter;
162 this->_rel_res_vector = solver._rel_res_vector;
176 this->_location = solver._location;
177 solver._location = kMemoryHost;
178 this->_kdim = solver._kdim;
180 this->_maxits = solver._maxits;
182 this->_tol = solver._tol;
184 this->_absolute_tol = solver._absolute_tol;
185 solver._absolute_tol =
false;
186 this->_rel_res = solver._rel_res;
188 this->_iter = solver._iter;
190 this->_rel_res_vector = std::move(solver._rel_res_vector);
204 this->_location = kMemoryHost;
208 _rel_res_vector.Clear();
210 return PARGEMSLR_SUCCESS;
230 virtual int Setup( VectorType &x, VectorType &rhs)
236 return PARGEMSLR_SUCCESS;
253 return PARGEMSLR_SUCCESS;
263 virtual int Solve( VectorType &x, VectorType &rhs)
267 PARGEMSLR_ERROR(
"Solve without setup.");
268 return PARGEMSLR_ERROR_FUNCTION_CALL_ERR;
272 typedef typename std::conditional<PargemslrIsDoublePrecision<DataType>::value, double,
float>::type RealDataType;
276 int n_local, i, j, k;
278 RealDataType normb, EPSILON, normr, tolr, t, gam;
281 #ifdef PARGEMSLR_TIMING
285 this->
_matrix->GetMpiInfo(np, myid, comm);
292 EPSILON = std::numeric_limits<RealDataType>::epsilon();
299 n_local = x.GetLengthLocal();
302 V.
Setup(n_local, this->_kdim+1, this->_location,
true);
303 Z.
Setup(n_local, this->_kdim+1, this->_location,
true);
304 H.
Setup(this->_kdim+1, this->_kdim, kMemoryHost,
true);
307 this->
_matrix->SetupVectorPtrStr(v);
308 this->
_matrix->SetupVectorPtrStr(z);
309 this->
_matrix->SetupVectorPtrStr(w);
312 c.
Setup(this->_kdim,
true);
313 s.
Setup(this->_kdim,
true);
314 rs.
Setup(this->_kdim+1,
true);
324 this->_rel_res = 0.0;
326 this->_rel_res_vector.Setup(1,
true);
327 return PARGEMSLR_SUCCESS;
337 PARGEMSLR_MEMCPY(v.GetData(), rhs.GetData(), n_local, v.GetDataLocation(), rhs.GetDataLocation(), T);
338 PARGEMSLR_TIME_CALL( comm, PARGEMSLR_ITERTIME_AMV, (this->
_matrix->MatVec(
'N', mone, x, one, v)));
346 this->_rel_res = 0.0;
348 this->_rel_res_vector.Setup(1,
true);
349 return PARGEMSLR_SUCCESS;
352 if(this->_absolute_tol)
358 tolr = this->_tol*normb;
362 this->_rel_res_vector.Setup(this->_maxits+1,
true);
364 this->_rel_res_vector[0] = normr/normb;
369 PargemslrPrintDashLine(pargemslr::pargemslr_global::_dash_line_width);
370 PARGEMSLR_PRINT(
"Start FlexGMRES(%d)\n",this->_kdim);
371 PARGEMSLR_PRINT(
"Residual Tol: %e\nMax number of inner iterations: %d\n", tolr, this->_maxits);
372 PargemslrPrintDashLine(pargemslr::pargemslr_global::_dash_line_width);
373 PARGEMSLR_PRINT(
"Step Residual norm Relative res. Convergence Rate\n");
375 PARGEMSLR_PRINT(
"%5d %8e %8e N/A\n", 0, normr, this->_rel_res_vector[0]);
379 while (this->_iter < this->_maxits)
390 while (i < this->_kdim && this->_iter < this->_maxits)
404 #ifdef PARGEMSLR_TIMING
405 PARGEMSLR_TIME_CALL( comm, PARGEMSLR_PRECTIME_PRECOND, (this->
_preconditioner->Solve(z, v)));
406 PARGEMSLR_TIME_CALL( comm, PARGEMSLR_ITERTIME_AMV, (this->
_matrix->MatVec(
'N', one, z, zero, w)));
409 this->
_matrix->MatVec(
'N', one, z, zero, w);
414 PARGEMSLR_MEMCPY(z.GetData(), v.GetData(), n_local, z.GetDataLocation(), v.GetDataLocation(), T);
415 #ifdef PARGEMSLR_TIMING
416 PARGEMSLR_TIME_CALL( comm, PARGEMSLR_ITERTIME_AMV, (this->
_matrix->MatVec(
'N', one, v, zero, w)));
418 this->
_matrix->MatVec(
'N', one, v, zero, w);
423 #ifdef PARGEMSLR_TIMING
424 PARGEMSLR_TIME_CALL( comm, PARGEMSLR_ITERTIME_MGS, (PargemslrMgs( w, V, H, t, i-1, RealDataType(1e-12), RealDataType(-1.0))));
426 PargemslrMgs( w, V, H, t, i-1, RealDataType(1e-12), RealDataType(-1.0));
429 if (PargemslrAbs(t) < EPSILON)
433 PARGEMSLR_PRINT(
"Break down in the current cycle\n");
441 for (j = 1; j < i; j++)
444 H(j-1,i-1) = PargemslrConj(c[j-1])*hii + s[j-1]*H(j,i-1);
445 H(j,i-1) = -s[j-1]*hii + c[j-1]*H(j,i-1);
450 gam = sqrt(PargemslrAbs(hii)*PargemslrAbs(hii) + PargemslrAbs(hii1)*PargemslrAbs(hii1));
451 if (PargemslrAbs(gam) < EPSILON)
457 rs[i] = -s[i-1] * rs[i-1];
458 rs[i-1] = PargemslrConj(c[i-1]) * rs[i-1];
460 H(i-1,i-1) = PargemslrConj(c[i-1])*hii + s[i-1]*hii1;
461 normr = PargemslrAbs(rs[i]);
463 this->_rel_res_vector[this->_iter] = normr/normb;
467 PARGEMSLR_PRINT(
"%5d %8e %8e %8.6f\n", this->_iter, normr, this->_rel_res_vector[this->_iter], this->_rel_res_vector[this->_iter] / this->_rel_res_vector[this->_iter-1]);
479 PARGEMSLR_PRINT(
"Rel. residual at the end of current cycle (# of steps per cycle: %d): %e \n", this->_kdim, this->_rel_res_vector[0]);
483 rs[i-1] /= H(i-1,i-1);
484 for ( k = i-2; k >= 0; k--)
486 for ( j = k+1; j < i; j++)
488 rs[k] -= H(k,j)*rs[j];
494 for ( j = 0; j < i; j++)
506 this->_rel_res = normr;
510 PARGEMSLR_CHKERR(i==0 && this->_iter != this->_maxits);
518 PARGEMSLR_MEMCPY(v.GetData(), rhs.GetData(), n_local, v.GetDataLocation(), rhs.GetDataLocation(), T);
519 #ifdef PARGEMSLR_TIMING
520 PARGEMSLR_TIME_CALL( comm, PARGEMSLR_ITERTIME_AMV, (this->
_matrix->MatVec(
'N', one, x, zero, w)) );
522 this->
_matrix->MatVec(
'N', one, x, zero, w);
553 this->_rel_res = normr / normb;
555 this->_rel_res_vector.Resize( this->_iter+1,
true,
false);
568 return PARGEMSLR_SUCCESS;
582 this->_location = location;
583 return PARGEMSLR_SUCCESS;
586 if(this->_location == location)
588 return PARGEMSLR_SUCCESS;
591 this->_location = location;
593 return PARGEMSLR_SUCCESS;
608 this->_tol = params[PARGEMSLR_IO_SOLVER_TOL];
609 this->_maxits = params[PARGEMSLR_IO_SOLVER_MAXITS];
610 this->_kdim = params[PARGEMSLR_IO_SOLVER_KDIM];
611 this->_absolute_tol = params[PARGEMSLR_IO_SOLVER_ATOL] != 0.0;
612 this->
_print_option = params[PARGEMSLR_IO_GENERAL_PRINT_LEVEL];
613 return PARGEMSLR_SUCCESS;
622 template <
typename T>
626 return PARGEMSLR_SUCCESS;
637 this->_maxits = maxits;
638 return PARGEMSLR_SUCCESS;
650 return PARGEMSLR_SUCCESS;
661 this->_absolute_tol = option;
662 return PARGEMSLR_SUCCESS;
670 typename std::conditional<PargemslrIsDoublePrecision<DataType>::value,
674 return this->_rel_res;
692 typename std::conditional<PargemslrIsDoublePrecision<DataType>::value,
696 return this->_rel_res_vector;
701 typedef FlexGmresClass<CsrMatrixClass<float>, SequentialVectorClass<float>,
float> fgmres_csr_seq_float;
702 typedef FlexGmresClass<CsrMatrixClass<double>, SequentialVectorClass<double>,
double> fgmres_csr_seq_double;
703 typedef FlexGmresClass<CsrMatrixClass<complexs>, SequentialVectorClass<complexs>, complexs> fgmres_csr_seq_complexs;
704 typedef FlexGmresClass<CsrMatrixClass<complexd>, SequentialVectorClass<complexd>, complexd> fgmres_csr_seq_complexd;
705 typedef FlexGmresClass<ParallelCsrMatrixClass<float>, ParallelVectorClass<float>,
float> fgmres_csr_par_float;
706 typedef FlexGmresClass<ParallelCsrMatrixClass<double>, ParallelVectorClass<double>,
double> fgmres_csr_par_double;
707 typedef FlexGmresClass<ParallelCsrMatrixClass<complexs>, ParallelVectorClass<complexs>, complexs> fgmres_csr_par_complexs;
708 typedef FlexGmresClass<ParallelCsrMatrixClass<complexd>, ParallelVectorClass<complexd>, complexd> fgmres_csr_par_complexd;