CHROMA
invsumr.cc
Go to the documentation of this file.
1 #include "chromabase.h"
3 
4 namespace Chroma {
5 
6 // Solve a shifted unitary system
7 //
8 // A x = b
9 //
10 // Where X is of the form: A = zeta I + rho U
11 //
12 // rho > 0 and zeta are complex, and U is unitary
13 //
14 // We solve with the method described in:
15 //
16 // "A Fast Minimal Residual Algorithm for Shifted Unitary Matrices"
17 // by Carl F. Jagels, and Lothar Reichel
18 // Numerical Linear Algebra with Applications, Vol 1(6), 555-570(1994)
19 //
20 // This paper is referenced by and applied to the Overlap Dirac Operator
21 // by G. Arnold, N. Cundy, J. van den Eshof, A Frommer, S. Krieg, T. Lippert,
22 // K. Schaefer "Numerical Methods for the QCD Overlap Operator: II.
23 // Optimal Krylov Subspace Methods" -- hep-lat/0311025
24 // which is where the name SUMR was coined.
25 //
26 
27 
28 template<typename T>
30  const T& b,
31  T& x,
32  const Complex& zeta,
33  const Real& rho,
34  const Real& epsilon,
35  int MaxSUMR,
36  int& n_count)
37 {
38  START_CODE();
39 
40  /********************************************************************/
41  /* Initialisation */
42  /********************************************************************/
43 
44  // First things first. We need r_0 = b - A x
45  //
46  // b - A x = b - (zeta I + rho U)x = b - zeta x - rho U x
47  //
48  LatticeFermion r = b - zeta*x;
49  LatticeFermion u;
50  U(u, x, PLUS);
51  r -= rho*u;
52 
53  // delta_0 = || r ||
54  Real delta = sqrt(norm2(r));
55 
56  // If already converged then exit now.
57  if( toBool( delta < epsilon*sqrt(norm2(b)) ) )
58  {
59  END_CODE();
60  return;
61  }
62 
63  // phi_hat_1 = 1 / delta_0
64  Complex phihat = Real(1)/delta;
65 
66  // tau_hat = delta_0 / rho
67  Complex tauhat = delta/rho;
68 
69  LatticeFermion w_old, p, v_old;
70  // w_-1 = p_-1 = v_0 = 0
71  //
72  // Iteration starts at m=1, so w_-1, p_-1 are in fact
73  // w_{m-2}, p_{m-2}, v_{m-1}
74  w_old=zero;
75  p=w_old;
76  v_old = w_old;
77 
78  // phi_0 = s_0 = lambda_0 := r_{-1,0} = 0
79  Complex phi = Real(0);
80  Real s = Real(0);
81  Complex lambda = Real(0);
82  Complex r0 = Real(0);
83 
84  // r_{0,0} := gamma_0 := sigma_0 := c_0 := 1;
85  Complex r1_old = Real(1);
86  Complex gamma = Real(1);
87  Real sigma = Real(1);
88  Complex c = Real(1);
89 
90  // v_1 := \tilde{v}_1 := r_0 / delta_ 0
91  LatticeFermion v, vtilde;
92  Real ftmp = Real(1)/delta;
93  v = ftmp*r;
94  vtilde = ftmp*r;
95 
96  /***********************************************************/
97  /* Start the iteration */
98  /***********************************************************/
99  bool convP = false;
100  for(int iter = 1; iter <= MaxSUMR && !convP; iter++) {
101 
102  // u := U v_m
103  U(u, v, PLUS);
104 
105  // gamma_m = - < vtilde, u >
106  gamma = - innerProduct(vtilde, u);
107 
108  // sigma = ( ( 1 + | gamma | )(1 - | gamma | ) )^{1/2}
109  //
110  // NB I multiplied out the inner brackets
111  // sigma = ( 1 - | gamma |^2 )^{1/2}
112  sigma = sqrt( Real(1) - real( conj(gamma)*gamma ) );
113 
114  // alpha_m := - gamma_m delta_{m-1}
115  Complex alpha = - gamma * delta;
116 
117  // r_{m-1, m}:= alpha_m * phi_{m-1} + s_{m-1}*zeta/rho;
118  r0 = alpha * phi + s*zeta/rho;
119 
120  // rhat_{m,m} := alpha_m phihat_m + conj(c_{m-1})*zeta/rho;
121  Complex r1hat = alpha*phihat+conj(c)*zeta/rho;
122 
123 
124  // conj(c_m) := rhat_{m,m}/( | rhat_{m,m} |^2 + | sigma_m |^2 )^{1/2}
125  // s_m := -sigma_m / ( | rhat_{m,m} |^2 + | sigma_m |^2 )^{1/2}
126  //
127  // Note I precompute the denominator into tmp_length.
128  // and I conjugate the expression for conj(c) to get c directly
129  // This trick is from the Wuppertal Group's MATLAB implementation
130  Real abs_rhat_sq = real(conj(r1hat)*r1hat);
131  Real abs_sigma_sq = sigma*sigma;
132  Real tmp_length = sqrt(abs_rhat_sq + abs_sigma_sq);
133 
134  c = conj(r1hat)/tmp_length;
135  s = - sigma / tmp_length;
136 
137  // r_{m,m} := -c_m*rhat_{m,m} + s_m*sigma_m;
138  Complex r1 = -c*r1hat + cmplx(s*sigma,0);
139 
140  // tau_m := -c_m*tauhat_m
141  Complex tau = -c*tauhat;
142 
143  // tauhat_{m+1} := s_m tauhat_m
144  tauhat = s * tauhat;
145 
146  // eta_m := tau_m / r_{m,m}
147  Complex eta = tau / r1;
148 
149  // kappa_{m-1} := r_{m-1,m}/r_{m-1, m-1}
150  Complex kappa = r0/r1_old;
151 
152  // keep r1
153  r1_old=r1;
154 
155  // w_{m-1} := alpha_m * p_{m-2} - ( w_{m-2} - v_{m-1} ) kappa_{m-1}
156  // p_{m-1} := p_{m-2} - ( w_{m-2} - v_{m-1} )*lambda_{m-1}
157  //
158  // NB: I precompute v_{m-2} - v_{m-1} into w_minus_v
159  LatticeFermion w_minus_v;
160  w_minus_v = w_old - v_old;
161 
162  LatticeFermion w = alpha*p - kappa*w_minus_v;
163  p = p - lambda*w_minus_v;
164 
165  // x_{m-1} := x_{m-1} - ( w_{m-1} - v_m ) eta
166  //
167  // NB I compute w_{m-1} - v_m into w_minus_v
168 
169  w_minus_v = w-v;
170  x = x - eta*w_minus_v;
171 
172  // I need to keep w as w_old for the next iteration
173  w_old = w;
174 
175  // At this point the paper writes:
176  //
177  // if (sigma_m = 0) then we have converged
178  // so here I just go on and set the convergence flag in the else clause
179  //
180  if( toBool( sqrt(real(conj(sigma)*sigma)) > epsilon) ) {
181 
182  // delta_m := delta_{m-1} sigma_m
183  delta = delta*sigma;
184 
185  // phi_m := -c_m phihat_m + s_m conj(gamma_m)/delta_m
186  phi = -c*phihat + s*conj(gamma)/delta;
187 
188  // lambda_m := phi_m / r_{m,m}
189  lambda = phi/r1;
190 
191  // phihat_{m+1} := s_m phihat_{m} + conj(c_m)*conj(gamma_m)/delta_m
192  phihat = s* phihat + conj(c) * conj(gamma)/delta;
193 
194  // preserve v as v_old
195  v_old = v;
196 
197  // v_{m+1} := (1/sigma)( u + gamma_m vtilde_m );
198  v = u + gamma*vtilde;
199  Real ftmp2 = Real(1)/sigma;
200  v *= ftmp2;
201 
202  // vtilde_{m+1} := sigma_m vtilde_m + conj(gamma_m)*v_{m+1}
203  Complex gconj = conj(gamma);
204  Complex csigma=cmplx(sigma,0);
205  vtilde = csigma*vtilde + gconj*v;
206 
207  // Normalise vtilde -- I found this in the Wuppertal MATLAB code
208  // It is not prescribed by the Reichel/Jagels paper
209  Real ftmp3 = Real(1)/sqrt(norm2(vtilde));
210  vtilde *= ftmp3;
211 
212  // Check whether we have converged or not:
213  //
214  // converged if | tauhat | < epsilon
215  //
216  Real taumod = sqrt(real(conj(tauhat)*tauhat));
217  if ( toBool(taumod < epsilon) ) convP = true;
218 
219  QDPIO::cout << "Iter " << iter << ": | tauhat |=" << taumod << std::endl;
220 
221  }
222  else {
223 
224  // Else clause of the if sigma_m == 0. If sigma < epsilon then converged
225  convP = true;
226  }
227 
228  // Count the number of iterations
229  n_count = iter;
230  }
231 
232  // And we are done, either converged or not...
233  if( n_count == MaxSUMR && ! convP ) {
234  QDPIO::cout << "Solver Nonconvergence Warning " << std::endl;
235  }
236 
237  END_CODE();
238 }
239 
240 template<>
242  const LatticeFermion& b,
243  LatticeFermion& x,
244  const Complex& zeta,
245  const Real& rho,
246  const Real& epsilon,
247  int MaxSUMR,
248  int& n_count)
249 {
250  InvSUMR_a(U, b, x, zeta, rho, epsilon, MaxSUMR, n_count);
251 }
252 
253 } // end namespace Chroma
Primary include file for CHROMA library code.
Linear Operator.
Definition: linearop.h:27
Conjugate-Gradient algorithm for a generic Linear Operator.
int x
Definition: meslate.cc:34
LatticeReal phi
Definition: mesq.cc:27
int epsilon(int i, int j, int k)
BinaryReturn< C1, C2, FnInnerProduct >::Type_t innerProduct(const QDPSubType< T1, C1 > &s1, const QDPType< T2, C2 > &s2)
static const LatticeInteger & alpha(const int dim)
Definition: stag_phases_s.h:43
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
static multi1d< LatticeColorMatrix > u
int n_count
Definition: invbicg.cc:78
Double c
Definition: invbicg.cc:108
LinOpSysSolverMGProtoClover::T T
@ PLUS
Definition: chromabase.h:45
void InvSUMR_a(const LinearOperator< T > &U, const T &b, T &x, const Complex &zeta, const Real &rho, const Real &epsilon, int MaxSUMR, int &n_count)
Definition: invsumr.cc:29
LatticeFermion eta
Definition: mespbg5p_w.cc:37
void InvSUMR(const LinearOperator< LatticeFermion > &U, const LatticeFermion &b, LatticeFermion &x, const Complex &zeta, const Real &rho, const Real &epsilon, int MaxSUMR, int &n_count)
Definition: invsumr.cc:241
START_CODE()
Complex b
Definition: invbicg.cc:96
Double zero
Definition: invbicg.cc:106
multi1d< LatticeFermion > s(Ncb)
int kappa
Definition: pade_trln_w.cc:112
int r0
Definition: qtopcor.cc:41
multi1d< LatticeColorMatrix > U
Double ftmp2
Definition: topol.cc:30