CHROMA
minv_rel_sumr.cc
Go to the documentation of this file.
1 #include "chromabase.h"
3 
4 namespace Chroma {
5 
6 // Solve a shifted unitary system
7 //
8 // A x = b
9 //
10 // Where X is of the form: A = zeta I + rho U
11 //
12 // rho > 0 and zeta are complex, and U is unitary
13 //
14 // We solve with the method described in:
15 //
16 // "A Fast Minimal Residual Algorithm for Shifted Unitary Matrices"
17 // by Carl F. Jagels, and Lothar Reichel
18 // Numerical Linear Algebra with Applications, Vol 1(6), 555-570(1994)
19 //
20 // This paper is referenced by and applied to the Overlap Dirac Operator
21 // by G. Arnold, N. Cundy, J. van den Eshof, A Frommer, S. Krieg, T. Lippert,
22 // K. Schaefer "Numerical Methods for the QCD Overlap Operator: II.
23 // Optimal Krylov Subspace Methods" -- hep-lat/0311025
24 // which is where the name SUMR was coined.
25 //
26 
27 
28 template<typename T>
30  const T& b,
31  multi1d<T>& x,
32  const multi1d<Complex>& zeta,
33  const multi1d<Real>& rho,
34  const multi1d<Real>& epsilon,
35  int MaxSUMR,
36  int& n_count)
37 {
38  START_CODE();
39 
40  // Sanity check
41  int numroot = x.size();
42  if( zeta.size() != numroot ) {
43  QDPIO::cerr << "zeta size:"<< zeta.size()
44  <<" is different from x.size:"<<numroot << std::endl;
45 
46  QDP_abort(1);
47  }
48 
49  if( rho.size() != numroot) {
50  QDPIO::cerr << "rho size:"<< rho.size()
51  <<" is different from x.size:"<<numroot << std::endl;
52 
53  }
54 
55  if( epsilon.size() != numroot) {
56  QDPIO::cerr << "epsilon size:"<< epsilon.size()
57  <<" is different from x.size:"<<numroot << std::endl;
58 
59  }
60 
61  // Auxiliary condition for stopping is when sigma < epsilon
62  // To keep things consistent this should be the smallest epsilon passed in
63  Real epsilon_sigma = epsilon[0];
64  for(int shift = 0; shift < numroot; shift++) {
65  if ( toBool( epsilon[shift] < epsilon_sigma ) ) {
66  epsilon_sigma = epsilon[shift];
67  }
68  }
69 
70  /********************************************************************/
71  /* Initialisation */
72  /********************************************************************/
73 
74  // First things first. We need r_0 = b - A x
75  // for all systems.
76  // Restriction here, is that the initial std::vector x is zero
77  // and so r_0 = b
78  for(int shift = 0; shift < numroot; shift++) {
79  x[shift] = zero;
80  }
81 
82  LatticeFermion r;
83  r = b;
84 
85  // delta_0_shift = || r_shift ||
86  Real delta = sqrt(norm2(r));
87 
88  multi1d<Complex> phihat(numroot);
89  multi1d<Complex> tauhat(numroot);
90 
91  // This std::vector is part of the Arnoldi process and should
92  // be shift independent.
93  for(int shift = 0; shift < numroot; shift++) {
94  // phi_hat_1 = 1 / delta_0
95  phihat[shift] = Real(1)/delta;
96 
97  // tau_hat = delta_0 / rho
98  tauhat[shift] = delta/rho[shift];
99  }
100 
101  // v_0 = zero
102  LatticeFermion v_old = zero;
103 
104  // These are used in updating the individual solutions
105  // and so are multi massed
106  multi1d<LatticeFermion> w_old(numroot);
107  multi1d<LatticeFermion> p(numroot);
108 
109  // w_-1 = p_-1
110  //
111  // Iteration starts at m=1, so w_-1, p_-1 are in fact
112  // w_{m-2}, p_{m-2}, v_{m-1}
113  for(int shift = 0; shift < numroot; shift++) {
114  w_old[shift]=zero;
115  p[shift]=zero;
116  }
117 
118 
119  // These are part of the Arnoldi process and ar shift independent
120  // gamma_0:= 1, sigma := 1
121  Complex gamma = Real(1);
122  Real sigma = Real(1);
123 
124  // These are all shift dependent
125  multi1d<Complex> phi(numroot);
126  multi1d<Real> s(numroot);
127  multi1d<Complex> lambda(numroot);
128  multi1d<Complex> r0(numroot);
129  multi1d<Complex> r1_old(numroot);
130  multi1d<Complex> c(numroot);
131 
132  // phi_0 = s_0 = lambda_0 := r_{-1,0} = 0
133  // r_{0,0} := c_0 := 1;
134  for(int shift = 0; shift < numroot; shift++) {
135  phi[shift] = Real(0);
136  s[shift] = Real(0);
137  lambda[shift] = Real(0);
138  r0[shift] = Real(0);
139  r1_old[shift] = Real(1);
140  c[shift] = Real(1);
141  }
142 
143  // It is here, that we couple the individual systems.
144  // Because r_0 is the same for all the systems
145  // v_1 is the same for all the system
146 
147  // v_1 := \tilde{v}_1 := r_0 / delta_ 0
148  LatticeFermion v;
149  LatticeFermion vtilde;
150 
151  Real ftmp = Real(1)/delta;
152  v = ftmp*r;
153  vtilde = ftmp*r;
154 
155  /***********************************************************/
156  /* Start the iteration */
157  /***********************************************************/
158  multi1d<bool> convP(numroot);
159  convP = false;
160  bool allConvP = false;
161 
162 
163  for(int iter = 1; iter <= MaxSUMR && !allConvP; iter++) {
164  LatticeFermion u;
165 
166  // Updating u, gamma and sigma are common to all systems.
167 
168  // the inner solver precision is supposed to be
169  // epsilon / || r ||
170  //
171  // we get the bound for || r || from tauhat.
172  // and here we need tauhat from the system with the smallest
173  // unconverged shift
174  int unc=0;
175  while( convP[unc] == true && unc < numroot) { unc++; }
176 
177  // If there are other shifts find the smallest unconverged one.
178  if( unc < numroot ) {
179  Real mod_me = sqrt(real(conj(zeta[unc])*zeta[unc]));
180 
181  // look through the other shifts
182  for(int shift=unc+1; shift < numroot; shift++) {
183 
184  // Only compare if they are unconverged
185  if( convP[shift] == false ) {
186  Real mod_trial = sqrt(real(conj(zeta[shift])*zeta[shift]));
187 
188  if ( toBool(mod_trial < mod_me) ) {
189  // The trial shift has smaller norm than me and is unconverged
190  // so we take it as the smallest and save its mod
191  unc = shift;
192  mod_me = mod_trial;
193  }
194 
195  }
196  }
197  }
198  else {
199  QDPIO::cerr << "All systems appear to be converged. I shouldnt be here" << std::endl;
200  }
201 
202  // unc now contains the shift index of the system with the smallest
203  // shift that is still unconverged. Use the tauhat of this for the
204  // relaxation.
205 
206  Real inner_tol = epsilon_sigma / sqrt(real(conj(tauhat[unc])*tauhat[unc]));
207 
208 
209 
210  // u := U v_m
211  U(u, v, PLUS, inner_tol);
212 
213  // gamma_m = - < vtilde, u >
214  gamma = - innerProduct(vtilde, u);
215 
216  // sigma = ( ( 1 + | gamma | )(1 - | gamma | ) )^{1/2}
217  //
218  // NB I multiplied out the inner brackets
219  // sigma = ( 1 - | gamma |^2 )^{1/2}
220  sigma = sqrt( Real(1) - real( conj(gamma)*gamma ) );
221 
222  // multi1d<Complex> alpha(numroot);
223  Complex alpha;
224  alpha = -gamma*delta;
225 
226  multi1d<Complex> r1hat(numroot);
227 
228  // the abs_rhat_sq abs_sigma_sq and tmp_length are truly temporary
229  // further, sigma is shift independent so I compute it here.
230  Real abs_rhat_sq;
231  Real abs_sigma_sq = sigma*sigma;
232  Real tmp_length;
233  multi1d<Complex> r1(numroot);
234 
235  // Go through all the shifts and update w, p, and x
236  for(int shift = 0; shift < numroot; shift++) {
237 
238  // Only do unconverged systems
239  if( !convP[shift] ) {
240 
241  // r_{m-1, m}:= alpha_m * phi_{m-1} + s_{m-1}*zeta/rho;
242  r0[shift] = alpha * phi[shift] + s[shift]*zeta[shift]/rho[shift];
243 
244  // rhat_{m,m} := alpha_m phihat_m + conj(c_{m-1})*zeta/rho;
245  r1hat = alpha * phihat[shift]+conj(c[shift])*zeta[shift]/rho[shift];
246 
247  // conj(c_m) := rhat_{m,m}/( | rhat_{m,m} |^2 + | sigma_m |^2 )^{1/2}
248  // s_m := -sigma_m / ( | rhat_{m,m} |^2 + | sigma_m |^2 )^{1/2}
249  //
250  // Note I precompute the denominator into tmp_length.
251  // and I conjugate the expression for conj(c) to get c directly
252  // This trick is from the Wuppertal Group's MATLAB implementation
253  abs_rhat_sq = real(conj(r1hat[shift])*r1hat[shift]);
254  tmp_length = sqrt(abs_rhat_sq + abs_sigma_sq);
255 
256  c[shift] = conj(r1hat[shift])/tmp_length;
257  s[shift] = - sigma / tmp_length;
258 
259  // r_{m,m} := -c_m*rhat_{m,m} + s_m*sigma_m;
260  r1[shift] = -c[shift]*r1hat[shift] + cmplx(s[shift]*sigma,0);
261 
262  // tau_m := -c_m*tauhat_m
263  Complex tau = -c[shift]*tauhat[shift];
264 
265  // tauhat_{m+1} := s_m tauhat_m
266  tauhat[shift] = s[shift] * tauhat[shift];
267 
268  // eta_m := tau_m / r_{m,m}
269  Complex eta = tau / r1[shift];
270 
271  // kappa_{m-1} := r_{m-1,m}/r_{m-1, m-1}
272  Complex kappa = r0[shift]/r1_old[shift];
273 
274  // keep r1
275  r1_old[shift]=r1[shift];
276 
277  // w_{m-1} := alpha_m * p_{m-2} - ( w_{m-2} - v_{m-1} ) kappa_{m-1}
278  // p_{m-1} := p_{m-2} - ( w_{m-2} - v_{m-1} )*lambda_{m-1}
279  //
280  // NB: I precompute v_{m-2} - v_{m-1} into w_minus_v
281  LatticeFermion w_minus_v;
282  w_minus_v = w_old[shift] - v_old;
283 
284  LatticeFermion w = alpha*p[shift] - kappa*w_minus_v;
285  p[shift] = p[shift] - lambda[shift]*w_minus_v;
286 
287  // x_{m-1} := x_{m-1} - ( w_{m-1} - v_m ) eta
288  //
289  // NB I compute w_{m-1} - v_m into w_minus_v
290 
291  w_minus_v = w-v;
292  x[shift] = x[shift] - eta*w_minus_v;
293 
294  // I need to keep w as w_old for the next iteration
295  w_old[shift] = w;
296 
297  }
298  }
299 
300  // At this point the paper writes:
301  //
302  // if (sigma_m = 0) then we have converged
303  //
304  // In this case I regard all systems as converted
305  //
306  if( toBool( sqrt(real(conj(sigma)*sigma)) > epsilon_sigma ) ) {
307 
308  // Here we haven't converged
309 
310  // delta_m := delta_{m-1} sigma_m
311  delta = delta*sigma;
312 
313  // Update phi, lambda, phihat
314  for(int shift=0; shift < numroot; shift++) {
315 
316  // But only for the update systems
317  if( !convP[shift] ) {
318 
319  // phi_m := -c_m phihat_m + s_m conj(gamma_m)/delta_m
320  phi[shift] = -c[shift]*phihat[shift]
321  + s[shift]*conj(gamma)/delta;
322 
323  // lambda_m := phi_m / r_{m,m}
324  lambda[shift] = phi[shift]/r1[shift];
325 
326  // phihat_{m+1} := s_m phihat_{m} + conj(c_m)*conj(gamma_m)/delta_m
327  phihat[shift] = s[shift]* phihat[shift]
328  + conj(c[shift]) * conj(gamma)/delta;
329 
330  }
331  }
332 
333  // preserve v as v_old
334  v_old = v;
335 
336  // v_{m+1} := (1/sigma)( u + gamma_m vtilde_m );
337  v = u + gamma*vtilde;
338  Real ftmp2 = Real(1)/sigma;
339  v *= ftmp2;
340 
341  // vtilde_{m+1} := sigma_m vtilde_m + conj(gamma_m)*v_{m+1}
342  Complex gconj = conj(gamma);
343  Complex csigma=cmplx(sigma,0);
344  vtilde = csigma*vtilde + gconj*v;
345 
346  // Normalise vtilde -- I found this in the Wuppertal MATLAB code
347  // It is not prescribed by the Reichel/Jagels paper
348  Real ftmp3 = Real(1)/sqrt(norm2(vtilde));
349  vtilde *= ftmp3;
350 
351  // Check whether we have converged or not:
352  //
353  // converged if | tauhat | < epsilon
354  //
355  // Assume we have all converged, and we BOOLEAN AND
356  // with individual systems
357  allConvP = true;
358 
359  // Go through all the shifted systems
360  for(int shift=0; shift < numroot; shift++) {
361 
362  // Only check unconverged systems
363  if( ! convP[shift] ) {
364 
365  // Get | tauhat |
366  Real taumod = sqrt(real(conj(tauhat[shift])*tauhat[shift]));
367 
368  QDPIO::cout << "Iter " << iter << ": Shift: " << shift
369  <<" | tauhat |=" << taumod << std::endl;
370 
371  // Check convergence
372  if ( toBool(taumod < epsilon[shift] ) ) convP[shift] = true;
373 
374  // Boolean AND with Overall convergence flag
375  allConvP &= convP[shift];
376  }
377 
378  }
379  }
380  else {
381 
382  // Else clause of the if sigma_m == 0. If sigma < epsilon then converged
383  allConvP = true;
384 
385  }
386 
387  // Count the number of iterations
388  n_count = iter;
389  }
390 
391  // And we are done, either converged or not...
392  if( n_count == MaxSUMR && ! allConvP ) {
393  QDPIO::cout << "Solver Nonconvergence Warning " << std::endl;
394  QDP_abort(1);
395  }
396 
397  END_CODE();
398 }
399 
400 template<>
402  const LatticeFermion& b,
403  multi1d<LatticeFermion>& x,
404  const multi1d<Complex>& zeta,
405  const multi1d<Real>& rho,
406  const multi1d<Real>& epsilon,
407  int MaxSUMR,
408  int& n_count)
409 {
410  MInvRelSUMR_a(U, b, x, zeta, rho, epsilon, MaxSUMR, n_count);
411 }
412 
413 } // end namespace Chroma
Primary include file for CHROMA library code.
Linear Operator.
Definition: linearop.h:27
int x
Definition: meslate.cc:34
LatticeReal phi
Definition: mesq.cc:27
Conjugate-Gradient algorithm for a generic Linear Operator.
int epsilon(int i, int j, int k)
BinaryReturn< C1, C2, FnInnerProduct >::Type_t innerProduct(const QDPSubType< T1, C1 > &s1, const QDPType< T2, C2 > &s2)
static const LatticeInteger & alpha(const int dim)
Definition: stag_phases_s.h:43
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
static multi1d< LatticeColorMatrix > u
int n_count
Definition: invbicg.cc:78
void MInvRelSUMR_a(const LinearOperator< T > &U, const T &b, multi1d< T > &x, const multi1d< Complex > &zeta, const multi1d< Real > &rho, const multi1d< Real > &epsilon, int MaxSUMR, int &n_count)
Double c
Definition: invbicg.cc:108
LinOpSysSolverMGProtoClover::T T
@ PLUS
Definition: chromabase.h:45
LatticeFermion eta
Definition: mespbg5p_w.cc:37
void MInvRelSUMR(const LinearOperator< LatticeFermion > &U, const LatticeFermion &b, multi1d< LatticeFermion > &x, const multi1d< Complex > &zeta, const multi1d< Real > &rho, const multi1d< Real > &epsilon, int MaxSUMR, int &n_count)
START_CODE()
Complex b
Definition: invbicg.cc:96
Double zero
Definition: invbicg.cc:106
multi1d< LatticeFermion > s(Ncb)
int kappa
Definition: pade_trln_w.cc:112
int r0
Definition: qtopcor.cc:41
multi1d< LatticeColorMatrix > U
Double ftmp2
Definition: topol.cc:30