CHROMA
minvsumr.cc
Go to the documentation of this file.
1 #include "chromabase.h"
3 
4 
5 namespace Chroma {
6 
7 // Solve a shifted unitary system
8 //
9 // A x = b
10 //
11 // Where X is of the form: A = zeta I + rho U
12 //
13 // rho > 0 and zeta are complex, and U is unitary
14 //
15 // We solve with the method described in:
16 //
17 // "A Fast Minimal Residual Algorithm for Shifted Unitary Matrices"
18 // by Carl F. Jagels, and Lothar Reichel
19 // Numerical Linear Algebra with Applications, Vol 1(6), 555-570(1994)
20 //
21 // This paper is referenced by and applied to the Overlap Dirac Operator
22 // by G. Arnold, N. Cundy, J. van den Eshof, A Frommer, S. Krieg, T. Lippert,
23 // K. Schaefer "Numerical Methods for the QCD Overlap Operator: II.
24 // Optimal Krylov Subspace Methods" -- hep-lat/0311025
25 // which is where the name SUMR was coined.
26 //
27 
28 
29 template<typename T>
31  const T& b,
32  multi1d<T>& x,
33  const multi1d<Complex>& zeta,
34  const multi1d<Real>& rho,
35  const multi1d<Real>& epsilon,
36  int MaxSUMR,
37  int& n_count)
38 {
39  START_CODE();
40 
41  // Sanity check
42  int numroot = x.size();
43  if( zeta.size() != numroot ) {
44  QDPIO::cerr << "zeta size:"<< zeta.size()
45  <<" is different from x.size:"<<numroot << std::endl;
46 
47  QDP_abort(1);
48  }
49 
50  if( rho.size() != numroot) {
51  QDPIO::cerr << "rho size:"<< rho.size()
52  <<" is different from x.size:"<<numroot << std::endl;
53 
54  }
55 
56  if( epsilon.size() != numroot) {
57  QDPIO::cerr << "epsilon size:"<< epsilon.size()
58  <<" is different from x.size:"<<numroot << std::endl;
59 
60  }
61 
62  // Auxiliary condition for stopping is when sigma < epsilon
63  // To keep things consistent this should be the smallest epsilon passed in
64  Real epsilon_sigma = epsilon[0];
65  for(int shift = 0; shift < numroot; shift++) {
66  if ( toBool( epsilon[shift] < epsilon_sigma ) ) {
67  epsilon_sigma = epsilon[shift];
68  }
69  }
70 
71  /********************************************************************/
72  /* Initialisation */
73  /********************************************************************/
74 
75  // First things first. We need r_0 = b - A x
76  // for all systems.
77  // Restriction here, is that the initial std::vector x is zero
78  // and so r_0 = b
79  for(int shift = 0; shift < numroot; shift++) {
80  x[shift] = zero;
81  }
82 
83  LatticeFermion r;
84  r = b;
85 
86  // delta_0_shift = || r_shift ||
87  Real delta = sqrt(norm2(r));
88 
89  multi1d<Complex> phihat(numroot);
90  multi1d<Complex> tauhat(numroot);
91 
92  // This std::vector is part of the Arnoldi process and should
93  // be shift independent.
94  for(int shift = 0; shift < numroot; shift++) {
95  // phi_hat_1 = 1 / delta_0
96  phihat[shift] = Real(1)/delta;
97 
98  // tau_hat = delta_0 / rho
99  tauhat[shift] = delta/rho[shift];
100  }
101 
102  // v_0 = zero
103  LatticeFermion v_old = zero;
104 
105  // These are used in updating the individual solutions
106  // and so are multi massed
107  multi1d<LatticeFermion> w_old(numroot);
108  multi1d<LatticeFermion> p(numroot);
109 
110  // w_-1 = p_-1
111  //
112  // Iteration starts at m=1, so w_-1, p_-1 are in fact
113  // w_{m-2}, p_{m-2}, v_{m-1}
114  for(int shift = 0; shift < numroot; shift++) {
115  w_old[shift]=zero;
116  p[shift]=zero;
117  }
118 
119 
120  // These are part of the Arnoldi process and ar shift independent
121  // gamma_0:= 1, sigma := 1
122  Complex gamma = Real(1);
123  Real sigma = Real(1);
124 
125  // These are all shift dependent
126  multi1d<Complex> phi(numroot);
127  multi1d<Real> s(numroot);
128  multi1d<Complex> lambda(numroot);
129  multi1d<Complex> r0(numroot);
130  multi1d<Complex> r1_old(numroot);
131  multi1d<Complex> c(numroot);
132 
133  // phi_0 = s_0 = lambda_0 := r_{-1,0} = 0
134  // r_{0,0} := c_0 := 1;
135  for(int shift = 0; shift < numroot; shift++) {
136  phi[shift] = Real(0);
137  s[shift] = Real(0);
138  lambda[shift] = Real(0);
139  r0[shift] = Real(0);
140  r1_old[shift] = Real(1);
141  c[shift] = Real(1);
142  }
143 
144  // It is here, that we couple the individual systems.
145  // Because r_0 is the same for all the systems
146  // v_1 is the same for all the system
147 
148  // v_1 := \tilde{v}_1 := r_0 / delta_ 0
149  LatticeFermion v;
150  LatticeFermion vtilde;
151 
152  Real ftmp = Real(1)/delta;
153  v = ftmp*r;
154  vtilde = ftmp*r;
155 
156  /***********************************************************/
157  /* Start the iteration */
158  /***********************************************************/
159  multi1d<bool> convP(numroot);
160  convP = false;
161  bool allConvP = false;
162 
163 
164  for(int iter = 1; iter <= MaxSUMR && !allConvP; iter++) {
165  LatticeFermion u;
166 
167  // Updating u, gamma and sigma are common to all systems.
168 
169  // u := U v_m
170  U(u, v, PLUS);
171 
172  // gamma_m = - < vtilde, u >
173  gamma = - innerProduct(vtilde, u);
174 
175  // sigma = ( ( 1 + | gamma | )(1 - | gamma | ) )^{1/2}
176  //
177  // NB I multiplied out the inner brackets
178  // sigma = ( 1 - | gamma |^2 )^{1/2}
179  sigma = sqrt( Real(1) - real( conj(gamma)*gamma ) );
180 
181  // multi1d<Complex> alpha(numroot);
182  Complex alpha;
183  alpha = -gamma*delta;
184 
185  multi1d<Complex> r1hat(numroot);
186 
187  // the abs_rhat_sq abs_sigma_sq and tmp_length are truly temporary
188  // further, sigma is shift independent so I compute it here.
189  Real abs_rhat_sq;
190  Real abs_sigma_sq = sigma*sigma;
191  Real tmp_length;
192  multi1d<Complex> r1(numroot);
193 
194  // Go through all the shifts and update w, p, and x
195  for(int shift = 0; shift < numroot; shift++) {
196 
197  // Only do unconverged systems
198  if( !convP[shift] ) {
199 
200  // r_{m-1, m}:= alpha_m * phi_{m-1} + s_{m-1}*zeta/rho;
201  r0[shift] = alpha * phi[shift] + s[shift]*zeta[shift]/rho[shift];
202 
203  // rhat_{m,m} := alpha_m phihat_m + conj(c_{m-1})*zeta/rho;
204  r1hat = alpha * phihat[shift]+conj(c[shift])*zeta[shift]/rho[shift];
205 
206  // conj(c_m) := rhat_{m,m}/( | rhat_{m,m} |^2 + | sigma_m |^2 )^{1/2}
207  // s_m := -sigma_m / ( | rhat_{m,m} |^2 + | sigma_m |^2 )^{1/2}
208  //
209  // Note I precompute the denominator into tmp_length.
210  // and I conjugate the expression for conj(c) to get c directly
211  // This trick is from the Wuppertal Group's MATLAB implementation
212  abs_rhat_sq = real(conj(r1hat[shift])*r1hat[shift]);
213  tmp_length = sqrt(abs_rhat_sq + abs_sigma_sq);
214 
215  c[shift] = conj(r1hat[shift])/tmp_length;
216  s[shift] = - sigma / tmp_length;
217 
218  // r_{m,m} := -c_m*rhat_{m,m} + s_m*sigma_m;
219  r1[shift] = -c[shift]*r1hat[shift] + cmplx(s[shift]*sigma,0);
220 
221  // tau_m := -c_m*tauhat_m
222  Complex tau = -c[shift]*tauhat[shift];
223 
224  // tauhat_{m+1} := s_m tauhat_m
225  tauhat[shift] = s[shift] * tauhat[shift];
226 
227  // eta_m := tau_m / r_{m,m}
228  Complex eta = tau / r1[shift];
229 
230  // kappa_{m-1} := r_{m-1,m}/r_{m-1, m-1}
231  Complex kappa = r0[shift]/r1_old[shift];
232 
233  // keep r1
234  r1_old[shift]=r1[shift];
235 
236  // w_{m-1} := alpha_m * p_{m-2} - ( w_{m-2} - v_{m-1} ) kappa_{m-1}
237  // p_{m-1} := p_{m-2} - ( w_{m-2} - v_{m-1} )*lambda_{m-1}
238  //
239  // NB: I precompute v_{m-2} - v_{m-1} into w_minus_v
240  LatticeFermion w_minus_v;
241  w_minus_v = w_old[shift] - v_old;
242 
243  LatticeFermion w = alpha*p[shift] - kappa*w_minus_v;
244  p[shift] = p[shift] - lambda[shift]*w_minus_v;
245 
246  // x_{m-1} := x_{m-1} - ( w_{m-1} - v_m ) eta
247  //
248  // NB I compute w_{m-1} - v_m into w_minus_v
249 
250  w_minus_v = w-v;
251  x[shift] = x[shift] - eta*w_minus_v;
252 
253  // I need to keep w as w_old for the next iteration
254  w_old[shift] = w;
255 
256  }
257  }
258 
259  // At this point the paper writes:
260  //
261  // if (sigma_m = 0) then we have converged
262  //
263  // In this case I regard all systems as converted
264  //
265  if( toBool( sqrt(real(conj(sigma)*sigma)) > epsilon_sigma ) ) {
266 
267  // Here we haven't converged
268 
269  // delta_m := delta_{m-1} sigma_m
270  delta = delta*sigma;
271 
272  // Update phi, lambda, phihat
273  for(int shift=0; shift < numroot; shift++) {
274 
275  // But only for the update systems
276  if( !convP[shift] ) {
277 
278  // phi_m := -c_m phihat_m + s_m conj(gamma_m)/delta_m
279  phi[shift] = -c[shift]*phihat[shift]
280  + s[shift]*conj(gamma)/delta;
281 
282  // lambda_m := phi_m / r_{m,m}
283  lambda[shift] = phi[shift]/r1[shift];
284 
285  // phihat_{m+1} := s_m phihat_{m} + conj(c_m)*conj(gamma_m)/delta_m
286  phihat[shift] = s[shift]* phihat[shift]
287  + conj(c[shift]) * conj(gamma)/delta;
288 
289  }
290  }
291 
292  // preserve v as v_old
293  v_old = v;
294 
295  // v_{m+1} := (1/sigma)( u + gamma_m vtilde_m );
296  v = u + gamma*vtilde;
297  Real ftmp2 = Real(1)/sigma;
298  v *= ftmp2;
299 
300  // vtilde_{m+1} := sigma_m vtilde_m + conj(gamma_m)*v_{m+1}
301  Complex gconj = conj(gamma);
302  Complex csigma=cmplx(sigma,0);
303  vtilde = csigma*vtilde + gconj*v;
304 
305  // Normalise vtilde -- I found this in the Wuppertal MATLAB code
306  // It is not prescribed by the Reichel/Jagels paper
307  Real ftmp3 = Real(1)/sqrt(norm2(vtilde));
308  vtilde *= ftmp3;
309 
310  // Check whether we have converged or not:
311  //
312  // converged if | tauhat | < epsilon
313  //
314  // Assume we have all converged, and we BOOLEAN AND
315  // with individual systems
316  allConvP = true;
317 
318  // Go through all the shifted systems
319  for(int shift=0; shift < numroot; shift++) {
320 
321  // Only check unconverged systems
322  if( ! convP[shift] ) {
323 
324  // Get | tauhat |
325  Real taumod = sqrt(real(conj(tauhat[shift])*tauhat[shift]));
326 
327  QDPIO::cout << "Iter " << iter << ": Shift: " << shift
328  <<" | tauhat |=" << taumod << std::endl;
329 
330  // Check convergence
331  if ( toBool(taumod < epsilon[shift] ) ) convP[shift] = true;
332 
333  // Boolean AND with Overall convergence flag
334  allConvP &= convP[shift];
335  }
336 
337  }
338  }
339  else {
340 
341  // Else clause of the if sigma_m == 0. If sigma < epsilon then converged
342  allConvP = true;
343 
344  }
345 
346  // Count the number of iterations
347  n_count = iter;
348  }
349 
350  // And we are done, either converged or not...
351  if( n_count == MaxSUMR && ! allConvP ) {
352  QDPIO::cout << "Solver Nonconvergence Warning " << std::endl;
353  }
354 
355  END_CODE();
356 }
357 
358 template<>
360  const LatticeFermion& b,
361  multi1d<LatticeFermion>& x,
362  const multi1d<Complex>& zeta,
363  const multi1d<Real>& rho,
364  const multi1d<Real>& epsilon,
365  int MaxSUMR,
366  int& n_count)
367 {
368  MInvSUMR_a(U, b, x, zeta, rho, epsilon, MaxSUMR, n_count);
369 }
370 
371 } // end namespace Chroma
Primary include file for CHROMA library code.
Linear Operator.
Definition: linearop.h:27
int x
Definition: meslate.cc:34
LatticeReal phi
Definition: mesq.cc:27
Conjugate-Gradient algorithm for a generic Linear Operator.
int epsilon(int i, int j, int k)
BinaryReturn< C1, C2, FnInnerProduct >::Type_t innerProduct(const QDPSubType< T1, C1 > &s1, const QDPType< T2, C2 > &s2)
static const LatticeInteger & alpha(const int dim)
Definition: stag_phases_s.h:43
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
static multi1d< LatticeColorMatrix > u
void MInvSUMR(const LinearOperator< LatticeFermion > &U, const LatticeFermion &b, multi1d< LatticeFermion > &x, const multi1d< Complex > &zeta, const multi1d< Real > &rho, const multi1d< Real > &epsilon, int MaxSUMR, int &n_count)
Definition: minvsumr.cc:359
int n_count
Definition: invbicg.cc:78
Double c
Definition: invbicg.cc:108
LinOpSysSolverMGProtoClover::T T
@ PLUS
Definition: chromabase.h:45
LatticeFermion eta
Definition: mespbg5p_w.cc:37
START_CODE()
Complex b
Definition: invbicg.cc:96
Double zero
Definition: invbicg.cc:106
multi1d< LatticeFermion > s(Ncb)
void MInvSUMR_a(const LinearOperator< T > &U, const T &b, multi1d< T > &x, const multi1d< Complex > &zeta, const multi1d< Real > &rho, const multi1d< Real > &epsilon, int MaxSUMR, int &n_count)
Definition: minvsumr.cc:30
int kappa
Definition: pade_trln_w.cc:112
int r0
Definition: qtopcor.cc:41
multi1d< LatticeColorMatrix > U
Double ftmp2
Definition: topol.cc:30