CHROMA
invbicrstab.cc
Go to the documentation of this file.
1 /*! \file
2  * \brief Conjugate-Gradient algorithm for a generic Linear Operator
3  */
4 
5 #include "chromabase.h"
7 
8 namespace Chroma {
9 
10  template<typename T, typename CR>
11 SystemSolverResults_t
13  const T& chi,
14  T& psi,
15  const Real& RsdBiCGStab,
16  int MaxBiCGStab,
17  enum PlusMinus isign)
18 
19 {
21  StopWatch swatch;
22  FlopCounter flopcount;
23  flopcount.reset();
24  const Subset& s = A.subset();
25  bool convP = false;
26 
27  swatch.reset();
28  swatch.start();
29 
30  Double chi_sq = norm2(chi,s);
31  flopcount.addSiteFlops(4*Nc*Ns,s);
32 
33 
34  Double rsd_sq = RsdBiCGStab*RsdBiCGStab*chi_sq;
35 
36  // First get r = r0 = chi - A psi
37  T r;
38  T r0;
39 
40  // Get A psi, use r0 as a temporary
41  A(r0, psi, isign);
42  flopcount.addFlops(A.nFlops());
43 
44  // now work out r= chi - Apsi = chi - r0
45  r[s] = chi - r0;
46  flopcount.addSiteFlops(2*Nc*Ns,s);
47 
48 
49  // The main difference between BICGStab and BiCRStab
50  // The shadow residual r0* -> A^\dagger r_0*
51 #if 1
52  if( isign == PLUS ) {
53  A(r0,r, MINUS);
54  }
55  else {
56  A(r0,r, PLUS);
57  }
58 #else
59  A(r0,r,isign);
60 #endif
61 
62  // Everything else stays the same
63 
64  // Now initialise v = p = 0
65  T p;
66  T v;
67 
68  p[s] = zero;
69  v[s] = zero;
70 
71  T tmp;
72  T t;
73 
74  ComplexD rho, rho_prev, alpha, omega;
75 
76  // rho_0 := alpha := omega = 1
77  // Iterations start at k=1, so rho_0 is in rho_prev
78  rho_prev = Double(1);
79  alpha = Double(1);
80  omega = Double(1);
81 
82  // The iterations
83  for(int k = 1; k <= MaxBiCGStab && !convP ; k++) {
84 
85  // rho_{k+1} = < r_0 | r >
86  rho = innerProduct(r0,r,s);
87 
88 
89  if( toBool( real(rho) == 0 ) && toBool( imag(rho) == 0 ) ) {
90  QDPIO::cout << "BiCGStab breakdown: rho = 0" << std::endl;
91  QDP_abort(1);
92  }
93 
94  // beta = ( rho_{k+1}/rho_{k})(alpha/omega)
95  ComplexD beta;
96  beta = ( rho / rho_prev ) * (alpha/omega);
97 
98  // p = r + beta(p - omega v)
99 
100  // first work out p - omega v
101  // into tmp
102  // then do p = r + beta tmp
103  CR omega_r = omega;
104  CR beta_r = beta;
105  tmp[s] = p - omega_r*v;
106  p[s] = r + beta_r*tmp;
107 
108 
109  // v = Ap
110  A(v,p,isign);
111 
112 
113  // alpha = rho_{k+1} / < r_0 | v >
114  // put <r_0 | v > into tmp
115  DComplex ctmp = innerProduct(r0,v,s);
116 
117 
118  if( toBool( real(ctmp) == 0 ) && toBool( imag(ctmp) == 0 ) ) {
119  QDPIO::cout << "BiCGStab breakdown: <r_0|v> = 0" << std::endl;
120  QDP_abort(1);
121  }
122 
123  alpha = rho / ctmp;
124 
125  // Done with rho now, so save it into rho_prev
126  rho_prev = rho;
127 
128  // s = r - alpha v
129  // I can overlap s with r, because I recompute it at the end.
130  CR alpha_r = alpha;
131  r[s] -= alpha_r*v;
132 
133 
134  // t = As = Ar
135  A(t,r,isign);
136  // omega = < t | s > / < t | t > = < t | r > / norm2(t);
137 
138  // This does the full 5D norm
139  Double t_norm = norm2(t,s);
140 
141 
142  if( toBool(t_norm == 0) ) {
143  QDPIO::cerr << "Breakdown || Ms || = || t || = 0 " << std::endl;
144  QDP_abort(1);
145  }
146 
147  // accumulate <t | s > = <t | r> into omega
148  omega = innerProduct(t,r,s);
149  omega /= t_norm;
150 
151  // psi = psi + omega s + alpha p
152  // = psi + omega r + alpha p
153  //
154  // use tmp to compute psi + omega r
155  // then add in the alpha p
156  omega_r = omega;
157  alpha_r = alpha;
158  tmp[s] = psi + omega_r*r;
159  psi[s] = tmp + alpha_r*p;
160 
161 
162 
163  // r = s - omega t = r - omega t1G
164 
165 
166  r[s] -= omega_r*t;
167 
168 
169  Double r_norm = norm2(r,s);
170 
171 
172  // QDPIO::cout << "Iteration " << k << " : r = " << r_norm << std::endl;
173  if( toBool(r_norm < rsd_sq ) ) {
174  convP = true;
175  ret.resid = sqrt(r_norm);
176  ret.n_count = k;
177 
178  }
179  else {
180  convP = false;
181  }
182 
183  //-------BiCGStab Flopcounting --------------------------------------
184  // flopcount.addSiteFlops(8*Nc*Ns,s); // <r0|r>
185  // flopcount.addSiteFlops(16*Nc*Ns,s); // p = r + beta p - beta_omega v
186  // flopcount.addSiteFlops(8*Nc*Ns,s); // <r0 | v>
187  // flopcount.addSiteFlops(8*Nc*Ns,s); // r -= alpha v
188  // flopcount.addSiteFlops(8*Nc*Ns, s); // < t, r>
189  // flopcount.addSiteFlops(4*Nc*Ns, s); // < t, t>
190  // flopcount.addSiteFlops(16*Nc*Ns,s); // psi += omega r + alpha_p
191  // flopcount.addSiteFlops(8*Nc*Ns,s); // r -=omega t
192  // flopcount.addSiteFlops(4*Nc*Ns,s); // norm2(r)
193  // flopcount.addFlops(2*A.nFlops()); // = 80*Nc*Ns cbsite flops + 2*A
194  //----------------------------------------------------------------------
195  flopcount.addSiteFlops(80*Nc*Ns,s);
196  flopcount.addFlops(2*A.nFlops());
197 
198 
199  }
200 
201  swatch.stop();
202 
203  QDPIO::cout << "InvBiCRStab: k = " << ret.n_count << " resid = " << ret.resid << std::endl;
204  flopcount.report("invbicrstab", swatch.getTimeInSeconds());
205 
206  if ( ret.n_count == MaxBiCGStab ) {
207  QDPIO::cerr << "Nonconvergence of BiCGStab. MaxIters reached " << std::endl;
208  }
209 
210  return ret;
211 }
212 
213 #if 0
214 // Fix here for now
215 template<>
216 SystemSolverResults_t
217 InvBiCGStab(const LinearOperator<LatticeFermion>& A,
218  const LatticeFermion& chi,
219  LatticeFermion& psi,
220  const Real& RsdBiCGStab,
221  int MaxBiCGStab,
222  enum PlusMinus isign)
223 
224 {
225  return InvBiCGStab_a<LatticeFermion, Complex>(A, chi, psi, RsdBiCGStab, MaxBiCGStab, isign);
226 }
227 #endif
228 
229 template<>
230 SystemSolverResults_t
232  const LatticeFermionF& chi,
233  LatticeFermionF& psi,
234  const Real& RsdBiCGStab,
235  int MaxBiCGStab,
236  enum PlusMinus isign)
237 
238 {
239  return InvBiCRStab_a<LatticeFermionF, ComplexF>(A, chi, psi, RsdBiCGStab, MaxBiCGStab, isign);
240 }
241 
242 template<>
243 SystemSolverResults_t
245  const LatticeFermionD& chi,
246  LatticeFermionD& psi,
247  const Real& RsdBiCGStab,
248  int MaxBiCGStab,
249  enum PlusMinus isign)
250 
251 {
252  return InvBiCRStab_a<LatticeFermionD, ComplexD>(A, chi, psi, RsdBiCGStab, MaxBiCGStab, isign);
253 }
254 
255 // Staggered
256 template<>
257 SystemSolverResults_t
259  const LatticeStaggeredFermion& chi,
260  LatticeStaggeredFermion& psi,
261  const Real& RsdBiCGStab,
262  int MaxBiCGStab,
263  enum PlusMinus isign)
264 
265 {
266  return InvBiCRStab_a<LatticeStaggeredFermion, Complex>(A, chi, psi, RsdBiCGStab, MaxBiCGStab, isign);
267 }
268 
269 } // end namespace Chroma
Primary include file for CHROMA library code.
Linear Operator.
Definition: linearop.h:27
Conjugate-Gradient algorithm for a generic Linear Operator.
int t
Definition: meslate.cc:37
BinaryReturn< C1, C2, FnInnerProduct >::Type_t innerProduct(const QDPSubType< T1, C1 > &s1, const QDPType< T2, C2 > &s2)
static const LatticeInteger & beta(const int dim)
Definition: stag_phases_s.h:47
static const LatticeInteger & alpha(const int dim)
Definition: stag_phases_s.h:43
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
LatticeFermion tmp
Definition: mespbg5p_w.cc:36
LinOpSysSolverMGProtoClover::T T
SystemSolverResults_t InvBiCRStab_a(const LinearOperator< T > &A, const T &chi, T &psi, const Real &RsdBiCGStab, int MaxBiCGStab, enum PlusMinus isign)
Definition: invbicrstab.cc:12
SystemSolverResults_t InvBiCGStab(const LinearOperator< LatticeFermionF > &A, const LatticeFermionF &chi, LatticeFermionF &psi, const Real &RsdBiCGStab, int MaxBiCGStab, enum PlusMinus isign)
Definition: invbicgstab.cc:222
Real rsd_sq
Definition: invbicg.cc:121
SystemSolverResults_t InvBiCRStab(const LinearOperator< LatticeFermionF > &A, const LatticeFermionF &chi, LatticeFermionF &psi, const Real &RsdBiCGStab, int MaxBiCGStab, enum PlusMinus isign)
Definition: invbicrstab.cc:231
@ MINUS
Definition: chromabase.h:45
@ PLUS
Definition: chromabase.h:45
multi1d< LatticeFermion > chi(Ncb)
Complex omega
Definition: invbicg.cc:97
LatticeFermion psi
Definition: mespbg5p_w.cc:35
A(A, psi, r, Ncb, PLUS)
Double zero
Definition: invbicg.cc:106
int k
Definition: invbicg.cc:119
multi1d< LatticeFermion > s(Ncb)
FloatingPoint< double > Double
Definition: gtest.h:7351
Double r_norm
Definition: pade_trln_w.cc:86
int r0
Definition: qtopcor.cc:41
Holds return info from SystemSolver call.
Definition: syssolver.h:17