CHROMA
reliable_ibicgstab.cc
Go to the documentation of this file.
1 /*! \file
2  * \brief Conjugate-Gradient algorithm for a generic Linear Operator
3  */
4 
5 #include "chromabase.h"
7 
9 
10 namespace Chroma {
11 
12  using namespace BiCGStabKernels;
13 
14  template<typename T, typename TF, typename CF>
15 SystemSolverResults_t
17  const LinearOperator<TF>& AF,
18  const T& chi,
19  T& psi,
20  const Real& RsdBiCGStab,
21  const Real& Delta,
22  int MaxBiCGStab,
23  enum PlusMinus isign)
24  {
26 
28  StopWatch swatch;
29  FlopCounter flopcount;
30  flopcount.reset();
31  const Subset& sub = A.subset();
32  bool convP = false;
33 
34  T b;
35  T dtmp;
36  T r_dble;
37  T x_dble;
38 
39 
40  TF r;
41  TF r0;
42  TF v;
43  TF tmp;
44  TF t;
45  TF s;
46  TF u;
47  TF q;
48  TF f0;
49  TF z;
50  TF x;
51 
52  TF vn_1, zn_1,qn_1;
53 
54  ComplexD rhon, rhon_1; // rho_n, rho_{n-1}
55  ComplexD alphan, alphan_1; // alpha_{n}, alpha_{n-1}
56  ComplexD omegan, omegan_1; // omega_{n} AND omega_{n-1} (before omega update)
57  ComplexD taun,taun_1; // tau_{n}, tau_{n-1}
58  RealD kappa;
59  RealD rnorm_prev; // lagged residuum
60  ComplexD theta;
61  ComplexD sigman_2, sigman_1;
62  ComplexD phin, phin_1; // phi_{n} AND phi_{n-1} before phi update
63  ComplexD pin, pin_1; // pi_{n} and pi_{n-1} before pi update
64  ComplexD gamma, eta;
65  swatch.reset();
66  swatch.start();
67 
68 
69  Double rsd_sq = RsdBiCGStab*RsdBiCGStab*norm2(chi,sub);
70  Double b_sq;
71  flopcount.addSiteFlops(4*Nc*Ns,sub);
72 
73  // Now initialise x_0 = v_0 = q_0 = z_0 = 0
74  x[sub] = zero;
75  vn_1[sub] = zero;
76  qn_1[sub] = zero;
77  zn_1[sub] = zero;
78 
79 
80 
81  // r = r0 = chi - A psi_0
82  A(dtmp, psi, isign);
83  xymz_normx(b,chi,dtmp,b_sq,sub);
84  r[sub] = b;
85  r0[sub] = b;
86 
87  QDPIO::cout << "r0 = " << b_sq << std::endl;;
88 
89  flopcount.addFlops(A.nFlops());
90  flopcount.addSiteFlops(2*Nc*Ns,sub);
91  flopcount.addSiteFlops(4*Nc*Ns,sub);
92 
93  Double rNorm = sqrt(b_sq);
94  Double r0Norm = rNorm;
95  Double maxrx = rNorm;
96  Double maxrr = rNorm;
97  bool updateR = false;
98  bool updateX = false;
99  int xupdates = 0;
100  int rupdates = 0;
101 
102  // f0 = A^\dag r_0 and plays a role like r_0 in inner products
103  if(isign == PLUS) {
104  AF(f0,r0,MINUS);
105  }
106  else {
107  AF(f0,r0, PLUS);
108  }
109  flopcount.addFlops(A.nFlops());
110 
111 
112  // u_0 = A r_0
113  AF(u,r,isign);
114  flopcount.addFlops(A.nFlops());
115 
116 
117  // rho_0 := alpha_0 := omega_0 = 1
118  rhon_1 = Double(1);
119  alphan_1 = Double(1);
120  omegan_1 = Double(1);
121 
122  // tau_0 = <r0,v0> = 0 since v0 = 0
123  taun_1 = Double(0);
124 
125  // sigma_{-1} = 0
126  sigman_2 = Double(0);
127 
128  // pi_0 = <r0,q_0> = 0
129  pin_1 = Double(0);
130 
131  // sigma_{0} = <r0, Ar_0> = <r0, u_0>
132  sigman_1 = innerProduct(r0,u,sub);
133 
134  //**** NB: Paper says phi_0 = 0.
135  //**** PETSc code claims it is <r0,r0>
136  //**** Using phi_0 = 0 leads to rho_1 = 0
137  //**** which leads to breakdown in iteration 2
138  //**** So I will use the PETSc version which appears to work
139 
140  // phi_0 = <r0,r0>
141  phin_1 = innerProduct(r0,r0,sub);
142  flopcount.addSiteFlops(16*Nc*Ns,sub); // (for sigma & phi)
143 
144  /********** ITERATION STARTS HERE ***************/
145 
146  // OK The big for loop
147  for(int n = 1; n <= MaxBiCGStab && !convP ; n++) {
148 
149  // Regular BiCGStab: rho = <r0,r>
150  // For IBiCGStab this has been unwound into
151  // the recurrencce:
152  //
153  // rho = phi_{n-1} - omega_{n-1}*( sigma_{n-2} - alpha_{n-1}*pi_{n-1} )
154  //
155  // Check for rho=0. If it is it will lead to breakdown in next iteration
156  rhon = phin_1 - omegan_1*(sigman_2 - alphan_1*pin_1); // 16 flops
157 
158 
159  if( toBool( real(rhon) == 0 ) && toBool( imag(rhon) == 0 ) ) {
160  QDPIO::cout << "BiCGStab breakdown: rho = 0" << std::endl;
161  QDP_abort(1);
162  }
163 
164 
165  // Regular BiCGStab: beta = ( rho_{n}/rho_{n-1})(alpha_{n-1}/omega_{n-1})
166  //
167  // For IBiCGStab where one can use delta_n = beta*omega_{n-1}
168  // it is useful to compute:
169  //
170  // delta_n = (rho_{n}/rho_{n-1})*alpha_{n-1}
171  // beta_n = delta_n/ omega_{n-1}
172  ComplexD beta;
173  ComplexD delta;
174  delta =( rhon / rhon_1 ) * alphan_1; // 15 flops
175  beta = delta/omegan_1; // 9 flops
176 
177 
178  // tau_n = <r0, v> needed for denominator of alpha
179  // but can be updated by recurrance
180  taun = sigman_1 + beta*(taun_1- omegan_1*pin_1); // 16 flops
181 
182  if( toBool( real(taun) == 0 ) && toBool( imag(taun) == 0 ) ) {
183  QDPIO::cout << "BiCGStab breakdown: n="<<n<<" <r_0|v> = 0" << std::endl;
184  QDPIO::cout << "sigman_1 = " << sigman_1 << std::endl;
185  QDPIO::cout << "beta= " << beta << std::endl;
186  QDPIO::cout << "taun_1 = " << taun_1 << std::endl;
187  QDPIO::cout << "ometan_1 = " << omegan_1 << std::endl;
188  QDPIO::cout << "pin_1 = " << pin_1 << std::endl;
189 
190  QDP_abort(1);
191  }
192 
193  // form alpha = rho/tau
194  alphan = rhon / taun; // 9 flops
195 
196 
197  // z_n plays role of alpha_n p_n in normal BiCGstab
198  // it is only used to update the solution.
199  //
200  // NB one needs alpha_n p_{-1}
201  // = (alpha_n/alpha_{n-1}) (alpha_{n-1} p_{n-1})
202  // = (alpha_n/alpha_{n-1}) z_{n-1}
203  //
204  // The Paper in line (12) of the algorithm leaves out this
205  // (alpha_n/alpha_{n-1}) factor.
206  //
207  // Also z update needs to be pulled before the v update (line 8) of paper
208  // otherwise a shadow copy of v_{n-1} needs to be kept,
209  //
210  // z = alphan r_n-1 + (beta*alphan/alpha_{n-1})*z_{n-1}
211  // - (beta*alphan*omegan_1)*v_{n-1}
212  ComplexD bar = beta*alphan/alphan_1; // 15 flops
213  ComplexD alphadelta = alphan*delta; // 6 flops
214 #if 0
215  tmp[sub] = bar*zn_1;
216  z[sub] = alphan*r+tmp;
217  z[sub] -= alphadelta*vn_1; // 22 Nc*Ns flops/site
218 
219 
220  // v = u_{n-1} + beta*v_{n-1} - beta*omegan_1*q_n_1
221  tmp[sub] = beta*vn_1; // 6Nc Ns
222  v[sub] = u + tmp; // 2Nc Ns
223  v[sub] -= delta*qn_1; // 8Nc Ns
224 #else
225  v[sub]=vn_1;
226  z[sub]=zn_1;
227 
228  ibicgstab_zvupdates(r,z,v,u,qn_1,alphan, bar, alphadelta, beta, delta, sub);
229 #endif
230 
231  // q = Av
232  AF(q,v,isign);
233 
234 #if 0
235  // t = u - alpha q
236  t[sub] = u - alphan * q; // 8 Nc Ns
237 
238  // s = r - alpha v
239  s[sub] = r - alphan*v; // 8 Nc Ns
240 
241 
242  // This should all be done with one sync point
243  // BIG ALLREDUCE
244 
245 
246  phin = innerProduct(r0,s,sub); // 8 Nc Ns flops/site
247  gamma = innerProduct(f0,s,sub); // 8 Nc Ns flops/site
248  pin = innerProduct(r0,q,sub); // 8 Nc Ns flops/site
249  eta = innerProduct(f0,t,sub); // 8 Nc Ns flops/site
250  theta = innerProduct(t,s,sub); // 8 Nc Ns flops/site
251  kappa = norm2(t,sub); // 4 Nc Ns flops/site
252  rnorm_prev = norm2(r,sub); // 4 Nc Ns flops/site
253 #else
255  r,
256  u,
257  v,
258  q,
259  r0,
260  f0,
261  s,
262  t,
263  phin,
264  pin,
265  gamma,
266  eta,
267  theta,
268  kappa,
269  rnorm_prev,
270  sub);
271 #endif
272 
273  // Collected flopcounts
274  // coefficient recurrences: flopcount.addFlops(86);
275  // z & v updates flopcount.addSiteFlops(38*Nc*Ns, sub);
276  // q = Av flopcount.addFlops(A.nFlops());
277  // s & t updates: flopcount.addSiteFlops(16*Nc*Ns, sub);
278  // 5 inner products, 2 norms: flopcount.addSiteFlops(48*Nc*Ns, sub);
279  flopcount.addFlops(A.nFlops() + 86);
280  flopcount.addSiteFlops(102*Nc*Ns, sub);
281 
282 
283  // Check Norm: See if we converged in the last iteration
284  // If so go to exit. This is yucky but makes the rest
285  // of the logic easier
286  if( toBool(rnorm_prev < rsd_sq ) ) {
287  // Yes we've converged
288  convP = true;
289 
290  // if updateX true, then we have just updated psi
291  // strictly x[sub] should be zero, so it should be OK to add it
292  // but why do the work if you don't need to
293  x_dble[sub] = x;
294  psi[sub]+=x_dble;
295  flopcount.addSiteFlops(2*Nc*Ns,sub);
296  ret.resid = sqrt(rnorm_prev);
297  ret.n_count = n;
298  goto exit;
299  }
300 
301 
302 #if 1
303  // Begin Reliable Updating ideas...
304  rNorm = sqrt(rnorm_prev);
305  if( toBool( rNorm > maxrx) ) maxrx = rNorm;
306  if( toBool( rNorm > maxrr) ) maxrr = rNorm;
307 
308  updateX = toBool ( rNorm < Delta*r0Norm && r0Norm <= maxrx );
309  updateR = toBool ( rNorm < Delta*maxrr && r0Norm <= maxrr ) || updateX;
310 
311  if( updateR ) {
312 
313  // Replace last residuum
314  x_dble[sub] = x;
315 
316  A(dtmp, x_dble, isign); // Use full solution so far
317 
318  // r_dble[sub] = b - tmp2;
319  // r_sq = norm2(r_dble,sub);
320  // r[s] = r_dble;
321  xymz_normx(r_dble, b,dtmp, rnorm_prev,sub);
322  r[sub]=r_dble;
323  flopcount.addFlops(A.nFlops());
324  flopcount.addSiteFlops(6*Nc*Ns,sub);
325 
326 
327  // Must also reset un_1
328  AF(u,r,isign);
329 
330  // Recomputing sigma_{n-1} = < r0, u_{n-1} >
331  // from the new u_{n-1} really helps stability
332  sigman_1 = innerProduct(r0,u,sub);
333  flopcount.addFlops(A.nFlops());
334  flopcount.addSiteFlops(8*Nc*Ns,sub);
335 
336 
337  rNorm = sqrt(rnorm_prev);
338  maxrr = rNorm;
339  rupdates++;
340 
341  if( updateX ) {
342  //QDPIO::cout << "Iter " << k << ": updating x " << std::endl;
343  if( ! updateR ) { x_dble[sub]=x; } // if updateR then this is done already
344  psi[sub] += x_dble; // Add on group accumulated solution in y
345  flopcount.addSiteFlops(2*Nc*Ns,sub);
346 
347  x[sub] = zero; // zero y
348  b[sub] = r_dble;
349  r0Norm = rNorm;
350  maxrx = rNorm;
351  xupdates++;
352  }
353 
354  }
355 
356 #endif
357 
358  if( ! updateR ) {
359 
360  // Carry on with this iteration
361  // Check kappa for breakdown
362  if( toBool(kappa == 0) ) {
363  QDPIO::cerr << "Breakdown || Ms || = || t || = 0 " << std::endl;
364  QDP_abort(1);
365  }
366 
367  // Regular BiCGStab omega_n = <t,s> / <t,t> = theta/kappa
368  omegan = theta/kappa; // 9 flops
369 
370 
371 
372 #if 0
373  // Update r, x
374  // r = s - omega t
375  r[sub] = s - omegan*t;
376 
377  // x = x + omega s + z
378  tmp[sub] = x + omegan*s;
379  x[sub] = tmp + z;
380 #else
381  ibicgstab_rxupdate(omegan,s,t,z,r,x,sub);
382 #endif
383 
384  // Recompute next u = A r
385  AF(u,r,isign);
386 
387  // sigma_n = <r0, A u_n> = gamma_n - omega_n * eta_n;
388  //
389  // NB: sigma_n is never explicitly used only sigma_{n-1}, sigma_{n-2}
390  // So if I stuck this in sigma_n, I'd just end up moving it to sigma_{n-1}
391  // So I'll just stick it straight into sigma_{n-1}
392  sigman_2 = sigman_1; // Preserve
393  sigman_1 = gamma - omegan*eta; // 8 flops
394 
395  // Update past values: Some of this could be saved I am sure
396  rhon_1 = rhon;
397  alphan_1 = alphan;
398  taun_1 = taun;
399  omegan_1 = omegan;
400  pin_1 = pin;
401  phin_1 = phin;
402 
403  vn_1[sub]=v;
404  qn_1[sub]=q;
405  zn_1[sub]=z;
406 
407  // Collected Flops
408  // Omega + Sigma Updates: flopcount.addFlops(17);
409  // r & x updates: flopcount.addSiteFlops(18*Nc*Ns, sub)
410  // u update flopcount.addFlops(A.NFlops)
411  flopcount.addFlops(A.nFlops()+17);
412  flopcount.addSiteFlops(18*Nc*Ns,sub);
413  }
414  }
415 
416 exit: swatch.stop();
417 
418  QDPIO::cout << "InvIBiCGStabReliable: n = " << ret.n_count << " r-updates: " << rupdates << " xr-updates: " << xupdates << std::endl;
419 
420  flopcount.report("reliable_invibicgstab", swatch.getTimeInSeconds());
421 
422  if ( ret.n_count == MaxBiCGStab ) {
423  QDPIO::cerr << "Nonconvergence of IBiCGStab. MaxIters reached " << std::endl;
424  }
425 
426 
428  return ret;
429 
430 }
431 
432 
433 
434 
435 
436 SystemSolverResults_t
438  const LatticeFermionF& chi,
439  LatticeFermionF& psi,
440  const Real& RsdBiCGStab,
441  const Real& Delta,
442  int MaxBiCGStab,
443  enum PlusMinus isign)
444 
445 {
446  return RelInvIBiCGStab_a<LatticeFermionF,LatticeFermionF, ComplexF>(A,A, chi, psi, RsdBiCGStab, Delta, MaxBiCGStab, isign);
447 }
448 
449  // Pure double
450 SystemSolverResults_t
452  const LatticeFermionD& chi,
453  LatticeFermionD& psi,
454  const Real& RsdBiCGStab,
455  const Real& Delta,
456  int MaxBiCGStab,
457  enum PlusMinus isign)
458 
459 {
460  return RelInvIBiCGStab_a<LatticeFermionD, LatticeFermionD, ComplexD>(A,A, chi, psi, RsdBiCGStab, Delta, MaxBiCGStab, isign);
461 }
462 
463  // single double
464 SystemSolverResults_t
467  const LatticeFermionD& chi,
468  LatticeFermionD& psi,
469  const Real& RsdBiCGStab,
470  const Real& Delta,
471  int MaxBiCGStab,
472  enum PlusMinus isign)
473 
474 {
475  return RelInvIBiCGStab_a<LatticeFermionD, LatticeFermionF, ComplexF>(A,AF, chi, psi, RsdBiCGStab, Delta, MaxBiCGStab, isign);
476 }
477 
478 
479 #if 0
480 
481 #endif
482 
483 } // end namespace Chroma
Primary include file for CHROMA library code.
Linear Operator.
Definition: linearop.h:27
SystemSolverResults_t InvIBiCGStabReliable(const LinearOperator< LatticeFermionF > &A, const LatticeFermionF &chi, LatticeFermionF &psi, const Real &RsdBiCGStab, const Real &Delta, int MaxBiCGStab, enum PlusMinus isign)
Bi-CG stabilized.
unsigned n
Definition: ldumul_w.cc:36
int z
Definition: meslate.cc:36
int x
Definition: meslate.cc:34
int t
Definition: meslate.cc:37
Double q
Definition: mesq.cc:17
void xymz_normx(T &x, const T &y, const T &z, Double &x_norm, const Subset &s)
void ibicgstab_rxupdate(const C &omega, const T &s, const T &t, const T &z, T &r, T &x, const Subset &sub)
void ibicgstab_zvupdates(const T &r, T &z, T &v, const T &u, const T &q, const C &alpha, const C &alpha_rat_beta, const C &alpha_delta, const C &beta, const C &delta, const Subset &s)
void ibicgstab_stupdates_reduces(const C &alpha, const T &r, const T &u, const T &v, const T &q, const T &r0, const T &f0, T &s, T &t, C &phi, C &pi, C &gamma, C &eta, C &theta, F &kappa, F &rnorm, const Subset &sub)
BinaryReturn< C1, C2, FnInnerProduct >::Type_t innerProduct(const QDPSubType< T1, C1 > &s1, const QDPType< T2, C2 > &s2)
static const LatticeInteger & beta(const int dim)
Definition: stag_phases_s.h:47
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
static multi1d< LatticeColorMatrix > u
LatticeFermion tmp
Definition: mespbg5p_w.cc:36
LinOpSysSolverMGProtoClover::T T
Real rsd_sq
Definition: invbicg.cc:121
@ MINUS
Definition: chromabase.h:45
@ PLUS
Definition: chromabase.h:45
multi1d< LatticeFermion > chi(Ncb)
LatticeFermion psi
Definition: mespbg5p_w.cc:35
LatticeFermion eta
Definition: mespbg5p_w.cc:37
A(A, psi, r, Ncb, PLUS)
Complex b
Definition: invbicg.cc:96
SystemSolverResults_t RelInvIBiCGStab_a(const LinearOperator< T > &A, const LinearOperator< TF > &AF, const T &chi, T &psi, const Real &RsdBiCGStab, const Real &Delta, int MaxBiCGStab, enum PlusMinus isign)
Double zero
Definition: invbicg.cc:106
multi1d< LatticeFermion > s(Ncb)
FloatingPoint< double > Double
Definition: gtest.h:7351
int kappa
Definition: pade_trln_w.cc:112
int r0
Definition: qtopcor.cc:41
BiCGStab Solver with reliable updates.
Holds return info from SystemSolver call.
Definition: syssolver.h:17
LatticeFermionF TF
Definition: t_quda_tprec.cc:17