CHROMA
inv_rel_cg2.cc
Go to the documentation of this file.
1 /*! \file
2  * \brief Conjugate-Gradient algorithm for a generic Linear Operator
3  */
4 
5 #include "chromabase.h"
7 
8 namespace Chroma {
9 
10 //! Conjugate-Gradient (CGNE) algorithm for a generic Linear Operator
11 /*! \ingroup invert
12  * This subroutine uses the Conjugate Gradient (CG) algorithm to find
13  * the solution of the set of linear equations
14  *
15  * Chi = A . Psi
16  *
17  * where A = M^dag . M
18  *
19  * Algorithm:
20 
21  * Psi[0] := initial guess; Linear interpolation (argument)
22  * r[0] := Chi - M^dag . M . Psi[0] ; Initial residual
23  * p[1] := r[0] ; Initial direction
24  * IF |r[0]| <= RsdCG |Chi| THEN RETURN; Converged?
25  * FOR k FROM 1 TO MaxCG DO CG iterations
26  * a[k] := |r[k-1]|**2 / <Mp[k],Mp[k]> ;
27  * Psi[k] += a[k] p[k] ; New solution std::vector
28  * r[k] -= a[k] M^dag . M . p[k] ; New residual
29  * IF |r[k]| <= RsdCG |Chi| THEN RETURN; Converged?
30  * b[k+1] := |r[k]|**2 / |r[k-1]|**2 ;
31  * p[k+1] := r[k] + b[k+1] p[k]; New direction
32  *
33  * Arguments:
34  *
35  * \param M Linear Operator (Read)
36  * \param chi Source (Read)
37  * \param psi Solution (Modify)
38  * \param RsdCG CG residual accuracy (Read)
39  * \param MaxCG Maximum CG iterations (Read)
40  * \param n_count Number of CG iteration (Write)
41  *
42  * Local Variables:
43  *
44  * p Direction std::vector
45  * r Residual std::vector
46  * cp | r[k] |**2
47  * c | r[k-1] |**2
48  * k CG iteration counter
49  * a a[k]
50  * b b[k+1]
51  * d < p[k], A.p[k] >
52  * Mp Temporary for M.p
53  *
54  * Subroutines:
55  * +
56  * A Apply matrix M or M to std::vector
57  *
58  * Operations:
59  *
60  * 2 A + 2 Nc Ns + N_Count ( 2 A + 10 Nc Ns )
61  */
62 
63 template<typename T>
65  const T& chi,
66  T& psi,
67  const Real& RsdCG,
68  int MaxCG,
69  int& n_count)
70 {
71  const Subset& s = M.subset();
72 
73  // Real rsd_sq = (RsdCG * RsdCG) * Real(norm2(chi,s));
74  Real chi_sq = Real(norm2(chi,s));
75 
76  QDPIO::cout << "chi_norm = " << sqrt(chi_sq) << std::endl;
77  Real rsd_sq = (RsdCG * RsdCG) * chi_sq;
78 
79  // +
80  // r[0] := Chi - A . Psi[0] where A = M . M
81 
82  // +
83  // r := [ Chi - M(u) . M(u) . psi ]
84  T r, mp, mmp;
85  psi[s] = zero;
86 
87  r[s] = chi;
88  // p[1] := r[0]
89  T p;
90  p = zero;
91  p[s] = r;
92 
93  // Cp = |r[0]|^2
94  Double c = norm2(r, s); /* 2 Nc Ns flops */
95  Double cp = c;
96 
97  Double zeta = Double(1)/c;
98 
99  QDPIO::cout << "InvRelCG2: k = 0 c = " << cp << " rsd_sq = " << rsd_sq << std::endl;
100 
101  // IF |r[0]| <= RsdCG |Chi| THEN RETURN;
102  if ( toBool(c <= rsd_sq) )
103  {
104  n_count = 0;
105  return;
106  }
107 
108  //
109  // FOR k FROM 1 TO MaxCG DO
110  //
111  Real a;
112  Double d;
113 
114  for(int k = 1; k <= MaxCG; ++k)
115  {
116  // Inner tolerance = epsilon || chi || || p || sqrt(zeta) / 2
117  //
118  // The || p || part is taken care of the fact that we are using
119  // relative residua in the inner solve. The factor of 2 is because
120  // we apply the operator twice...
121  Real inner_tol = sqrt(rsd_sq)*sqrt(zeta)/Real(2);
122 
123  // Compute M^{dag} M p
124  M(mp, p, PLUS, inner_tol);
125  M(mmp, mp, MINUS, inner_tol);
126 
127  // d = < M^{dag} M p, p>
128  d = innerProductReal(mmp, p, s); /* 2 Nc Ns flops */
129 
130  // a[k] := | r |**2 / < p[k], Ap[k] > ;
131  a = Real(c)/Real(d);
132 
133  // Psi[k] += a[k] p[k]
134  psi[s] += a * p; /* 2 Nc Ns flops */
135 
136  // r[k] -= a[k] M^{dag}M. p[k] ;
137  r[s] -= a * mmp;
138 
139  // IF |r[k]| <= RsdCG |Chi| THEN RETURN;
140 
141  // cp = | r[k] |**2
142  c = norm2(r, s); /* 2 Nc Ns flops */
143 
144  // update relaxation factor
145  zeta += Double(1)/c;
146 
147  // Update p as:
148  // p[k+1] := r[k] + c/cp * p
149  //
150  // we put c/cp into b
151  //
152  // b[k+1] := |r[k]|**2 / |r[k-1]|**2
153  Real b = Real(c) / Real(cp);
154 
155  // p[k+1] := r[k] + b[k+1] p[k]
156  p[s] = r + b*p; /* Nc Ns flops */
157 
158  cp = c;
159 
160  QDPIO::cout << "InvCG: k = " << k << " cp = " << cp << std::endl;
161 
162  if ( toBool(cp <= rsd_sq) )
163  {
164  n_count = k;
165  return;
166  }
167  }
168  n_count = MaxCG;
169  QDPIO::cerr << "Nonconvergence Warning n_count = " << n_count << std::endl;
170 }
171 
172 
173 // Fix here for now
174 template<>
176  const LatticeFermion& chi,
177  LatticeFermion& psi,
178  const Real& RsdCG,
179  int MaxCG,
180  int& n_count)
181 {
183 }
184 
185 } // end namespace Chroma
Primary include file for CHROMA library code.
Linear Operator.
Definition: linearop.h:27
virtual const Subset & subset() const =0
Return the subset on which the operator acts.
void InvRelCG2_a(const LinearOperator< T > &M, const T &chi, T &psi, const Real &RsdCG, int MaxCG, int &n_count)
Conjugate-Gradient (CGNE) algorithm for a generic Linear Operator.
Definition: inv_rel_cg2.cc:64
void InvRelCG2(const LinearOperator< LatticeFermion > &M, const LatticeFermion &chi, LatticeFermion &psi, const Real &RsdCG, int MaxCG, int &n_count)
Relaxed Conjugate-Gradient (CGNE) algorithm for a generic Linear Operator.
Definition: inv_rel_cg2.cc:175
Conjugate-Gradient algorithm for a generic Linear Operator.
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
const WilsonTypeFermAct< multi1d< LatticeFermion > > Handle< const ConnectState > const multi1d< Real > enum InvType invType const multi1d< Real > & RsdCG
Definition: pbg5p_w.cc:30
int n_count
Definition: invbicg.cc:78
Double c
Definition: invbicg.cc:108
LinOpSysSolverMGProtoClover::T T
const WilsonTypeFermAct< multi1d< LatticeFermion > > Handle< const ConnectState > const multi1d< Real > enum InvType invType const multi1d< Real > int MaxCG
Definition: pbg5p_w.cc:32
Real rsd_sq
Definition: invbicg.cc:121
@ MINUS
Definition: chromabase.h:45
@ PLUS
Definition: chromabase.h:45
Double cp
Definition: invbicg.cc:107
multi1d< LatticeFermion > chi(Ncb)
Complex a
Definition: invbicg.cc:95
LatticeFermion psi
Definition: mespbg5p_w.cc:35
DComplex d
Definition: invbicg.cc:99
multi1d< LatticeFermion > mp(Ncb)
Complex b
Definition: invbicg.cc:96
Double zero
Definition: invbicg.cc:106
int k
Definition: invbicg.cc:119
multi1d< LatticeFermion > s(Ncb)
FloatingPoint< double > Double
Definition: gtest.h:7351