CHROMA
syssolver_linop_wilson_quda_w.h
Go to the documentation of this file.
1 // -*- C++ -*-
2 /*! \file
3  * \brief Solve a MdagM*psi=chi linear system by BiCGStab
4  */
5 
6 #ifndef __syssolver_linop_quda_wilson_h__
7 #define __syssolver_linop_quda_wilson_h__
8 
9 #include "chroma_config.h"
10 
11 #ifdef BUILD_QUDA
12 #include <quda.h>
13 
14 #include "handle.h"
15 #include "state.h"
16 #include "syssolver.h"
17 #include "linearop.h"
23 #include "io/aniso_io.h"
24 #include <string>
25 
26 #include "util/gauge/reunit.h"
27 
28 //#include <util_quda.h>
29 
30 namespace Chroma
31 {
32 
33 //! Richardson system solver namespace
34 namespace LinOpSysSolverQUDAWilsonEnv
35 {
36 //! Register the syssolver
37 bool registerAll();
38 }
39 
40 
41 
42 //! Solve a Wilson Fermion System using the QUDA inverter
43 /*! \ingroup invert
44  *** WARNING THIS SOLVER WORKS FOR Wilson FERMIONS ONLY ***
45  */
46 
47 class LinOpSysSolverQUDAWilson : public LinOpSystemSolver<LatticeFermion>
48 {
49 public:
50  typedef LatticeFermion T;
51  typedef LatticeColorMatrix U;
52  typedef multi1d<LatticeColorMatrix> Q;
53 
54  typedef LatticeFermionF TF;
55  typedef LatticeColorMatrixF UF;
56  typedef multi1d<LatticeColorMatrixF> QF;
57 
58  typedef LatticeFermionF TD;
59  typedef LatticeColorMatrixF UD;
60  typedef multi1d<LatticeColorMatrixF> QD;
61 
62  typedef WordType<T>::Type_t REALT;
63  //! Constructor
64  /*!
65  * \param M_ Linear operator ( Read )
66  * \param invParam inverter parameters ( Read )
67  */
68  LinOpSysSolverQUDAWilson(Handle< LinearOperator<T> > A_,
69  Handle< FermState<T,Q,Q> > state_,
70  const SysSolverQUDAWilsonParams& invParam_) :
71  A(A_), invParam(invParam_)
72  {
73  QDPIO::cout << "LinOpSysSolverQUDAWilson:" << std::endl;
74 
75  // FOLLOWING INITIALIZATION in test QUDA program
76 
77  // 1) work out cpu_prec, cuda_prec, cuda_prec_sloppy
78  int s = sizeof( WordType<T>::Type_t );
79  if (s == 4) {
80  cpu_prec = QUDA_SINGLE_PRECISION;
81  }
82  else {
83  cpu_prec = QUDA_DOUBLE_PRECISION;
84  }
85 
86 
87  // Work out GPU precision
88  switch( invParam.cudaPrecision ) {
89  case HALF:
90  gpu_prec = QUDA_HALF_PRECISION;
91  break;
92  case SINGLE:
93  gpu_prec = QUDA_SINGLE_PRECISION;
94  break;
95  case DOUBLE:
96  gpu_prec = QUDA_DOUBLE_PRECISION;
97  break;
98  default:
99  gpu_prec = cpu_prec;
100  break;
101  }
102 
103  // Work out GPU Sloppy precision
104  // Default: No Sloppy
105  switch( invParam.cudaSloppyPrecision ) {
106  case HALF:
107  gpu_half_prec = QUDA_HALF_PRECISION;
108  break;
109  case SINGLE:
110  gpu_half_prec = QUDA_SINGLE_PRECISION;
111  break;
112  case DOUBLE:
113  gpu_half_prec = QUDA_DOUBLE_PRECISION;
114  break;
115  default:
116  gpu_half_prec = gpu_prec;
117  break;
118  }
119 
120  // 2) pull 'new; GAUGE and Invert params
121  q_gauge_param = newQudaGaugeParam();
122  quda_inv_param = newQudaInvertParam();
123 
124  // 3) set lattice size
125  const multi1d<int>& latdims = Layout::subgridLattSize();
126 
127  q_gauge_param.X[0] = latdims[0];
128  q_gauge_param.X[1] = latdims[1];
129  q_gauge_param.X[2] = latdims[2];
130  q_gauge_param.X[3] = latdims[3];
131 
132  // 4) - deferred (anisotropy)
133 
134  // 5) - set QUDA_WILSON_LINKS, QUDA_GAUGE_ORDER
135  q_gauge_param.type = QUDA_WILSON_LINKS;
136  q_gauge_param.gauge_order = QUDA_QDP_GAUGE_ORDER; // gauge[mu], p
137 
138  // 6) - set t_boundary
139  // Convention: BC has to be applied already
140  // This flag just tells QUDA that this is so,
141  // so that QUDA can take care in the reconstruct
142  if( invParam.AntiPeriodicT ) {
143  q_gauge_param.t_boundary = QUDA_ANTI_PERIODIC_T;
144  }
145  else {
146  q_gauge_param.t_boundary = QUDA_PERIODIC_T;
147  }
148 
149  // Set cpu_prec, cuda_prec, reconstruct and sloppy versions
150  q_gauge_param.cpu_prec = cpu_prec;
151  q_gauge_param.cuda_prec = gpu_prec;
152 
153 
154  switch( invParam.cudaReconstruct ) {
155  case RECONS_NONE:
156  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_NO;
157  break;
158  case RECONS_8:
159  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_8;
160  break;
161  case RECONS_12:
162  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_12;
163  break;
164  default:
165  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_12;
166  break;
167  };
168 
169  q_gauge_param.cuda_prec_sloppy = gpu_half_prec;
170 
171  switch( invParam.cudaSloppyReconstruct ) {
172  case RECONS_NONE:
173  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_NO;
174  break;
175  case RECONS_8:
176  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_8;
177  break;
178  case RECONS_12:
179  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12;
180  break;
181  default:
182  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12;
183  break;
184  };
185 
186  // Gauge fixing:
187 
188  // These are the links
189  // They may be smeared and the BC's may be applied
190  Q links_single(Nd);
191 
192  // Now downcast to single prec fields.
193  for(int mu=0; mu < Nd; mu++) {
194  links_single[mu] = (state_->getLinks())[mu];
195  }
196 
197  // GaugeFix
198  if( invParam.axialGaugeP ) {
199  QDPIO::cout << "Fixing Temporal Gauge" << std::endl;
200  temporalGauge(links_single, GFixMat, Nd-1);
201  for(int mu=0; mu < Nd; mu++){
202  links_single[mu] = GFixMat*(state_->getLinks())[mu]*adj(shift(GFixMat, FORWARD, mu));
203  }
204  q_gauge_param.gauge_fix = QUDA_GAUGE_FIXED_YES;
205  }
206  else {
207  // No GaugeFix
208  q_gauge_param.gauge_fix = QUDA_GAUGE_FIXED_NO; // No Gfix yet
209  }
210 
211  // deferred 4) Gauge Anisotropy
212  const AnisoParam_t& aniso = invParam.WilsonParams.anisoParam;
213  if( aniso.anisoP ) { // Anisotropic case
214  Real gamma_f = aniso.xi_0 / aniso.nu;
215  q_gauge_param.anisotropy = toDouble(gamma_f);
216  }
217  else {
218  q_gauge_param.anisotropy = 1.0;
219  }
220 
221  // MAKE FSTATE BEFORE RESCALING links_single
222  // Because the clover term expects the unrescaled links...
223  Handle<FermState<T,Q,Q> > fstate( new PeriodicFermState<T,Q,Q>(links_single));
224 
225  if( aniso.anisoP ) { // Anisotropic case
226  multi1d<Real> cf=makeFermCoeffs(aniso);
227  for(int mu=0; mu < Nd; mu++) {
228  links_single[mu] *= cf[mu];
229  }
230  }
231 
232  // Now onto the inv param:
233  // Dslash type
234  quda_inv_param.dslash_type = QUDA_WILSON_DSLASH;
235 
236  // Invert type:
237  switch( invParam.solverType ) {
238  case CG:
239  quda_inv_param.inv_type = QUDA_CG_INVERTER;
240  solver_string = "CG";
241  break;
242  case BICGSTAB:
243  quda_inv_param.inv_type = QUDA_BICGSTAB_INVERTER;
244  solver_string = "BICGSTAB";
245  break;
246  case GCR:
247  quda_inv_param.inv_type = QUDA_GCR_INVERTER;
248  solver_string = "GCR";
249  break;
250  default:
251  QDPIO::cerr << "Unknown SOlver type" << std::endl;
252  QDP_abort(1);
253  break;
254  }
255 
256  // Mass
257 
258  Real massParam = Real(1) + Real(3)/Real(q_gauge_param.anisotropy) + invParam.WilsonParams.Mass;
259 
260  //invMassParam = 1.0/massParam;
261  invMassParam = 1.0;
262 
263  quda_inv_param.kappa = 1.0/(2*toDouble(massParam));
264  quda_inv_param.clover_coeff = 0.0; // Always true for Wilson
265 
266  quda_inv_param.tol = toDouble(invParam.RsdTarget);
267  quda_inv_param.maxiter = invParam.MaxIter;
268  quda_inv_param.reliable_delta = toDouble(invParam.Delta);
269  quda_inv_param.pipeline = invParam.Pipeline;
270 
271  // Solution type
272  quda_inv_param.solution_type = QUDA_MATPC_SOLUTION;
273 
274  // Solve type
275  switch( invParam.solverType ) {
276  case CG:
277  quda_inv_param.solve_type = QUDA_NORMOP_PC_SOLVE;
278  break;
279  case BICGSTAB:
280  quda_inv_param.solve_type = QUDA_DIRECT_PC_SOLVE;
281  break;
282 
283  case GCR:
284  quda_inv_param.solve_type = QUDA_DIRECT_PC_SOLVE;
285  break;
286  case CA_GCR:
287  quda_inv_param.solve_type = QUDA_DIRECT_PC_SOLVE;
288  break;
289  case MR:
290  quda_inv_param.solve_type = QUDA_DIRECT_PC_SOLVE;
291  break;
292  default:
293  quda_inv_param.solve_type = QUDA_NORMOP_PC_SOLVE;
294  break;
295  }
296 
297 
298  if( invParam.asymmetricP ) {
299  QDPIO::cout << "Using Asymmetric Linop: A_oo - D A^{-1}_ee D" << std::endl;
300  quda_inv_param.matpc_type = QUDA_MATPC_ODD_ODD;
301 
302  }
303  else {
304  QDPIO::cout << "Using Symmetric Linop: 1 - A^{-1}_oo D A^{-1}_ee D" << std::endl;
305  quda_inv_param.matpc_type = QUDA_MATPC_ODD_ODD;
306  }
307 
308  quda_inv_param.dagger = QUDA_DAG_NO;
309  quda_inv_param.mass_normalization = QUDA_ASYMMETRIC_MASS_NORMALIZATION;
310  // quda_inv_param.mass_normalization = QUDA_KAPPA_NORMALIZATION;
311 
312  quda_inv_param.cpu_prec = cpu_prec;
313  quda_inv_param.cuda_prec = gpu_prec;
314  quda_inv_param.cuda_prec_sloppy = gpu_half_prec;
315  quda_inv_param.preserve_source = QUDA_PRESERVE_SOURCE_YES;
316  quda_inv_param.use_init_guess = QUDA_USE_INIT_GUESS_NO;
317  quda_inv_param.dirac_order = QUDA_DIRAC_ORDER;
318  quda_inv_param.gamma_basis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS;
319  // Autotuning
320  if( invParam.tuneDslashP ) {
321  QDPIO::cout << "Enabling Dslash Autotuning" << std::endl;
322 
323  quda_inv_param.tune = QUDA_TUNE_YES;
324  }
325  else {
326  QDPIO::cout << "Disabling Dslash Autotuning" << std::endl;
327 
328  quda_inv_param.tune = QUDA_TUNE_NO;
329  }
330 
331 
332  // Setup padding
333  multi1d<int> face_size(4);
334  face_size[0] = latdims[1]*latdims[2]*latdims[3]/2;
335  face_size[1] = latdims[0]*latdims[2]*latdims[3]/2;
336  face_size[2] = latdims[0]*latdims[1]*latdims[3]/2;
337  face_size[3] = latdims[0]*latdims[1]*latdims[2]/2;
338 
339  int max_face = face_size[0];
340  for(int i=1; i <=3; i++) {
341  if ( face_size[i] > max_face ) {
342  max_face = face_size[i];
343  }
344  }
345 
346 
347  q_gauge_param.ga_pad = max_face;
348  // PADDING
349  quda_inv_param.sp_pad = 0;
350  quda_inv_param.cl_pad = 0;
351 
352  if( invParam.innerParamsP ) {
353  QDPIO::cout << "Setting inner solver params" << std::endl;
354  // Dereference handle
355  const GCRInnerSolverParams& ip = *(invParam.innerParams);
356 
357  // Set preconditioner precision
358  switch( ip.precPrecondition ) {
359  case HALF:
360  quda_inv_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
361  q_gauge_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
362  break;
363 
364  case SINGLE:
365  quda_inv_param.cuda_prec_precondition = QUDA_SINGLE_PRECISION;
366  q_gauge_param.cuda_prec_precondition = QUDA_SINGLE_PRECISION;
367  break;
368 
369  case DOUBLE:
370  quda_inv_param.cuda_prec_precondition = QUDA_DOUBLE_PRECISION;
371  q_gauge_param.cuda_prec_precondition = QUDA_DOUBLE_PRECISION;
372  break;
373  default:
374  quda_inv_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
375  q_gauge_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
376  break;
377  }
378 
379  switch( ip.reconstructPrecondition ) {
380  case RECONS_NONE:
381  q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_NO;
382  break;
383  case RECONS_8:
384  q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_8;
385  break;
386  case RECONS_12:
387  q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_12;
388  break;
389  default:
390  q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_NO;
391  break;
392  };
393 
394  quda_inv_param.tol_precondition = toDouble(ip.tolPrecondition);
395  quda_inv_param.maxiter_precondition = ip.maxIterPrecondition;
396  quda_inv_param.gcrNkrylov = ip.gcrNkrylov;
397  switch( ip.schwarzType ) {
398  case ADDITIVE_SCHWARZ :
399  quda_inv_param.schwarz_type = QUDA_ADDITIVE_SCHWARZ;
400  break;
402  quda_inv_param.schwarz_type = QUDA_MULTIPLICATIVE_SCHWARZ;
403  break;
404  default:
405  quda_inv_param.schwarz_type = QUDA_ADDITIVE_SCHWARZ;
406  break;
407  }
408  quda_inv_param.precondition_cycle = ip.preconditionCycle;
409 
410  if( ip.verboseInner ) {
411  quda_inv_param.verbosity_precondition = QUDA_VERBOSE;
412  }
413  else {
414  quda_inv_param.verbosity_precondition = QUDA_SILENT;
415  }
416 
417  switch( ip.invTypePrecondition ) {
418  case CG:
419  quda_inv_param.inv_type_precondition = QUDA_CG_INVERTER;
420  break;
421  case BICGSTAB:
422  quda_inv_param.inv_type_precondition = QUDA_BICGSTAB_INVERTER;
423  break;
424  case CA_GCR:
425  quda_inv_param.inv_type_precondition= QUDA_CA_GCR_INVERTER;
426  break;
427  case MR:
428  quda_inv_param.inv_type_precondition= QUDA_MR_INVERTER;
429  break;
430 
431  default:
432  quda_inv_param.inv_type_precondition = QUDA_MR_INVERTER;
433  break;
434  }
435  }
436  else {
437  QDPIO::cout << "Setting Precondition stuff to defaults for not using" << std::endl;
438  quda_inv_param.inv_type_precondition= QUDA_INVALID_INVERTER;
439  quda_inv_param.tol_precondition = 1.0e-1;
440  quda_inv_param.maxiter_precondition = 1000;
441  quda_inv_param.verbosity_precondition = QUDA_SILENT;
442  quda_inv_param.gcrNkrylov = 1;
443  }
444 
445 
446 
447  if( invParam.verboseP ) {
448  quda_inv_param.verbosity = QUDA_VERBOSE;
449  }
450  else {
451  quda_inv_param.verbosity = QUDA_SUMMARIZE;
452  }
453 
454  // Set up the links
455  void* gauge[4];
456 
457  for(int mu=0; mu < Nd; mu++) {
458  gauge[mu] = (void *)&(links_single[mu].elem(all.start()).elem().elem(0,0).real());
459 
460  }
461 
462  loadGaugeQuda((void *)gauge, &q_gauge_param);
463 
464 
465 
466  }
467 
468 
469  //! Destructor is automatic
470  ~LinOpSysSolverQUDAWilson()
471  {
472  QDPIO::cout << "Destructing" << std::endl;
473  freeGaugeQuda();
474  }
475 
476  //! Return the subset on which the operator acts
477  const Subset& subset() const {return A->subset();}
478 
479  //! Solver the linear system
480  /*!
481  * \param psi solution ( Modify )
482  * \param chi source ( Read )
483  * \return syssolver results
484  */
485  SystemSolverResults_t operator() (T& psi, const T& chi) const
486  {
487  SystemSolverResults_t res;
488 
489  START_CODE();
490  StopWatch swatch;
491  swatch.start();
492 
493  T chiResc = zero;
494  chiResc[A->subset()] = invMassParam * chi;
495  // T MdagChi;
496 
497  // This is a CGNE. So create new RHS
498  // (*A)(MdagChi, chi, MINUS);
499  // Handle< LinearOperator<T> > MM(new MdagMLinOp<T>(A));
500  if ( invParam.axialGaugeP ) {
501  T g_chi,g_psi;
502 
503  // Gauge Fix source and initial guess
504  QDPIO::cout << "Gauge Fixing source and initial guess" << std::endl;
505  g_chi[ rb[1] ] = GFixMat * chiResc;
506  g_psi[ rb[1] ] = GFixMat * psi;
507  QDPIO::cout << "Solving" << std::endl;
508  res = qudaInvert(g_chi,
509  g_psi);
510  QDPIO::cout << "Untransforming solution." << std::endl;
511  psi[ rb[1]] = adj(GFixMat)*g_psi;
512 
513  }
514  else {
515  res = qudaInvert(chiResc,
516  psi);
517  }
518 
519  swatch.stop();
520 
521 
522  {
523  T r;
524  r[A->subset()]=chi;
525  T tmp;
526  (*A)(tmp, psi, PLUS);
527  r[A->subset()] -= tmp;
528  res.resid = sqrt(norm2(r, A->subset()));
529  }
530 
531  Double rel_resid = res.resid/sqrt(norm2(chi,A->subset()));
532 
533  QDPIO::cout << "QUDA_"<< solver_string <<"_WILSON_SOLVER: " << res.n_count << " iterations. Rsd = " << res.resid << " Relative Rsd = " << rel_resid << std::endl;
534 
535  // Convergence Check/Blow Up
536  if ( ! invParam.SilentFailP ) {
537  if ( toBool( rel_resid > invParam.RsdToleranceFactor*invParam.RsdTarget) ) {
538  QDPIO::cerr << "ERROR: QUDA Solver residuum is outside tolerance: QUDA resid="<< rel_resid << " Desired =" << invParam.RsdTarget << " Max Tolerated = " << invParam.RsdToleranceFactor*invParam.RsdTarget << std::endl;
539  QDP_abort(1);
540  }
541  }
542 
543  END_CODE();
544  return res;
545  }
546 
547 
548 private:
549  // Hide default constructor
550  LinOpSysSolverQUDAWilson() {}
551 
552 #if 1
553  Q links_orig;
554 #endif
555 
556  U GFixMat;
557  QudaPrecision_s cpu_prec;
558  QudaPrecision_s gpu_prec;
559  QudaPrecision_s gpu_half_prec;
560 
561  Handle< LinearOperator<T> > A;
562  const SysSolverQUDAWilsonParams invParam;
563  Real invMassParam;
564  QudaGaugeParam q_gauge_param;
565  QudaInvertParam quda_inv_param;
566 
567  SystemSolverResults_t qudaInvert(const T& chi_s,
568  T& psi_s
569  )const ;
570 
571  std::string solver_string;
572 };
573 
574 
575 } // End namespace
576 
577 
578 #endif
579 #endif
580 
Anisotropy parameters.
int mu
Definition: cool.cc:24
void temporalGauge(multi1d< LatticeColorMatrix > &ug, LatticeColorMatrix &g, int decay_dir)
Temporal gauge fixing.
Class for counted reference semantics.
Linear Operators.
Nd
Definition: meslate.cc:74
bool registerAll()
Register all the factories.
multi1d< Hadron2PtContraction_t > operator()(const multi1d< LatticeColorMatrix > &u)
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
LinOpSysSolverMGProtoClover::Q Q
LatticeFermion tmp
Definition: mespbg5p_w.cc:36
LinOpSysSolverMGProtoClover::T T
int i
Definition: pbg5p_w.cc:55
@ PLUS
Definition: chromabase.h:45
multi1d< LatticeFermion > chi(Ncb)
LatticeFermion psi
Definition: mespbg5p_w.cc:35
START_CODE()
A(A, psi, r, Ncb, PLUS)
Double zero
Definition: invbicg.cc:106
multi1d< Real > makeFermCoeffs(const AnisoParam_t &aniso)
Make fermion coefficients.
Definition: aniso_io.cc:63
multi1d< LatticeFermion > s(Ncb)
@ ADDITIVE_SCHWARZ
Definition: enum_quda_io.h:103
@ MULTIPLICATIVE_SCHWARZ
Definition: enum_quda_io.h:104
@ RECONS_12
Definition: enum_quda_io.h:80
@ RECONS_NONE
Definition: enum_quda_io.h:78
@ RECONS_8
Definition: enum_quda_io.h:79
FloatingPoint< double > Double
Definition: gtest.h:7351
::std::string string
Definition: gtest.h:1979
Periodic ferm state and a creator.
#define FORWARD
Definition: primitives.h:82
Reunitarize in place a color matrix to SU(N)
Simple fermionic BC.
Support class for fermion actions and linear operators.
Linear system solvers.
multi1d< LatticeColorMatrix > U
LatticeFermion T
Definition: t_clover.cc:11
multi1d< LatticeColorMatrix > Q
Definition: t_clover.cc:12
multi1d< LatticeColorMatrixF > QF
Definition: t_quda_tprec.cc:19
LatticeColorMatrixF UF
Definition: t_quda_tprec.cc:18
LatticeFermionD TD
Definition: t_quda_tprec.cc:22
LatticeColorMatrixD UD
Definition: t_quda_tprec.cc:23
LatticeFermionF TF
Definition: t_quda_tprec.cc:17
multi1d< LatticeColorMatrixD > QD
Definition: t_quda_tprec.cc:24
Axial gauge fixing.