CHROMA
reliable_bicgstab.cc
Go to the documentation of this file.
1 /*! \file
2  * \brief Conjugate-Gradient algorithm for a generic Linear Operator
3  */
4 
5 #include "chromabase.h"
7 
9 
10 namespace Chroma {
11 
12  using namespace BiCGStabKernels;
13 
14  template<typename T, typename TF, typename CF>
15 SystemSolverResults_t
17  const LinearOperator<TF>& AF,
18  const T& chi,
19  T& psi,
20  const Real& RsdBiCGStab,
21  const Real& Delta,
22  int MaxBiCGStab,
23  enum PlusMinus isign)
24  {
26 
28 
29  const Subset& s = A.subset();
30 
31  bool convP = false;
32 
33  // These are all the vectors. There should be
34  // None declared later on. These declarations do 'mallocs'
35  // under the hood. Want those out of the main loop.
36  T b;
37  T tmp;
38  T r_dble;
39  T x_dble;
40 
41  TF r;
42  TF r0;
43  TF x;
44  TF p;
45  TF v;
46  TF t;
47 
48  int k;
49 
50  StopWatch swatch;
51  FlopCounter flopcount;
52  flopcount.reset();
53  swatch.reset();
54  swatch.start();
55 
56  x[s]=zero;
57  p[s] = zero;
58  v[s] = zero;
59 
60  Double rsd_sq = Double(RsdBiCGStab)*Double(RsdBiCGStab)*norm2(chi,s);
61  Double b_sq;
62 
63 
64  A(tmp, psi, isign);
65 
66  // We could do all this in a onner
67  // b_sq = minusTmpB(tmp, b, r, r0,s)
68  //
69  //b[s] = chi-tmp;
70  //b_sq = norm2(b,s);
71  // r[s] = b;
72 
73  xymz_normx(b,chi,tmp,b_sq,s);
74  r[s] = b;
75  r0[s] = b;
76  Double r_sq = b_sq;
77  QDPIO::cout << "r0 = " << b_sq << std::endl;;
78 
79  flopcount.addFlops(A.nFlops());
80  flopcount.addSiteFlops(2*Nc*Ns,s);
81  flopcount.addSiteFlops(4*Nc*Ns,s);
82 
83  Double rNorm = sqrt(r_sq);
84  Double r0Norm = rNorm;
85  Double maxrx = rNorm;
86  Double maxrr = rNorm;
87  bool updateR = false;
88  bool updateX = false;
89  int rupdates = 0;
90  int xupdates = 0;
91 
92 
93  DComplex rho, rho_prev, alpha, omega;
94 
95  DComplex ctmp;
96  Double t_norm;
97 
98  CF rho_r, alpha_r, omega_r;
99  // rho_0 := alpha := omega = 1
100  // Iterations start at k=1, so rho_0 is in rho_prev
101  rho = Double(1);
102  rho_prev = Double(1);
103  alpha = Double(1);
104  omega = Double(1);
105 
106  // The iterations
107  for(k = 0; k < MaxBiCGStab && !convP ; k++) {
108 
109  if( k == 0 ) {
110  // I know that r_0 = r so <r_0|r>=norm2(r) = r_sq
111  // rho = innerProduct(r0,r,s);
112  rho = r_sq;
113  p[s] = r;
114  }
115  else {
116  DComplex beta =(rho / rho_prev) * (alpha/omega);
117  CF beta_r = beta;
118  omega_r = omega;
119 
120  // NB: This could be done in a onner
121  // rPlusBetaPMinusBetaOmegav(p, r, v, beta, omega, s)
122 
123  // p = r + beta(p - omega v)
124  // first work out p - omega v
125  // into tmp
126  // then do p = r + beta tmp
127 
128 
129  // tmp[s] = p - omega_r*v;
130  // p[s] = r + beta_r*tmp;
131  yxpaymabz(r, p, v, beta_r, omega_r, s);
132 
133  }
134 
135  // v = Ap
136  AF(v,p,isign);
137 
138  // alpha = rho_{k+1} / < r_0 | v >
139  // put <r_0 | v > into tmp
140  ctmp = innerProduct(r0,v,s);
141 
142  if( toBool( real(ctmp) == 0 ) && toBool( imag(ctmp) == 0 ) ) {
143  QDPIO::cout << "BiCGStab breakdown: <r_0|v> = 0" << std::endl;
144  QDP_abort(1);
145  }
146 
147  alpha = rho / ctmp;
148 
149  // Done with rho now, so save it into rho_prev
150  rho_prev = rho;
151 
152  // s = r - alpha v
153  // I can overlap s with r, because I recompute it at the end.
154  alpha_r = alpha;
155  // r[s] -= alpha_r*v;
156  cxmay(r,v,alpha_r,s);
157 
158 
159  // t = As = Ar
160  AF(t,r,isign);
161 
162 
163  // omega = < t | s > / < t | t > = < t | r > / norm2(t);
164  // accumulate <t | s > = <t | r> into omega
165 
166  // As Mike tells me, I could do these together.
167  // I can probably reduce these to a single ALLREDUCE/QMP_sum_double_array()
168  //
169  // some routine like: t_norm = normXCdotXY(t,r,s, iprod_r, iprod_i)
170  // Double t_norm = norm2(t,s);
171  // omega = innerProduct(t,r,s);
172 
173 
174  norm2x_cdotxy(t,r, t_norm, omega, s);
175 
176  omega /= t_norm;
177 
178  // again
179  // This is a simple xPlusAYPlusBz(x,r,p,omega,alpha)
180  // psi = psi + omega s + alpha p
181  // = psi + omega r + alpha p
182  //
183  // use tmp to compute psi + omega r
184  // then add in the alpha p
185  omega_r = omega;
186  // tmp[s] = x + omega_r*r;
187  // x[s] = tmp + alpha_r*p;
188 
189 
190  xpaypbz(x,r,p,omega_r, alpha_r,s);
191 
192  // r = s - omega t = r - omega t1G
193 
194  // I can roll this all together
195  // r_sq = XMinusAYNormXCDotZX(r,t,r0,omega_r, omega_i, rho_r, rho_i, s),
196  // r[s] -= omega_r*t;
197  // r_sq = norm2(r,s);
198  // rho = innerProduct(r0,r,s);
199 
200  xmay_normx_cdotzx(r, t, r0, omega_r, r_sq, rho,s);
201 
202  // Flops so far: Standard BiCGStab Flops
203  // -----------------------------------------
204  flopcount.addSiteFlops(80*Nc*Ns,s);
205  flopcount.addFlops(2*A.nFlops());
206  // ------------------------------------------
207 
208  rNorm = sqrt(r_sq);
209 
210  if( toBool( rNorm > maxrx) ) maxrx = rNorm;
211  if( toBool( rNorm > maxrr) ) maxrr = rNorm;
212 
213  updateX = toBool ( rNorm < Delta*r0Norm && r0Norm <= maxrx );
214  updateR = toBool ( rNorm < Delta*maxrr && r0Norm <= maxrr ) || updateX;
215 
216  if( updateR ) {
217  // QDPIO::cout << "Iter " << k << ": updating r " << std::endl;
218  rupdates++;
219 
220  x_dble[s] = x;
221 
222  A(tmp, x_dble, isign); // Use full solution so far
223 
224  // Roll this together - can eliminate r_dble which is an intermediary
225 
226  // r_dble[s] = b - tmp2;
227  // r_sq = norm2(r_dble,s);
228  // r[s] = r_dble;
229  xymz_normx(r_dble, b,tmp, r_sq,s);
230  r[s]=r_dble;
231 
232  flopcount.addSiteFlops(6*Nc*Ns,s);
233  flopcount.addFlops(A.nFlops());
234 
235  rNorm = sqrt(r_sq);
236  maxrr = rNorm;
237 
238 
239  if( updateX ) {
240  xupdates++;
241  //QDPIO::cout << "Iter " << k << ": updating x " << std::endl;
242  if( ! updateR ) { x_dble[s]=x; } // if updateR then this is done already
243  psi[s] += x_dble; // Add on group accumulated solution in y
244  flopcount.addSiteFlops(2*Nc*Ns,s);
245 
246  x[s] = zero; // zero y
247  b[s] = r_dble;
248  r0Norm = rNorm;
249  maxrx = rNorm;
250  }
251 
252  }
253 
254 
255  if( toBool(r_sq < rsd_sq ) ) {
256 
257  convP = true;
258 
259  // if updateX true, then we have just updated psi
260  // strictly x[s] should be zero, so it should be OK to add it
261  // but why do the work if you don't need to
262  x_dble[s] = x;
263  psi[s]+=x_dble;
264  flopcount.addSiteFlops(2*Nc*Ns,s);
265  ret.resid = rNorm;
266  ret.n_count = k;
267  }
268  else {
269  convP = false;
270  }
271 
272 
273 
274  }
275  swatch.stop();
276  if( k >= MaxBiCGStab ) {
277  QDPIO::cerr << "Nonconvergence of reliable BiCGStab. MaxIters = " << MaxBiCGStab << " exceeded" << std::endl;
278  QDP_abort(1);
279  }
280  else {
281  QDPIO::cout << "reliable_bicgstab: n_count " << ret.n_count << " r-updates: " << rupdates << " xr-updates: " << xupdates << std::endl;
282  flopcount.report("reliable_bicgstab", swatch.getTimeInSeconds());
283  }
284 
286  return ret;
287 
288 }
289 
290 
291 
292 
293 
294 SystemSolverResults_t
296  const LatticeFermionF& chi,
297  LatticeFermionF& psi,
298  const Real& RsdBiCGStab,
299  const Real& Delta,
300  int MaxBiCGStab,
301  enum PlusMinus isign)
302 
303 {
304  return RelInvBiCGStab_a<LatticeFermionF,LatticeFermionF, ComplexF>(A,A, chi, psi, RsdBiCGStab, Delta, MaxBiCGStab, isign);
305 }
306 
307  // Pure double
308 SystemSolverResults_t
310  const LatticeFermionD& chi,
311  LatticeFermionD& psi,
312  const Real& RsdBiCGStab,
313  const Real& Delta,
314  int MaxBiCGStab,
315  enum PlusMinus isign)
316 
317 {
318  return RelInvBiCGStab_a<LatticeFermionD, LatticeFermionD, ComplexD>(A,A, chi, psi, RsdBiCGStab, Delta, MaxBiCGStab, isign);
319 }
320 
321  // single double
322 SystemSolverResults_t
325  const LatticeFermionD& chi,
326  LatticeFermionD& psi,
327  const Real& RsdBiCGStab,
328  const Real& Delta,
329  int MaxBiCGStab,
330  enum PlusMinus isign)
331 
332 {
333  return RelInvBiCGStab_a<LatticeFermionD, LatticeFermionF, ComplexF>(A,AF, chi, psi, RsdBiCGStab, Delta, MaxBiCGStab, isign);
334 }
335 
336 
337 #if 0
338 
339 #endif
340 
341 } // end namespace Chroma
Primary include file for CHROMA library code.
Linear Operator.
Definition: linearop.h:27
SystemSolverResults_t InvBiCGStabReliable(const LinearOperator< LatticeFermionF > &A, const LatticeFermionF &chi, LatticeFermionF &psi, const Real &RsdBiCGStab, const Real &Delta, int MaxBiCGStab, enum PlusMinus isign)
Bi-CG stabilized.
int x
Definition: meslate.cc:34
int t
Definition: meslate.cc:37
void xpaypbz(T &x, T &y, T &z, C &a, C &b, const Subset &s)
void xmay_normx_cdotzx(T &x, const T &y, const T &z, C &a, Double &normx, DComplex &cdotzx, const Subset &s)
void xymz_normx(T &x, const T &y, const T &z, Double &x_norm, const Subset &s)
void norm2x_cdotxy(const T &x, const T &y, Double &norm2x, DComplex &cdotxy, const Subset &s)
void cxmay(T &x, const T &y, const C &a, const Subset &s)
void yxpaymabz(T &x, T &y, T &z, const C &a, const C &b, const Subset &s)
BinaryReturn< C1, C2, FnInnerProduct >::Type_t innerProduct(const QDPSubType< T1, C1 > &s1, const QDPType< T2, C2 > &s2)
static const LatticeInteger & beta(const int dim)
Definition: stag_phases_s.h:47
static const LatticeInteger & alpha(const int dim)
Definition: stag_phases_s.h:43
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
LatticeFermion tmp
Definition: mespbg5p_w.cc:36
LinOpSysSolverMGProtoClover::T T
Real rsd_sq
Definition: invbicg.cc:121
multi1d< LatticeFermion > chi(Ncb)
SystemSolverResults_t RelInvBiCGStab_a(const LinearOperator< T > &A, const LinearOperator< TF > &AF, const T &chi, T &psi, const Real &RsdBiCGStab, const Real &Delta, int MaxBiCGStab, enum PlusMinus isign)
Complex omega
Definition: invbicg.cc:97
LatticeFermion psi
Definition: mespbg5p_w.cc:35
A(A, psi, r, Ncb, PLUS)
Complex b
Definition: invbicg.cc:96
Double zero
Definition: invbicg.cc:106
int k
Definition: invbicg.cc:119
multi1d< LatticeFermion > s(Ncb)
FloatingPoint< double > Double
Definition: gtest.h:7351
int r0
Definition: qtopcor.cc:41
BiCGStab Solver with reliable updates.
Holds return info from SystemSolver call.
Definition: syssolver.h:17
LatticeFermionF TF
Definition: t_quda_tprec.cc:17