CHROMA
syssolver_mdagm_wilson_quda_w.h
Go to the documentation of this file.
1 // -*- C++ -*-
2 /*! \file
3  * \brief Solve a MdagM*psi=chi linear system by BiCGStab
4  */
5 
6 #ifndef __syssolver_mdagm_quda_wilson_h__
7 #define __syssolver_mdagm_quda_wilson_h__
8 
9 #include "chroma_config.h"
10 
11 #ifdef BUILD_QUDA
12 
13 #include "handle.h"
14 #include "state.h"
15 #include "syssolver.h"
16 #include "linearop.h"
17 #include "lmdagm.h"
23 #include "io/aniso_io.h"
24 #include <string>
25 
26 #include "util/gauge/reunit.h"
27 
28 #include <quda.h>
29 
30 
31 namespace Chroma
32 {
33 
34  //! Richardson system solver namespace
35  namespace MdagMSysSolverQUDAWilsonEnv
36  {
37  //! Register the syssolver
38  bool registerAll();
39  }
40 
41 
42 
43  //! Solve a Wilson Fermion System using the QUDA inverter
44  /*! \ingroup invert
45  *** WARNING THIS SOLVER WORKS FOR Wilson FERMIONS ONLY ***
46  */
47 
48  class MdagMSysSolverQUDAWilson : public MdagMSystemSolver<LatticeFermion>
49  {
50  public:
51  typedef LatticeFermion T;
52  typedef LatticeColorMatrix U;
53  typedef multi1d<LatticeColorMatrix> Q;
54 
55  typedef LatticeFermionF TF;
56  typedef LatticeColorMatrixF UF;
57  typedef multi1d<LatticeColorMatrixF> QF;
58 
59  typedef LatticeFermionF TD;
60  typedef LatticeColorMatrixF UD;
61  typedef multi1d<LatticeColorMatrixF> QD;
62 
63  typedef WordType<T>::Type_t REALT;
64  //! Constructor
65  /*!
66  * \param M_ Linear operator ( Read )
67  * \param invParam inverter parameters ( Read )
68  */
69  MdagMSysSolverQUDAWilson(Handle< LinearOperator<T> > A_,
70  Handle< FermState<T,Q,Q> > state_,
71  const SysSolverQUDAWilsonParams& invParam_) :
72  A(A_), invParam(invParam_)
73  {
74  QDPIO::cout << "MdagMSysSolverQUDAWilson:" << std::endl;
75 
76  // FOLLOWING INITIALIZATION in test QUDA program
77 
78  // 1) work out cpu_prec, cuda_prec, cuda_prec_sloppy
79  int s = sizeof( WordType<T>::Type_t );
80  if (s == 4) {
81  cpu_prec = QUDA_SINGLE_PRECISION;
82  }
83  else {
84  cpu_prec = QUDA_DOUBLE_PRECISION;
85  }
86 
87 
88  // Work out GPU precision
89  switch( invParam.cudaPrecision ) {
90  case HALF:
91  gpu_prec = QUDA_HALF_PRECISION;
92  break;
93  case SINGLE:
94  gpu_prec = QUDA_SINGLE_PRECISION;
95  break;
96  case DOUBLE:
97  gpu_prec = QUDA_DOUBLE_PRECISION;
98  break;
99  default:
100  gpu_prec = cpu_prec;
101  break;
102  }
103 
104  // Work out GPU Sloppy precision
105  // Default: No Sloppy
106  switch( invParam.cudaSloppyPrecision ) {
107  case HALF:
108  gpu_half_prec = QUDA_HALF_PRECISION;
109  break;
110  case SINGLE:
111  gpu_half_prec = QUDA_SINGLE_PRECISION;
112  break;
113  case DOUBLE:
114  gpu_half_prec = QUDA_DOUBLE_PRECISION;
115  break;
116  default:
117  gpu_half_prec = gpu_prec;
118  break;
119  }
120 
121  // 2) pull 'new; GAUGE and Invert params
122  q_gauge_param = newQudaGaugeParam();
123  quda_inv_param = newQudaInvertParam();
124 
125  // 3) set lattice size
126  const multi1d<int>& latdims = Layout::subgridLattSize();
127 
128  q_gauge_param.X[0] = latdims[0];
129  q_gauge_param.X[1] = latdims[1];
130  q_gauge_param.X[2] = latdims[2];
131  q_gauge_param.X[3] = latdims[3];
132 
133  // 4) - deferred (anisotropy)
134 
135  // 5) - set QUDA_WILSON_LINKS, QUDA_GAUGE_ORDER
136  q_gauge_param.type = QUDA_WILSON_LINKS;
137  q_gauge_param.gauge_order = QUDA_QDP_GAUGE_ORDER; // gauge[mu], p
138 
139  // 6) - set t_boundary
140  // Convention: BC has to be applied already
141  // This flag just tells QUDA that this is so,
142  // so that QUDA can take care in the reconstruct
143  if( invParam.AntiPeriodicT ) {
144  q_gauge_param.t_boundary = QUDA_ANTI_PERIODIC_T;
145  }
146  else {
147  q_gauge_param.t_boundary = QUDA_PERIODIC_T;
148  }
149 
150  // Set cpu_prec, cuda_prec, reconstruct and sloppy versions
151  q_gauge_param.cpu_prec = cpu_prec;
152  q_gauge_param.cuda_prec = gpu_prec;
153 
154 
155  switch( invParam.cudaReconstruct ) {
156  case RECONS_NONE:
157  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_NO;
158  break;
159  case RECONS_8:
160  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_8;
161  break;
162  case RECONS_12:
163  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_12;
164  break;
165  default:
166  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_12;
167  break;
168  };
169 
170  q_gauge_param.cuda_prec_sloppy = gpu_half_prec;
171 
172  switch( invParam.cudaSloppyReconstruct ) {
173  case RECONS_NONE:
174  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_NO;
175  break;
176  case RECONS_8:
177  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_8;
178  break;
179  case RECONS_12:
180  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12;
181  break;
182  default:
183  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12;
184  break;
185  };
186 
187  // Gauge fixing:
188 
189  // These are the links
190  // They may be smeared and the BC's may be applied
191  Q links_single(Nd);
192 
193  // Now downcast to single prec fields.
194  for(int mu=0; mu < Nd; mu++) {
195  links_single[mu] = (state_->getLinks())[mu];
196  }
197 
198  // GaugeFix
199  if( invParam.axialGaugeP ) {
200  QDPIO::cout << "Fixing Temporal Gauge" << std::endl;
201  temporalGauge(links_single, GFixMat, Nd-1);
202  for(int mu=0; mu < Nd; mu++){
203  links_single[mu] = GFixMat*(state_->getLinks())[mu]*adj(shift(GFixMat, FORWARD, mu));
204  }
205  q_gauge_param.gauge_fix = QUDA_GAUGE_FIXED_YES;
206  }
207  else {
208  // No GaugeFix
209  q_gauge_param.gauge_fix = QUDA_GAUGE_FIXED_NO; // No Gfix yet
210  }
211 
212  // deferred 4) Gauge Anisotropy
213  const AnisoParam_t& aniso = invParam.WilsonParams.anisoParam;
214  if( aniso.anisoP ) { // Anisotropic case
215  Real gamma_f = aniso.xi_0 / aniso.nu;
216  q_gauge_param.anisotropy = toDouble(gamma_f);
217  }
218  else {
219  q_gauge_param.anisotropy = 1.0;
220  }
221 
222  // MAKE FSTATE BEFORE RESCALING links_single
223  // Because the clover term expects the unrescaled links...
224  Handle<FermState<T,Q,Q> > fstate( new PeriodicFermState<T,Q,Q>(links_single));
225 
226  if( aniso.anisoP ) { // Anisotropic case
227  multi1d<Real> cf=makeFermCoeffs(aniso);
228  for(int mu=0; mu < Nd; mu++) {
229  links_single[mu] *= cf[mu];
230  }
231  }
232 
233  // Now onto the inv param:
234  // Dslash type
235  quda_inv_param.dslash_type = QUDA_WILSON_DSLASH;
236 
237  // Invert type:
238  switch( invParam.solverType ) {
239  case CG:
240  quda_inv_param.inv_type = QUDA_CG_INVERTER;
241  solver_string = "CG";
242  break;
243  case BICGSTAB:
244  quda_inv_param.inv_type = QUDA_BICGSTAB_INVERTER;
245  solver_string = "BICGSTAB";
246  break;
247  case GCR:
248  quda_inv_param.inv_type = QUDA_GCR_INVERTER;
249  solver_string = "GCR";
250  break;
251  default:
252  quda_inv_param.inv_type = QUDA_CG_INVERTER;
253  solver_string = "CG";
254  break;
255  }
256 
257 
258  Real massParam = Real(1) + Real(3)/Real(q_gauge_param.anisotropy) + invParam.WilsonParams.Mass;
259 
260  quda_inv_param.kappa = 1.0/(2*toDouble(massParam));
261 
262  // FIXME: If QUDA ever starts to compute our clover term we will need to fix this
263  // Right now it is a dummy value, since we pass in the clover term
264  quda_inv_param.clover_coeff = 1.0;
265 
266  quda_inv_param.mass_normalization = QUDA_ASYMMETRIC_MASS_NORMALIZATION;
267 
268  quda_inv_param.tol = toDouble(invParam.RsdTarget);
269  quda_inv_param.maxiter = invParam.MaxIter;
270  quda_inv_param.reliable_delta = toDouble(invParam.Delta);
271  quda_inv_param.pipeline = invParam.Pipeline;
272 
273  // Solution type
274  quda_inv_param.solution_type = QUDA_MATPCDAG_MATPC_SOLUTION;
275 
276  // Solve type
277  switch( invParam.solverType ) {
278  case CG:
279  quda_inv_param.solve_type = QUDA_NORMOP_PC_SOLVE;
280  break;
281  case BICGSTAB:
282  quda_inv_param.solve_type = QUDA_DIRECT_PC_SOLVE;
283  break;
284  case GCR:
285  quda_inv_param.solve_type = QUDA_DIRECT_PC_SOLVE;
286  break;
287  case MR:
288  quda_inv_param.solve_type = QUDA_DIRECT_PC_SOLVE;
289  break;
290 
291  default:
292  quda_inv_param.solve_type = QUDA_NORMOP_PC_SOLVE;
293 
294  break;
295  }
296 
297 
298  if ( invParam.asymmetricP ) {
299  QDPIO::cout << "Using asymmetric preconditioning" << std::endl;
300  quda_inv_param.matpc_type = QUDA_MATPC_ODD_ODD_ASYMMETRIC;
301  }
302  else {
303  QDPIO::cout << "Using symmetric preconditioning" << std::endl;
304  quda_inv_param.matpc_type = QUDA_MATPC_ODD_ODD;
305  }
306 
307  quda_inv_param.dagger = QUDA_DAG_NO;
308 
309 
310  quda_inv_param.cpu_prec = cpu_prec;
311  quda_inv_param.cuda_prec = gpu_prec;
312  quda_inv_param.cuda_prec_sloppy = gpu_half_prec;
313  quda_inv_param.preserve_source = QUDA_PRESERVE_SOURCE_YES;
314  quda_inv_param.use_init_guess = QUDA_USE_INIT_GUESS_NO;
315  quda_inv_param.dirac_order = QUDA_DIRAC_ORDER;
316  quda_inv_param.gamma_basis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS;
317 
318  // Autotuning
319  if( invParam.tuneDslashP ) {
320  QDPIO::cout << "Enabling Dslash Autotuning" << std::endl;
321 
322  quda_inv_param.tune = QUDA_TUNE_YES;
323  }
324  else {
325  QDPIO::cout << "Disabling Dslash Autotuning" << std::endl;
326 
327  quda_inv_param.tune = QUDA_TUNE_NO;
328  }
329 
330 
331  // Setup padding
332  multi1d<int> face_size(4);
333  face_size[0] = latdims[1]*latdims[2]*latdims[3]/2;
334  face_size[1] = latdims[0]*latdims[2]*latdims[3]/2;
335  face_size[2] = latdims[0]*latdims[1]*latdims[3]/2;
336  face_size[3] = latdims[0]*latdims[1]*latdims[2]/2;
337 
338  int max_face = face_size[0];
339  for(int i=1; i <=3; i++) {
340  if ( face_size[i] > max_face ) {
341  max_face = face_size[i];
342  }
343  }
344 
345 
346  q_gauge_param.ga_pad = max_face;
347  quda_inv_param.sp_pad = 0;
348  quda_inv_param.cl_pad = 0;
349 
350  if( invParam.innerParamsP ) {
351  QDPIO::cout << "Setting inner solver params" << std::endl;
352  // Dereference handle
353  GCRInnerSolverParams ip = *(invParam.innerParams);
354 
355  // Set preconditioner precision
356  switch( ip.precPrecondition ) {
357  case HALF:
358  quda_inv_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
359  q_gauge_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
360  break;
361 
362  case SINGLE:
363  quda_inv_param.cuda_prec_precondition = QUDA_SINGLE_PRECISION;
364  q_gauge_param.cuda_prec_precondition = QUDA_SINGLE_PRECISION;
365  break;
366 
367  case DOUBLE:
368  quda_inv_param.cuda_prec_precondition = QUDA_DOUBLE_PRECISION;
369  q_gauge_param.cuda_prec_precondition = QUDA_DOUBLE_PRECISION;
370  break;
371  default:
372  quda_inv_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
373  q_gauge_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
374  break;
375  }
376 
377  switch( ip.reconstructPrecondition ) {
378  case RECONS_NONE:
379  q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_NO;
380  break;
381  case RECONS_8:
382  q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_8;
383  break;
384  case RECONS_12:
385  q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_12;
386  break;
387  default:
388  q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_12;
389  break;
390  };
391 
392  quda_inv_param.tol_precondition = toDouble(ip.tolPrecondition);
393  quda_inv_param.maxiter_precondition = ip.maxIterPrecondition;
394  quda_inv_param.gcrNkrylov = ip.gcrNkrylov;
395  switch( ip.schwarzType ) {
396  case ADDITIVE_SCHWARZ :
397  quda_inv_param.schwarz_type = QUDA_ADDITIVE_SCHWARZ;
398  break;
400  quda_inv_param.schwarz_type = QUDA_MULTIPLICATIVE_SCHWARZ;
401  break;
402  default:
403  quda_inv_param.schwarz_type = QUDA_ADDITIVE_SCHWARZ;
404  break;
405  }
406  quda_inv_param.precondition_cycle = ip.preconditionCycle;
407 
408  if( ip.verboseInner ) {
409  quda_inv_param.verbosity_precondition = QUDA_VERBOSE;
410  }
411  else {
412  quda_inv_param.verbosity_precondition = QUDA_SILENT;
413  }
414 
415  switch( ip.invTypePrecondition ) {
416  case CG:
417  quda_inv_param.inv_type_precondition = QUDA_CG_INVERTER;
418  break;
419  case BICGSTAB:
420  quda_inv_param.inv_type_precondition = QUDA_BICGSTAB_INVERTER;
421 
422  break;
423  case MR:
424  quda_inv_param.inv_type_precondition= QUDA_MR_INVERTER;
425  break;
426 
427  default:
428  quda_inv_param.inv_type_precondition = QUDA_MR_INVERTER;
429  break;
430  }
431  }
432  else {
433  QDPIO::cout << "Setting Precondition stuff to defaults for not using" << std::endl;
434  quda_inv_param.inv_type_precondition= QUDA_INVALID_INVERTER;
435  quda_inv_param.tol_precondition = 1.0e-1;
436  quda_inv_param.maxiter_precondition = 1000;
437  quda_inv_param.verbosity_precondition = QUDA_SILENT;
438  quda_inv_param.gcrNkrylov = 1;
439  }
440 
441 
442  if( invParam.verboseP ) {
443  quda_inv_param.verbosity = QUDA_VERBOSE;
444  }
445  else {
446  quda_inv_param.verbosity = QUDA_SUMMARIZE;
447  }
448 
449  // Set up the links
450  void* gauge[4];
451 
452  for(int mu=0; mu < Nd; mu++) {
453  gauge[mu] = (void *)&(links_single[mu].elem(all.start()).elem().elem(0,0).real());
454 
455  }
456  loadGaugeQuda((void *)gauge, &q_gauge_param);
457 
458 
459  }
460 
461 
462  //! Destructor is automatic
463  ~MdagMSysSolverQUDAWilson()
464  {
465  QDPIO::cout << "Destructing" << std::endl;
466  freeGaugeQuda();
467  }
468 
469  //! Return the subset on which the operator acts
470  const Subset& subset() const {return A->subset();}
471 
472  //! Solver the linear system
473  /*!
474  * \param psi solution ( Modify )
475  * \param chi source ( Read )
476  * \return syssolver results
477  */
478  SystemSolverResults_t operator() (T& psi, const T& chi ) const
479  {
480  SystemSolverResults_t res;
481 
482  START_CODE();
483  StopWatch swatch;
484  swatch.start();
485 
486  // T MdagChi;
487 
488  // This is a CGNE. So create new RHS
489  // (*A)(MdagChi, chi, MINUS);
490  // Handle< LinearOperator<T> > MM(new MdagMMdagM<T>(A));
491  if ( invParam.axialGaugeP ) {
492  T g_chi,g_psi;
493 
494  // Gauge Fix source and initial guess
495  QDPIO::cout << "Gauge Fixing source and initial guess" << std::endl;
496  g_chi[ rb[1] ] = GFixMat * chi;
497  g_psi[ rb[1] ] = GFixMat * psi;
498  QDPIO::cout << "Solving" << std::endl;
499  res = qudaInvert(g_chi,
500  g_psi);
501  QDPIO::cout << "Untransforming solution." << std::endl;
502  psi[ rb[1]] = adj(GFixMat)*g_psi;
503 
504  }
505  else {
506  QDPIO::cout << "Calling QUDA Invert" << std::endl;
507  res = qudaInvert(chi,
508  psi);
509  }
510 
511  swatch.stop();
512 
513 
514  {
515  T r;
516  r[A->subset()]=chi;
517  T tmp,tmp2;
518  (*A)(tmp, psi, PLUS);
519  (*A)(tmp2, tmp, MINUS);
520  r[A->subset()] -= tmp2;
521  res.resid = sqrt(norm2(r, A->subset()));
522  }
523 
524  Double rel_resid = res.resid/sqrt(norm2(chi,A->subset()));
525 
526  QDPIO::cout << "QUDA_"<< solver_string <<"_CLOVER_SOLVER: " << res.n_count << " iterations. Rsd = " << res.resid << " Relative Rsd = " << rel_resid << std::endl;
527 
528  // Convergence Check/Blow Up
529  if ( ! invParam.SilentFailP ) {
530  if ( toBool( rel_resid > invParam.RsdToleranceFactor*invParam.RsdTarget) ) {
531  QDPIO::cerr << "ERROR: QUDA Solver residuum is outside tolerance: QUDA resid="<< rel_resid << " Desired =" << invParam.RsdTarget << " Max Tolerated = " << invParam.RsdToleranceFactor*invParam.RsdTarget << std::endl;
532  QDP_abort(1);
533  }
534  }
535 
536  END_CODE();
537  return res;
538  }
539 
540 
541  SystemSolverResults_t operator() (T& psi, const T& chi, Chroma::AbsChronologicalPredictor4D<T>& predictor ) const
542  {
543  SystemSolverResults_t res;
544 
545  START_CODE();
546  StopWatch swatch;
547  swatch.start();
548  {
549  Handle< LinearOperator<T> > MdagM( new MdagMLinOp<T>(A) );
550  predictor(psi, (*MdagM), chi);
551  }
552  res = (*this)(psi, chi);
553  predictor.newVector(psi);
554  swatch.stop();
555  double time = swatch.getTimeInSeconds();
556  QDPIO::cout << "QUDA_"<< solver_string <<"_CLOVER_SOLVER: Total time (with prediction)=" << time << std::endl;
557  END_CODE();
558  return res;
559  }
560 
561 
562  private:
563  // Hide default constructor
564  MdagMSysSolverQUDAWilson() {}
565 
566 #if 1
567  Q links_orig;
568 #endif
569 
570  U GFixMat;
571  QudaPrecision_s cpu_prec;
572  QudaPrecision_s gpu_prec;
573  QudaPrecision_s gpu_half_prec;
574 
575  Handle< LinearOperator<T> > A;
576  const SysSolverQUDAWilsonParams invParam;
577  QudaGaugeParam q_gauge_param;
578  QudaInvertParam quda_inv_param;
579 
580  SystemSolverResults_t qudaInvert(const T& chi_s,
581  T& psi_s
582  )const ;
583 
584  std::string solver_string;
585  };
586 
587 
588 } // End namespace
589 
590 #endif // BUILD_QUDA
591 #endif
592 
Anisotropy parameters.
Abstract interface for a Chronological Solution predictor.
virtual void newVector(const T &psi)=0
int mu
Definition: cool.cc:24
void temporalGauge(multi1d< LatticeColorMatrix > &ug, LatticeColorMatrix &g, int decay_dir)
Temporal gauge fixing.
Class for counted reference semantics.
Linear Operators.
M^dag*M composition of a linear operator.
Nd
Definition: meslate.cc:74
Double tmp2
Definition: mesq.cc:30
bool registerAll()
Register all the factories.
multi1d< Hadron2PtContraction_t > operator()(const multi1d< LatticeColorMatrix > &u)
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
LinOpSysSolverMGProtoClover::Q Q
LatticeFermion tmp
Definition: mespbg5p_w.cc:36
LinOpSysSolverMGProtoClover::T T
int i
Definition: pbg5p_w.cc:55
@ MINUS
Definition: chromabase.h:45
@ PLUS
Definition: chromabase.h:45
multi1d< LatticeFermion > chi(Ncb)
LatticeFermion psi
Definition: mespbg5p_w.cc:35
START_CODE()
A(A, psi, r, Ncb, PLUS)
multi1d< Real > makeFermCoeffs(const AnisoParam_t &aniso)
Make fermion coefficients.
Definition: aniso_io.cc:63
multi1d< LatticeFermion > s(Ncb)
@ ADDITIVE_SCHWARZ
Definition: enum_quda_io.h:103
@ MULTIPLICATIVE_SCHWARZ
Definition: enum_quda_io.h:104
@ RECONS_12
Definition: enum_quda_io.h:80
@ RECONS_NONE
Definition: enum_quda_io.h:78
@ RECONS_8
Definition: enum_quda_io.h:79
FloatingPoint< double > Double
Definition: gtest.h:7351
::std::string string
Definition: gtest.h:1979
Periodic ferm state and a creator.
#define FORWARD
Definition: primitives.h:82
Reunitarize in place a color matrix to SU(N)
Simple fermionic BC.
Support class for fermion actions and linear operators.
Linear system solvers.
Handle< LinearOperator< T > > MdagM
multi1d< LatticeColorMatrix > U
LatticeFermion T
Definition: t_clover.cc:11
multi1d< LatticeColorMatrix > Q
Definition: t_clover.cc:12
multi1d< LatticeColorMatrixF > QF
Definition: t_quda_tprec.cc:19
LatticeColorMatrixF UF
Definition: t_quda_tprec.cc:18
LatticeFermionD TD
Definition: t_quda_tprec.cc:22
LatticeColorMatrixD UD
Definition: t_quda_tprec.cc:23
LatticeFermionF TF
Definition: t_quda_tprec.cc:17
multi1d< LatticeColorMatrixD > QD
Definition: t_quda_tprec.cc:24
Axial gauge fixing.