CHROMA
reliable_cg.cc
Go to the documentation of this file.
1 /*! \file
2  * \brief Conjugate-Gradient algorithm for a generic Linear Operator
3  */
4 
5 #include "chromabase.h"
7 
8 namespace Chroma {
9 
10  template<typename T, typename TF, typename RF>
11 SystemSolverResults_t
13  const LinearOperator<TF>& AF,
14  const T& chi,
15  T& psi,
16  const Real& RsdCG,
17  const Real& Delta,
18  int MaxCG)
19  {
20  START_CODE();
22 
23  const Subset& s = A.subset();
24 
25  bool convP = false;
26 
27  // First get r = r0 = chi - A psi
28  TF r;
29  T b;
30  T r_dble;
31  T x_dble;
32  int k;
33 
34  StopWatch swatch;
35  FlopCounter flopcount;
36  flopcount.reset();
37  swatch.reset();
38  swatch.start();
39 
40 
41 
42  b[s] = chi;
43  Double chi_norm = norm2(chi,s);
45 
46  {
47  T tmp1, tmp2;
48  A(tmp1, psi, PLUS);
49  A(tmp2, tmp1, MINUS);
50  b[s] -= tmp2;
51  flopcount.addFlops(2*A.nFlops());
52  flopcount.addSiteFlops(2*Nc*Ns,s);
53  }
54 
55  TF x; x[s]=zero;
56 
57  // now work out r= chi - Apsi = chi - r0
58  r[s] = b;
59 
60  Double r_sq = norm2(r,s);
61  flopcount.addSiteFlops(4*Nc*Ns,s);
62 
63 
64  QDPIO::cout << "Reliable CG: || r0 ||/|| b ||=" << sqrt(r_sq/chi_norm) << std::endl;
65 
66 
67  Double rNorm = sqrt(r_sq);
68  Double r0Norm = rNorm;
69  Double maxrx = rNorm;
70  Double maxrr = rNorm;
71  bool updateR = false;
72  bool updateX = false;
73 
74 
75  // Now initialise v = p = 0
76  TF p;
77  Double a, c, d;
78 
79  // The iterations
80  for(k = 0; k < MaxCG && !convP; k++) {
81  if( k == 0 ) {
82  p[s] = r;
83  }
84  else {
85  Double beta = r_sq / c;
86  RF br = beta;
87  p[s] = r + br*p; flopcount.addSiteFlops(4*Nc*Ns,s);
88  }
89 
90  c = r_sq;
91 
92  TF mmp,mp;
93  AF(mp, p, PLUS);
94  d = norm2(mp,s);
95  AF(mmp,mp,MINUS);
96 
97  a = c/d;
98  RF ar = a;
99  x[s] += ar*p;
100  r[s] -= ar*mmp;
101 
102  r_sq = norm2(r,s);
103 
104  // flopcount.addSiteFlops(4*Nc*Ns,s); <mp, mp>
105  // flopcount.addSiteFlops(4*Nc*Ns,s); x += a * p
106  // flopcount.addSiteFlops(4*Nc*Ns,s); r -= a * mm
107  // flopcount.addSiteFlops(4*Nc*Ns,s); norm2(r)
108  flopcount.addSiteFlops(16*Nc*Ns,s);
109  flopcount.addFlops(2*A.nFlops());
110 
111  // Reliable update part...
112  rNorm = sqrt(r_sq);
113  if( toBool( rNorm > maxrx) ) maxrx = rNorm;
114  if( toBool( rNorm > maxrr) ) maxrr = rNorm;
115 
116  updateX = toBool ( rNorm < Delta*r0Norm && r0Norm <= maxrx );
117  updateR = toBool ( rNorm < Delta*maxrr && r0Norm <= maxrr ) || updateX;
118 
119  // Do the R update with real DP residual
120  if( updateR ) {
121 
122  {
123  T tmp1,tmp2;
124  x_dble[s] = x;
125 
126  A(tmp1, x_dble, PLUS); // Use full solution so far
127  A(tmp2, tmp1, MINUS); // Use full solution so far
128 
129  r_dble[s] = b - tmp2;
130  }
131 
132  r[s] = r_dble; // new R = b - Ax
133  r_sq = norm2(r_dble,s);
134 
135  flopcount.addSiteFlops(6*Nc*Ns,s); // 4 from norm2, 2 from r=b-tmp2
136  flopcount.addFlops(2*A.nFlops());
137 
138  rNorm = sqrt(r_sq);
139  maxrr = rNorm;
140 
141  // Group wise x update
142  if( updateX ) {
143  if( ! updateR ) { x_dble[s]=x; } // if updateR then this is done already
144  psi[s] += x_dble; // Add on group accumulated solution in y
145  flopcount.addSiteFlops(2*Nc*Ns,s);
146 
147  x[s] = zero; // zero y
148  b[s] = r_dble;
149  r0Norm = rNorm;
150  maxrx = rNorm;
151  }
152 
153  }
154 
155  // Convergence check
156  if( toBool(r_sq < rsd_sq ) ) {
157  // We've converged.
158 
159  // if updateX true, then we have just updated psi
160  // strictly x[s] should be zero, so it should be OK to add it
161  // but why do the work if you don't need to
162  x_dble[s] = x;
163  psi[s]+=x_dble;
164  flopcount.addSiteFlops(2*Nc*Ns,s);
165  ret.resid = rNorm;
166  ret.n_count = k;
167  convP = true;
168  }
169  else {
170  convP = false;
171  }
172 
173  }
174 
175  // Loop is finished. Report FLOP Count...
176  swatch.stop();
177  flopcount.report("reliable_invcg2", swatch.getTimeInSeconds());
178 
179  // Check for nonconvergence
180  if( k >= MaxCG ) {
181  QDPIO::cout << "Nonconvergence: Reliable CG Failed to converge in " << MaxCG << " iterations " << std::endl;
182  QDP_abort(1);
183  }
184 
185  // Done
186  END_CODE();
187  return ret;
188  }
189 
190 
191 
192 SystemSolverResults_t
194  const LatticeFermionF& chi,
195  LatticeFermionF& psi,
196  const Real& RsdCG,
197  const Real& Delta,
198  int MaxCG)
199 
200 {
201  return RelInvCG_a<LatticeFermionF,LatticeFermionF, RealF>(A,A, chi, psi, RsdCG, Delta, MaxCG);
202 }
203 
204  // Pure double
205 SystemSolverResults_t
207  const LatticeFermionD& chi,
208  LatticeFermionD& psi,
209  const Real& RsdCG,
210  const Real& Delta,
211  int MaxCG)
212 {
213  return RelInvCG_a<LatticeFermionD, LatticeFermionD, RealD>(A,A, chi, psi, RsdCG, Delta, MaxCG);
214 }
215 
216  // single double
217 SystemSolverResults_t
220  const LatticeFermionD& chi,
221  LatticeFermionD& psi,
222  const Real& RsdCG,
223  const Real& Delta,
224  int MaxCG)
225 {
226  return RelInvCG_a<LatticeFermionD, LatticeFermionF, RealF>(A,AF, chi, psi, RsdCG, Delta, MaxCG);
227 }
228 
229 
230 } // end namespace Chroma
Primary include file for CHROMA library code.
Linear Operator.
Definition: linearop.h:27
SystemSolverResults_t InvCGReliable(const LinearOperator< LatticeFermionF > &A, const LatticeFermionF &chi, LatticeFermionF &psi, const Real &RsdCG, const Real &Delta, int MaxCG)
Bi-CG stabilized.
Definition: reliable_cg.cc:193
int x
Definition: meslate.cc:34
Double tmp2
Definition: mesq.cc:30
static const LatticeInteger & beta(const int dim)
Definition: stag_phases_s.h:47
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
const WilsonTypeFermAct< multi1d< LatticeFermion > > Handle< const ConnectState > const multi1d< Real > enum InvType invType const multi1d< Real > & RsdCG
Definition: pbg5p_w.cc:30
Double c
Definition: invbicg.cc:108
LinOpSysSolverMGProtoClover::T T
const WilsonTypeFermAct< multi1d< LatticeFermion > > Handle< const ConnectState > const multi1d< Real > enum InvType invType const multi1d< Real > int MaxCG
Definition: pbg5p_w.cc:32
Real rsd_sq
Definition: invbicg.cc:121
@ MINUS
Definition: chromabase.h:45
@ PLUS
Definition: chromabase.h:45
multi1d< LatticeFermion > chi(Ncb)
Complex a
Definition: invbicg.cc:95
LatticeFermion psi
Definition: mespbg5p_w.cc:35
SystemSolverResults_t RelInvCG_a(const LinearOperator< T > &A, const LinearOperator< TF > &AF, const T &chi, T &psi, const Real &RsdCG, const Real &Delta, int MaxCG)
Definition: reliable_cg.cc:12
DComplex d
Definition: invbicg.cc:99
START_CODE()
A(A, psi, r, Ncb, PLUS)
multi1d< LatticeFermion > mp(Ncb)
Complex b
Definition: invbicg.cc:96
Double chi_norm
Definition: invbicg.cc:79
Double zero
Definition: invbicg.cc:106
int k
Definition: invbicg.cc:119
multi1d< LatticeFermion > s(Ncb)
FloatingPoint< double > Double
Definition: gtest.h:7351
BiCGStab Solver with reliable updates.
Holds return info from SystemSolver call.
Definition: syssolver.h:17
LatticeFermionF TF
Definition: t_quda_tprec.cc:17