CHROMA
multi_syssolver_mdagm_cg_wilson_quda_w.h
Go to the documentation of this file.
1 // -*- C++ -*-
2 /*! \file
3  * \brief Solve a MdagM*psi=chi linear system by CG2 using CG
4  */
5 
6 #ifndef __multi_syssolver_mdagm_cg_wilson_quda_w_h__
7 #define __multi_syssolver_mdagm_cg_wilson_quda_w_h__
8 
9 #include "chroma_config.h"
10 
11 #ifdef BUILD_QUDA
12 
13 #include "handle.h"
14 #include "syssolver.h"
15 #include "linearop.h"
16 #include "lmdagm.h"
21 #include "io/aniso_io.h"
22 #include <string>
23 
24 #include "util/gauge/reunit.h"
25 
26 #include <quda.h>
27 
28 namespace Chroma
29 {
30 
31  //! CG2 system solver namespace
32  namespace MdagMMultiSysSolverCGQudaWilsonEnv
33  {
34  //! Register the syssolver
35  bool registerAll();
36  }
37 
38  //! Solve a CG2 system. Here, the operator is NOT assumed to be hermitian
39  /*! \ingroup invert
40  */
41  class MdagMMultiSysSolverCGQudaWilson : public MdagMMultiSystemSolver<LatticeFermion>
42  {
43  public:
44  typedef LatticeFermion T;
45  typedef LatticeColorMatrix U;
46  typedef multi1d<LatticeColorMatrix> Q;
47  typedef multi1d<LatticeColorMatrix> P;
48 
49  typedef LatticeFermionF TF;
50  typedef LatticeColorMatrixF UF;
51  typedef multi1d<LatticeColorMatrixF> QF;
52  typedef multi1d<LatticeColorMatrixF> PF;
53 
54  typedef LatticeFermionD TD;
55  typedef LatticeColorMatrixD UD;
56  typedef multi1d<LatticeColorMatrixD> QD;
57  typedef multi1d<LatticeColorMatrixD> PD;
58 
59  typedef WordType<T>::Type_t REALT;
60  //! Constructor
61  /*!
62  * \param M_ Linear operator ( Read )
63  * \param invParam inverter parameters ( Read )
64  */
65  MdagMMultiSysSolverCGQudaWilson(Handle< LinearOperator<T> > M_,
66  Handle< FermState<T,P,Q> > state_,
67  const SysSolverQUDAWilsonParams& invParam_) :
68  A(M_), invParam(invParam_)
69 
70  {
71  QDPIO::cout << "MdagMMultiSysSolverCGQUDAWilson: " << std::endl;
72  // FOLLOWING INITIALIZATION in test QUDA program
73 
74  // 1) work out cpu_prec, cuda_prec, cuda_prec_sloppy
75  int s = sizeof( WordType<T>::Type_t );
76  if (s == 4) {
77  cpu_prec = QUDA_SINGLE_PRECISION;
78  }
79  else {
80  cpu_prec = QUDA_DOUBLE_PRECISION;
81  }
82 
83 
84  // Work out GPU precision
85  switch( invParam.cudaPrecision ) {
86  case HALF:
87  gpu_prec = QUDA_HALF_PRECISION;
88  break;
89  case SINGLE:
90  gpu_prec = QUDA_SINGLE_PRECISION;
91  break;
92  case DOUBLE:
93  gpu_prec = QUDA_DOUBLE_PRECISION;
94  break;
95  default:
96  gpu_prec = cpu_prec;
97  break;
98  }
99 
100  gpu_half_prec = gpu_prec;
101 
102  // 2) pull 'new; GAUGE and Invert params
103  //
104  QDPIO::cout << " Calling new QUDA Invert Param" << std::endl;
105  q_gauge_param = newQudaGaugeParam();
106  quda_inv_param = newQudaInvertParam();
107 
108  // 3) set lattice size
109  const multi1d<int>& latdims = Layout::subgridLattSize();
110 
111  q_gauge_param.X[0] = latdims[0];
112  q_gauge_param.X[1] = latdims[1];
113  q_gauge_param.X[2] = latdims[2];
114  q_gauge_param.X[3] = latdims[3];
115 
116  // 4) - deferred (anisotropy)
117 
118  // 5) - set QUDA_WILSON_LINKS, QUDA_GAUGE_ORDER
119  q_gauge_param.type = QUDA_WILSON_LINKS;
120  q_gauge_param.gauge_order = QUDA_QDP_GAUGE_ORDER; // gauge[mu], p
121 
122  // 6) - set t_boundary
123  // Convention: BC has to be applied already
124  // This flag just tells QUDA that this is so,
125  // so that QUDA can take care in the reconstruct
126  if( invParam.AntiPeriodicT ) {
127  q_gauge_param.t_boundary = QUDA_ANTI_PERIODIC_T;
128  }
129  else {
130  q_gauge_param.t_boundary = QUDA_PERIODIC_T;
131  }
132 
133  // Set cpu_prec, cuda_prec, reconstruct and sloppy versions
134  q_gauge_param.cpu_prec = cpu_prec;
135  q_gauge_param.cuda_prec = gpu_prec;
136 
137 
138  switch( invParam.cudaReconstruct ) {
139  case RECONS_NONE:
140  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_NO;
141  break;
142  case RECONS_8:
143  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_8;
144  break;
145  case RECONS_12:
146  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_12;
147  break;
148  default:
149  q_gauge_param.reconstruct = QUDA_RECONSTRUCT_12;
150  break;
151  };
152 
153  q_gauge_param.cuda_prec_sloppy = gpu_half_prec;
154 
155  switch( invParam.cudaSloppyReconstruct ) {
156  case RECONS_NONE:
157  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_NO;
158  break;
159  case RECONS_8:
160  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_8;
161  break;
162  case RECONS_12:
163  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12;
164  break;
165  default:
166  q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12;
167  break;
168  };
169 
170  // Gauge fixing:
171 
172  // These are the links
173  // They may be smeared and the BC's may be applied
174  Q links_single(Nd);
175 
176  // Now downcast to single prec fields.
177  for(int mu=0; mu < Nd; mu++) {
178  links_single[mu] = (state_->getLinks())[mu];
179  }
180 
181  // GaugeFix
182  if( invParam.axialGaugeP ) {
183  QDPIO::cout << "Fixing Temporal Gauge" << std::endl;
184  temporalGauge(links_single, GFixMat, Nd-1);
185  for(int mu=0; mu < Nd; mu++){
186  links_single[mu] = GFixMat*(state_->getLinks())[mu]*adj(shift(GFixMat, FORWARD, mu));
187  }
188  q_gauge_param.gauge_fix = QUDA_GAUGE_FIXED_YES;
189  }
190  else {
191  // No GaugeFix
192  q_gauge_param.gauge_fix = QUDA_GAUGE_FIXED_NO; // No Gfix yet
193  }
194 
195  // deferred 4) Gauge Anisotropy
196  const AnisoParam_t& aniso = invParam.WilsonParams.anisoParam;
197  if( aniso.anisoP ) { // Anisotropic case
198  Real gamma_f = aniso.xi_0 / aniso.nu;
199  q_gauge_param.anisotropy = toDouble(gamma_f);
200  }
201  else {
202  q_gauge_param.anisotropy = 1.0;
203  }
204 
205  // MAKE FSTATE BEFORE RESCALING links_single
206  // Because the clover term expects the unrescaled links...
207  Handle<FermState<T,Q,Q> > fstate( new PeriodicFermState<T,Q,Q>(links_single));
208 
209  if( aniso.anisoP ) { // Anisotropic case
210  multi1d<Real> cf=makeFermCoeffs(aniso);
211  for(int mu=0; mu < Nd; mu++) {
212  links_single[mu] *= cf[mu];
213  }
214  }
215 
216  // Now onto the inv param:
217  // Dslash type
218  quda_inv_param.dslash_type = QUDA_WILSON_DSLASH;
219  solver_string = "MULTI_CG";
220  quda_inv_param.inv_type = QUDA_CG_INVERTER;
221 
222  // Mass
223  Real massParam = Real(1) + Real(3)/Real(q_gauge_param.anisotropy) + invParam.WilsonParams.Mass;
224  quda_inv_param.kappa = 1.0/(2*toDouble(massParam));
225 
226 
227  // FIXME: We set clover coeff to a dummy value. This is dumb
228  // If we ever get QUDA to compute our clvoer term we will need to fix this.
229  // Right now it doesn't matter because we pass our own clover term
230  quda_inv_param.clover_coeff = 1.0; // dummy value
231 
232  quda_inv_param.mass_normalization = QUDA_ASYMMETRIC_MASS_NORMALIZATION;
233 
234  quda_inv_param.tol = toDouble(invParam.RsdTarget);
235  quda_inv_param.maxiter = invParam.MaxIter;
236  quda_inv_param.reliable_delta = toDouble(invParam.Delta);
237  quda_inv_param.pipeline = invParam.Pipeline;
238 
239  // Solution type
240  quda_inv_param.solution_type = QUDA_MATPCDAG_MATPC_SOLUTION;
241 
242  // Solve type
243  switch( invParam.solverType ) {
244  case CG:
245  quda_inv_param.solve_type = QUDA_NORMOP_PC_SOLVE;
246  break;
247  default:
248  QDPIO::cerr << "Only CG Is currently implemented for multi-shift" << std::endl;
249  QDP_abort(1);
250 
251  break;
252  }
253 
254  if( invParam.asymmetricP ) {
255  QDPIO::cout << "Asymmetric LinOP" << std::endl;
256  quda_inv_param.matpc_type = QUDA_MATPC_ODD_ODD_ASYMMETRIC;
257  }
258  else {
259  QDPIO::cout << "Symmetric LinOp" << std::endl;
260  quda_inv_param.matpc_type = QUDA_MATPC_ODD_ODD;
261  }
262 
263  quda_inv_param.dagger = QUDA_DAG_NO;
264 
265 
266  quda_inv_param.cpu_prec = cpu_prec;
267  quda_inv_param.cuda_prec = gpu_prec;
268  quda_inv_param.cuda_prec_sloppy = gpu_half_prec;
269  quda_inv_param.preserve_source = QUDA_PRESERVE_SOURCE_YES;
270  quda_inv_param.dirac_order = QUDA_DIRAC_ORDER;
271  quda_inv_param.gamma_basis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS;
272 
273  // Autotuning
274  if( invParam.tuneDslashP ) {
275  QDPIO::cout << "Enabling Dslash Autotuning" << std::endl;
276 
277  quda_inv_param.tune = QUDA_TUNE_YES;
278  }
279  else {
280  QDPIO::cout << "Disabling Dslash Autotuning" << std::endl;
281 
282  quda_inv_param.tune = QUDA_TUNE_NO;
283  }
284 
285  // PADDING
286 
287  // Setup padding
288  multi1d<int> face_size(4);
289  face_size[0] = latdims[1]*latdims[2]*latdims[3]/2;
290  face_size[1] = latdims[0]*latdims[2]*latdims[3]/2;
291  face_size[2] = latdims[0]*latdims[1]*latdims[3]/2;
292  face_size[3] = latdims[0]*latdims[1]*latdims[2]/2;
293 
294  int max_face = face_size[0];
295  for(int i=1; i <=3; i++) {
296  if ( face_size[i] > max_face ) {
297  max_face = face_size[i];
298  }
299  }
300 
301 
302  q_gauge_param.ga_pad = max_face;
303  quda_inv_param.sp_pad = 0;
304  quda_inv_param.cl_pad = 0;
305 
306 
307  // Setting GCR Preconditioner to defaults, as we don't use it..
308  // This is kinda yucky.
309 
310  QDPIO::cout << "Setting Precondition stuff to defaults for not using" << std::endl;
311  quda_inv_param.inv_type_precondition= QUDA_INVALID_INVERTER;
312  quda_inv_param.tol_precondition = 1.0e-1;
313  quda_inv_param.maxiter_precondition = 1000;
314  quda_inv_param.verbosity_precondition = QUDA_SILENT;
315  quda_inv_param.gcrNkrylov = 1;
316 
317  if( invParam.verboseP ) {
318  quda_inv_param.verbosity = QUDA_VERBOSE;
319  }
320  else {
321  quda_inv_param.verbosity = QUDA_SUMMARIZE;
322  }
323 
324  // Set up the links
325  void* gauge[4];
326 
327  for(int mu=0; mu < Nd; mu++) {
328  gauge[mu] = (void *)&(links_single[mu].elem(all.start()).elem().elem(0,0).real());
329 
330  }
331 
332  loadGaugeQuda((void *)gauge, &q_gauge_param);
333 
334  }
335 
336  //! Destructor is automatic
337  ~MdagMMultiSysSolverCGQudaWilson() {
338  QDPIO::cout << "Destructing" << std::endl;
339  freeGaugeQuda();
340  }
341 
342  //! Return the subset on which the operator acts
343  const Subset& subset() const {return A->subset();}
344 
345  //! Solver the linear system
346  /*!
347  * \param psi solution ( Modify )
348  * \param chi source ( Read )
349  * \return syssolver results
350  */
351  SystemSolverResults_t operator() (multi1d<T>& psi, const multi1d<Real>& shifts, const T& chi) const
352  {
353  START_CODE();
354  StopWatch swatch;
355  swatch.reset();
356  swatch.start();
357  SystemSolverResults_t res;
358  res.n_count = 0;
359 
360  if ( invParam.axialGaugeP ) {
361  T g_chi;
362  multi1d<T> g_psi(psi.size());
363 
364  // Gauge Fix source and initial guess
365  QDPIO::cout << "Gauge Fixing source and initial guess" << std::endl;
366  g_chi[ rb[1] ] = GFixMat * chi;
367  for(int s=0; s < psi.size(); s++) {
368  g_psi[s][ rb[1] ] = zero; // All initial guesses are zero
369  }
370 
371  QDPIO::cout << "Solving" << std::endl;
372  res = qudaInvertMulti(
373  g_chi,
374  g_psi,
375  shifts);
376  QDPIO::cout << "Untransforming solution." << std::endl;
377  for(int s=0; s< psi.size(); s++) {
378  psi[s][ rb[1]] = adj(GFixMat)*g_psi[s];
379  }
380 
381  }
382  else {
383 
384  res = qudaInvertMulti(chi,
385  psi,
386  shifts);
387 
388  }
389 
390  swatch.stop();
391  double time = swatch.getTimeInSeconds();
392 
393  if (invParam.verboseP ) {
394  Double chinorm=norm2(chi, A->subset());
395  multi1d<Double> r_rel(shifts.size());
396 
397  for(int i=0; i < shifts.size(); i++) {
398  T tmp1,tmp2;
399  tmp1 = zero;
400  tmp2 = zero;
401 
402  (*A)(tmp1, psi[i], PLUS);
403  (*A)(tmp2, tmp1, MINUS); // tmp2 = A^\dagger A psi
404  tmp2[ A->subset() ] += shifts[i]* psi[i]; // tmp2 = ( A^\dagger A + shift_i ) psi
405  T r;
406  r = zero;
407 
408  r[ A->subset() ] = chi - tmp2;
409  r_rel[i] = sqrt(norm2(r, A->subset())/chinorm );
410  QDPIO::cout << "r[" <<i <<"] = " << r_rel[i] << std::endl;
411  }
412  }
413  QDPIO::cout << "MULTI_CG_QUDA_CLOVER_SOLVER: " << res.n_count << " iterations. Rsd = " << res.resid << std::endl;
414  QDPIO::cout << "MULTI_CG_QUDA_CLOVER_SOLVER: "<<time<< " sec" << std::endl;
415  END_CODE();
416 
417  return res;
418  }
419 
420 
421  private:
422  // Hide default constructor
423  MdagMMultiSysSolverCGQudaWilson() {}
424  U GFixMat;
425  QudaPrecision_s cpu_prec;
426  QudaPrecision_s gpu_prec;
427  QudaPrecision_s gpu_half_prec;
428 
429  Handle< LinearOperator<T> > A;
430  const SysSolverQUDAWilsonParams invParam;
431  QudaGaugeParam q_gauge_param;
432  mutable QudaInvertParam quda_inv_param;
433 
434  SystemSolverResults_t qudaInvertMulti(const T& chi_s,
435  multi1d<T>& psi_s,
436  multi1d<Real> shifts
437  )const ;
438 
439  std::string solver_string;
440 
441  };
442 
443 
444 } // End namespace
445 
446 #endif
447 #endif
448 
Anisotropy parameters.
int mu
Definition: cool.cc:24
void temporalGauge(multi1d< LatticeColorMatrix > &ug, LatticeColorMatrix &g, int decay_dir)
Temporal gauge fixing.
Class for counted reference semantics.
Linear Operators.
M^dag*M composition of a linear operator.
Nd
Definition: meslate.cc:74
Double tmp2
Definition: mesq.cc:30
multi1d< LatticeColorMatrix > P
multi1d< Hadron2PtContraction_t > operator()(const multi1d< LatticeColorMatrix > &u)
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
LinOpSysSolverMGProtoClover::Q Q
LinOpSysSolverMGProtoClover::T T
int i
Definition: pbg5p_w.cc:55
@ MINUS
Definition: chromabase.h:45
@ PLUS
Definition: chromabase.h:45
multi1d< LatticeFermion > chi(Ncb)
LatticeFermion psi
Definition: mespbg5p_w.cc:35
START_CODE()
A(A, psi, r, Ncb, PLUS)
Double zero
Definition: invbicg.cc:106
multi1d< Real > makeFermCoeffs(const AnisoParam_t &aniso)
Make fermion coefficients.
Definition: aniso_io.cc:63
multi1d< LatticeFermion > s(Ncb)
@ RECONS_12
Definition: enum_quda_io.h:80
@ RECONS_NONE
Definition: enum_quda_io.h:78
@ RECONS_8
Definition: enum_quda_io.h:79
FloatingPoint< double > Double
Definition: gtest.h:7351
::std::string string
Definition: gtest.h:1979
Periodic ferm state and a creator.
#define FORWARD
Definition: primitives.h:82
Reunitarize in place a color matrix to SU(N)
Simple fermionic BC.
Linear system solvers.
multi1d< LatticeColorMatrix > U
LatticeFermion T
Definition: t_clover.cc:11
multi1d< LatticeColorMatrix > Q
Definition: t_clover.cc:12
multi1d< LatticeColorMatrixF > QF
Definition: t_quda_tprec.cc:19
LatticeColorMatrixF UF
Definition: t_quda_tprec.cc:18
multi1d< LatticeColorMatrixF > PF
Definition: t_quda_tprec.cc:20
LatticeFermionD TD
Definition: t_quda_tprec.cc:22
LatticeColorMatrixD UD
Definition: t_quda_tprec.cc:23
LatticeFermionF TF
Definition: t_quda_tprec.cc:17
multi1d< LatticeColorMatrixD > PD
Definition: t_quda_tprec.cc:25
multi1d< LatticeColorMatrixD > QD
Definition: t_quda_tprec.cc:24
Axial gauge fixing.