CHROMA
inv_rel_sumr.cc
Go to the documentation of this file.
1 #include "chromabase.h"
3 
4 namespace Chroma {
5 
6 // Solve a shifted unitary system
7 //
8 // A x = b
9 //
10 // Where X is of the form: A = zeta I + rho U
11 //
12 // rho > 0 and zeta are complex, and U is unitary
13 //
14 // We solve with the method described in:
15 //
16 // "A Fast Minimal Residual Algorithm for Shifted Unitary Matrices"
17 // by Carl F. Jagels, and Lothar Reichel
18 // Numerical Linear Algebra with Applications, Vol 1(6), 555-570(1994)
19 //
20 // This paper is referenced by and applied to the Overlap Dirac Operator
21 // by G. Arnold, N. Cundy, J. van den Eshof, A Frommer, S. Krieg, T. Lippert,
22 // K. Schaefer "Numerical Methods for the QCD Overlap Operator: II.
23 // Optimal Krylov Subspace Methods" -- hep-lat/0311025
24 // which is where the name SUMR was coined.
25 //
26 
27 //
28 // This is now with Relaxation in the solver...
29 //
30 //
31 
32 template<typename T>
34  const T& b,
35  T& x,
36  const Complex& zeta,
37  const Real& rho,
38  const Real& epsilon,
39  int MaxSUMR,
40  int& n_count)
41 {
42 
43 
44  /********************************************************************/
45  /* Initialisation */
46  /********************************************************************/
47 
48  // First things first. We need r_0 = b - A x
49  //
50  // b - A x = b - (zeta I + rho U)x = b - zeta x - rho U x
51  //
52  LatticeFermion r = b - zeta*x;
53  LatticeFermion u;
54  U(u, x, PLUS, epsilon);
55  r -= rho*u;
56 
57  // delta_0 = || r ||
58  Real delta = sqrt(norm2(r));
59 
60  // phi_hat_1 = 1 / delta_0
61  Complex phihat = Real(1)/delta;
62 
63  // tau_hat = delta_0 / rho
64  Complex tauhat = delta/rho;
65 
66  LatticeFermion w_old, p, v_old;
67  // w_-1 = p_-1 = v_0 = 0
68  //
69  // Iteration starts at m=1, so w_-1, p_-1 are in fact
70  // w_{m-2}, p_{m-2}, v_{m-1}
71  w_old=zero;
72  p=w_old;
73  v_old = w_old;
74 
75  // phi_0 = s_0 = lambda_0 := r_{-1,0} = 0
76  Complex phi = Real(0);
77  Real s = Real(0);
78  Complex lambda = Real(0);
79  Complex r0 = Real(0);
80 
81  // r_{0,0} := gamma_0 := sigma_0 := c_0 := 1;
82  Complex r1_old = Real(1);
83  Complex gamma = Real(1);
84  Real sigma = Real(1);
85  Complex c = Real(1);
86 
87  // v_1 := \tilde{v}_1 := r_0 / delta_ 0
88  LatticeFermion v, vtilde;
89  Real ftmp = Real(1)/delta;
90  v = ftmp*r;
91  vtilde = ftmp*r;
92 
93  /***********************************************************/
94  /* Start the iteration */
95  /***********************************************************/
96  bool convP = false;
97  Real taumod = sqrt(real(conj(tauhat)*tauhat));
98  for(int iter = 1; iter <= MaxSUMR && !convP; iter++) {
99 
100  // taumod is a bound on || r ||
101  // inner solver criterion is epsilon/|| r ||
102 
103  Real inner_tol = epsilon/taumod;
104 
105  // u := U v_m
106  U(u, v, PLUS, inner_tol);
107 
108  // gamma_m = - < vtilde, u >
109  gamma = - innerProduct(vtilde, u);
110 
111  // sigma = ( ( 1 + | gamma | )(1 - | gamma | ) )^{1/2}
112  //
113  // NB I multiplied out the inner brackets
114  // sigma = ( 1 - | gamma |^2 )^{1/2}
115  sigma = sqrt( Real(1) - real( conj(gamma)*gamma ) );
116 
117  // alpha_m := - gamma_m delta_{m-1}
118  Complex alpha = - gamma * delta;
119 
120  // r_{m-1, m}:= alpha_m * phi_{m-1} + s_{m-1}*zeta/rho;
121  r0 = alpha * phi + s*zeta/rho;
122 
123  // rhat_{m,m} := alpha_m phihat_m + conj(c_{m-1})*zeta/rho;
124  Complex r1hat = alpha*phihat+conj(c)*zeta/rho;
125 
126 
127  // conj(c_m) := rhat_{m,m}/( | rhat_{m,m} |^2 + | sigma_m |^2 )^{1/2}
128  // s_m := -sigma_m / ( | rhat_{m,m} |^2 + | sigma_m |^2 )^{1/2}
129  //
130  // Note I precompute the denominator into tmp_length.
131  // and I conjugate the expression for conj(c) to get c directly
132  // This trick is from the Wuppertal Group's MATLAB implementation
133  Real abs_rhat_sq = real(conj(r1hat)*r1hat);
134  Real abs_sigma_sq = sigma*sigma;
135  Real tmp_length = sqrt(abs_rhat_sq + abs_sigma_sq);
136 
137  c = conj(r1hat)/tmp_length;
138  s = - sigma / tmp_length;
139 
140  // r_{m,m} := -c_m*rhat_{m,m} + s_m*sigma_m;
141  Complex r1 = -c*r1hat + cmplx(s*sigma,0);
142 
143  // tau_m := -c_m*tauhat_m
144  Complex tau = -c*tauhat;
145 
146  // tauhat_{m+1} := s_m tauhat_m
147  tauhat = s * tauhat;
148 
149  // eta_m := tau_m / r_{m,m}
150  Complex eta = tau / r1;
151 
152  // kappa_{m-1} := r_{m-1,m}/r_{m-1, m-1}
153  Complex kappa = r0/r1_old;
154 
155  // keep r1
156  r1_old=r1;
157 
158  // w_{m-1} := alpha_m * p_{m-2} - ( w_{m-2} - v_{m-1} ) kappa_{m-1}
159  // p_{m-1} := p_{m-2} - ( w_{m-2} - v_{m-1} )*lambda_{m-1}
160  //
161  // NB: I precompute v_{m-2} - v_{m-1} into w_minus_v
162  LatticeFermion w_minus_v;
163  w_minus_v = w_old - v_old;
164 
165  LatticeFermion w = alpha*p - kappa*w_minus_v;
166  p = p - lambda*w_minus_v;
167 
168  // x_{m-1} := x_{m-1} - ( w_{m-1} - v_m ) eta
169  //
170  // NB I compute w_{m-1} - v_m into w_minus_v
171 
172  w_minus_v = w-v;
173  x = x - eta*w_minus_v;
174 
175  // I need to keep w as w_old for the next iteration
176  w_old = w;
177 
178 
179  // At this point the paper writes:
180  //
181  // if (sigma_m = 0) then we have converged
182  // so here I just go on and set the convergence flag in the else clause
183  //
184  if( toBool( sqrt(real(conj(sigma)*sigma)) > epsilon) ) {
185 
186  // delta_m := delta_{m-1} sigma_m
187  delta = delta*sigma;
188 
189  // phi_m := -c_m phihat_m + s_m conj(gamma_m)/delta_m
190  phi = -c*phihat + s*conj(gamma)/delta;
191 
192  // lambda_m := phi_m / r_{m,m}
193  lambda = phi/r1;
194 
195  // phihat_{m+1} := s_m phihat_{m} + conj(c_m)*conj(gamma_m)/delta_m
196  phihat = s* phihat + conj(c) * conj(gamma)/delta;
197 
198  // preserve v as v_old
199  v_old = v;
200 
201  // v_{m+1} := (1/sigma)( u + gamma_m vtilde_m );
202  v = u + gamma*vtilde;
203  Real ftmp2 = Real(1)/sigma;
204  v *= ftmp2;
205 
206  // vtilde_{m+1} := sigma_m vtilde_m + conj(gamma_m)*v_{m+1}
207  Complex gconj = conj(gamma);
208  Complex csigma=cmplx(sigma,0);
209  vtilde = csigma*vtilde + gconj*v;
210 
211  // Normalise vtilde -- I found this in the Wuppertal MATLAB code
212  // It is not prescribed by the Reichel/Jagels paper
213  Real ftmp3 = Real(1)/sqrt(norm2(vtilde));
214  vtilde *= ftmp3;
215 
216  // Check whether we have converged or not:
217  //
218  // converged if | tauhat | < epsilon
219  //
220  taumod = sqrt(real(conj(tauhat)*tauhat));
221  if ( toBool(taumod < epsilon) ) convP = true;
222 
223  QDPIO::cout << "Iter " << iter << ": | tauhat |=" << taumod << std::endl;
224 
225  }
226  else {
227 
228  // Else clause of the if sigma_m == 0. If sigma < epsilon then converged
229  convP = true;
230  }
231 
232  // Count the number of iterations
233  n_count = iter;
234  }
235 
236  // And we are done, either converged or not...
237  if( n_count >= MaxSUMR && ! convP ) {
238  QDPIO::cout << "Solver Nonconvergence Warning " << std::endl;
239  }
240 
241 }
242 
243 template<>
245  const LatticeFermion& b,
246  LatticeFermion& x,
247  const Complex& zeta,
248  const Real& rho,
249  const Real& epsilon,
250  int MaxSUMR,
251  int& n_count)
252 {
253  InvRelSUMR_a(U, b, x, zeta, rho, epsilon, MaxSUMR, n_count);
254 }
255 
256 } // end namespace Chroma
Primary include file for CHROMA library code.
Linear Operator.
Definition: linearop.h:27
Conjugate-Gradient algorithm for a generic Linear Operator.
int x
Definition: meslate.cc:34
LatticeReal phi
Definition: mesq.cc:27
int epsilon(int i, int j, int k)
BinaryReturn< C1, C2, FnInnerProduct >::Type_t innerProduct(const QDPSubType< T1, C1 > &s1, const QDPType< T2, C2 > &s2)
static const LatticeInteger & alpha(const int dim)
Definition: stag_phases_s.h:43
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
static multi1d< LatticeColorMatrix > u
int n_count
Definition: invbicg.cc:78
void InvRelSUMR(const LinearOperator< LatticeFermion > &U, const LatticeFermion &b, LatticeFermion &x, const Complex &zeta, const Real &rho, const Real &epsilon, int MaxSUMR, int &n_count)
Double c
Definition: invbicg.cc:108
LinOpSysSolverMGProtoClover::T T
void InvRelSUMR_a(const LinearOperator< T > &U, const T &b, T &x, const Complex &zeta, const Real &rho, const Real &epsilon, int MaxSUMR, int &n_count)
Definition: inv_rel_sumr.cc:33
@ PLUS
Definition: chromabase.h:45
LatticeFermion eta
Definition: mespbg5p_w.cc:37
Complex b
Definition: invbicg.cc:96
Double zero
Definition: invbicg.cc:106
multi1d< LatticeFermion > s(Ncb)
int kappa
Definition: pade_trln_w.cc:112
int r0
Definition: qtopcor.cc:41
multi1d< LatticeColorMatrix > U
Double ftmp2
Definition: topol.cc:30