6 #ifndef __syssolver_mdagm_quda_wilson_h__
7 #define __syssolver_mdagm_quda_wilson_h__
35 namespace MdagMSysSolverQUDAWilsonEnv
48 class MdagMSysSolverQUDAWilson :
public MdagMSystemSolver<LatticeFermion>
51 typedef LatticeFermion
T;
52 typedef LatticeColorMatrix
U;
53 typedef multi1d<LatticeColorMatrix>
Q;
55 typedef LatticeFermionF
TF;
56 typedef LatticeColorMatrixF
UF;
57 typedef multi1d<LatticeColorMatrixF>
QF;
59 typedef LatticeFermionF
TD;
60 typedef LatticeColorMatrixF
UD;
61 typedef multi1d<LatticeColorMatrixF>
QD;
63 typedef WordType<T>::Type_t REALT;
69 MdagMSysSolverQUDAWilson(Handle< LinearOperator<T> > A_,
70 Handle< FermState<T,Q,Q> > state_,
71 const SysSolverQUDAWilsonParams& invParam_) :
72 A(A_), invParam(invParam_)
74 QDPIO::cout <<
"MdagMSysSolverQUDAWilson:" << std::endl;
79 int s =
sizeof( WordType<T>::Type_t );
81 cpu_prec = QUDA_SINGLE_PRECISION;
84 cpu_prec = QUDA_DOUBLE_PRECISION;
89 switch( invParam.cudaPrecision ) {
91 gpu_prec = QUDA_HALF_PRECISION;
94 gpu_prec = QUDA_SINGLE_PRECISION;
97 gpu_prec = QUDA_DOUBLE_PRECISION;
106 switch( invParam.cudaSloppyPrecision ) {
108 gpu_half_prec = QUDA_HALF_PRECISION;
111 gpu_half_prec = QUDA_SINGLE_PRECISION;
114 gpu_half_prec = QUDA_DOUBLE_PRECISION;
117 gpu_half_prec = gpu_prec;
122 q_gauge_param = newQudaGaugeParam();
123 quda_inv_param = newQudaInvertParam();
126 const multi1d<int>& latdims = Layout::subgridLattSize();
128 q_gauge_param.X[0] = latdims[0];
129 q_gauge_param.X[1] = latdims[1];
130 q_gauge_param.X[2] = latdims[2];
131 q_gauge_param.X[3] = latdims[3];
136 q_gauge_param.type = QUDA_WILSON_LINKS;
137 q_gauge_param.gauge_order = QUDA_QDP_GAUGE_ORDER;
143 if( invParam.AntiPeriodicT ) {
144 q_gauge_param.t_boundary = QUDA_ANTI_PERIODIC_T;
147 q_gauge_param.t_boundary = QUDA_PERIODIC_T;
151 q_gauge_param.cpu_prec = cpu_prec;
152 q_gauge_param.cuda_prec = gpu_prec;
155 switch( invParam.cudaReconstruct ) {
157 q_gauge_param.reconstruct = QUDA_RECONSTRUCT_NO;
160 q_gauge_param.reconstruct = QUDA_RECONSTRUCT_8;
163 q_gauge_param.reconstruct = QUDA_RECONSTRUCT_12;
166 q_gauge_param.reconstruct = QUDA_RECONSTRUCT_12;
170 q_gauge_param.cuda_prec_sloppy = gpu_half_prec;
172 switch( invParam.cudaSloppyReconstruct ) {
174 q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_NO;
177 q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_8;
180 q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12;
183 q_gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12;
195 links_single[
mu] = (state_->getLinks())[
mu];
199 if( invParam.axialGaugeP ) {
200 QDPIO::cout <<
"Fixing Temporal Gauge" << std::endl;
203 links_single[
mu] = GFixMat*(state_->getLinks())[
mu]*adj(shift(GFixMat,
FORWARD,
mu));
205 q_gauge_param.gauge_fix = QUDA_GAUGE_FIXED_YES;
209 q_gauge_param.gauge_fix = QUDA_GAUGE_FIXED_NO;
213 const AnisoParam_t& aniso = invParam.WilsonParams.anisoParam;
215 Real gamma_f = aniso.xi_0 / aniso.nu;
216 q_gauge_param.anisotropy = toDouble(gamma_f);
219 q_gauge_param.anisotropy = 1.0;
224 Handle<FermState<T,Q,Q> > fstate(
new PeriodicFermState<T,Q,Q>(links_single));
229 links_single[
mu] *= cf[
mu];
235 quda_inv_param.dslash_type = QUDA_WILSON_DSLASH;
238 switch( invParam.solverType ) {
240 quda_inv_param.inv_type = QUDA_CG_INVERTER;
241 solver_string =
"CG";
244 quda_inv_param.inv_type = QUDA_BICGSTAB_INVERTER;
245 solver_string =
"BICGSTAB";
248 quda_inv_param.inv_type = QUDA_GCR_INVERTER;
249 solver_string =
"GCR";
252 quda_inv_param.inv_type = QUDA_CG_INVERTER;
253 solver_string =
"CG";
258 Real massParam = Real(1) + Real(3)/Real(q_gauge_param.anisotropy) + invParam.WilsonParams.Mass;
260 quda_inv_param.kappa = 1.0/(2*toDouble(massParam));
264 quda_inv_param.clover_coeff = 1.0;
266 quda_inv_param.mass_normalization = QUDA_ASYMMETRIC_MASS_NORMALIZATION;
268 quda_inv_param.tol = toDouble(invParam.RsdTarget);
269 quda_inv_param.maxiter = invParam.MaxIter;
270 quda_inv_param.reliable_delta = toDouble(invParam.Delta);
271 quda_inv_param.pipeline = invParam.Pipeline;
274 quda_inv_param.solution_type = QUDA_MATPCDAG_MATPC_SOLUTION;
277 switch( invParam.solverType ) {
279 quda_inv_param.solve_type = QUDA_NORMOP_PC_SOLVE;
282 quda_inv_param.solve_type = QUDA_DIRECT_PC_SOLVE;
285 quda_inv_param.solve_type = QUDA_DIRECT_PC_SOLVE;
288 quda_inv_param.solve_type = QUDA_DIRECT_PC_SOLVE;
292 quda_inv_param.solve_type = QUDA_NORMOP_PC_SOLVE;
298 if ( invParam.asymmetricP ) {
299 QDPIO::cout <<
"Using asymmetric preconditioning" << std::endl;
300 quda_inv_param.matpc_type = QUDA_MATPC_ODD_ODD_ASYMMETRIC;
303 QDPIO::cout <<
"Using symmetric preconditioning" << std::endl;
304 quda_inv_param.matpc_type = QUDA_MATPC_ODD_ODD;
307 quda_inv_param.dagger = QUDA_DAG_NO;
310 quda_inv_param.cpu_prec = cpu_prec;
311 quda_inv_param.cuda_prec = gpu_prec;
312 quda_inv_param.cuda_prec_sloppy = gpu_half_prec;
313 quda_inv_param.preserve_source = QUDA_PRESERVE_SOURCE_YES;
314 quda_inv_param.use_init_guess = QUDA_USE_INIT_GUESS_NO;
315 quda_inv_param.dirac_order = QUDA_DIRAC_ORDER;
316 quda_inv_param.gamma_basis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS;
319 if( invParam.tuneDslashP ) {
320 QDPIO::cout <<
"Enabling Dslash Autotuning" << std::endl;
322 quda_inv_param.tune = QUDA_TUNE_YES;
325 QDPIO::cout <<
"Disabling Dslash Autotuning" << std::endl;
327 quda_inv_param.tune = QUDA_TUNE_NO;
332 multi1d<int> face_size(4);
333 face_size[0] = latdims[1]*latdims[2]*latdims[3]/2;
334 face_size[1] = latdims[0]*latdims[2]*latdims[3]/2;
335 face_size[2] = latdims[0]*latdims[1]*latdims[3]/2;
336 face_size[3] = latdims[0]*latdims[1]*latdims[2]/2;
338 int max_face = face_size[0];
339 for(
int i=1;
i <=3;
i++) {
340 if ( face_size[
i] > max_face ) {
341 max_face = face_size[
i];
346 q_gauge_param.ga_pad = max_face;
347 quda_inv_param.sp_pad = 0;
348 quda_inv_param.cl_pad = 0;
350 if( invParam.innerParamsP ) {
351 QDPIO::cout <<
"Setting inner solver params" << std::endl;
353 GCRInnerSolverParams ip = *(invParam.innerParams);
356 switch( ip.precPrecondition ) {
358 quda_inv_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
359 q_gauge_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
363 quda_inv_param.cuda_prec_precondition = QUDA_SINGLE_PRECISION;
364 q_gauge_param.cuda_prec_precondition = QUDA_SINGLE_PRECISION;
368 quda_inv_param.cuda_prec_precondition = QUDA_DOUBLE_PRECISION;
369 q_gauge_param.cuda_prec_precondition = QUDA_DOUBLE_PRECISION;
372 quda_inv_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
373 q_gauge_param.cuda_prec_precondition = QUDA_HALF_PRECISION;
377 switch( ip.reconstructPrecondition ) {
379 q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_NO;
382 q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_8;
385 q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_12;
388 q_gauge_param.reconstruct_precondition = QUDA_RECONSTRUCT_12;
392 quda_inv_param.tol_precondition = toDouble(ip.tolPrecondition);
393 quda_inv_param.maxiter_precondition = ip.maxIterPrecondition;
394 quda_inv_param.gcrNkrylov = ip.gcrNkrylov;
395 switch( ip.schwarzType ) {
397 quda_inv_param.schwarz_type = QUDA_ADDITIVE_SCHWARZ;
400 quda_inv_param.schwarz_type = QUDA_MULTIPLICATIVE_SCHWARZ;
403 quda_inv_param.schwarz_type = QUDA_ADDITIVE_SCHWARZ;
406 quda_inv_param.precondition_cycle = ip.preconditionCycle;
408 if( ip.verboseInner ) {
409 quda_inv_param.verbosity_precondition = QUDA_VERBOSE;
412 quda_inv_param.verbosity_precondition = QUDA_SILENT;
415 switch( ip.invTypePrecondition ) {
417 quda_inv_param.inv_type_precondition = QUDA_CG_INVERTER;
420 quda_inv_param.inv_type_precondition = QUDA_BICGSTAB_INVERTER;
424 quda_inv_param.inv_type_precondition= QUDA_MR_INVERTER;
428 quda_inv_param.inv_type_precondition = QUDA_MR_INVERTER;
433 QDPIO::cout <<
"Setting Precondition stuff to defaults for not using" << std::endl;
434 quda_inv_param.inv_type_precondition= QUDA_INVALID_INVERTER;
435 quda_inv_param.tol_precondition = 1.0e-1;
436 quda_inv_param.maxiter_precondition = 1000;
437 quda_inv_param.verbosity_precondition = QUDA_SILENT;
438 quda_inv_param.gcrNkrylov = 1;
442 if( invParam.verboseP ) {
443 quda_inv_param.verbosity = QUDA_VERBOSE;
446 quda_inv_param.verbosity = QUDA_SUMMARIZE;
453 gauge[
mu] = (
void *)&(links_single[
mu].elem(all.start()).elem().elem(0,0).real());
456 loadGaugeQuda((
void *)gauge, &q_gauge_param);
463 ~MdagMSysSolverQUDAWilson()
465 QDPIO::cout <<
"Destructing" << std::endl;
470 const Subset& subset()
const {
return A->subset();}
480 SystemSolverResults_t res;
491 if ( invParam.axialGaugeP ) {
495 QDPIO::cout <<
"Gauge Fixing source and initial guess" << std::endl;
496 g_chi[ rb[1] ] = GFixMat *
chi;
497 g_psi[ rb[1] ] = GFixMat *
psi;
498 QDPIO::cout <<
"Solving" << std::endl;
499 res = qudaInvert(g_chi,
501 QDPIO::cout <<
"Untransforming solution." << std::endl;
502 psi[ rb[1]] = adj(GFixMat)*g_psi;
506 QDPIO::cout <<
"Calling QUDA Invert" << std::endl;
507 res = qudaInvert(
chi,
521 res.resid = sqrt(norm2(
r,
A->subset()));
524 Double rel_resid = res.resid/sqrt(norm2(
chi,
A->subset()));
526 QDPIO::cout <<
"QUDA_"<< solver_string <<
"_CLOVER_SOLVER: " << res.n_count <<
" iterations. Rsd = " << res.resid <<
" Relative Rsd = " << rel_resid << std::endl;
529 if ( ! invParam.SilentFailP ) {
530 if ( toBool( rel_resid > invParam.RsdToleranceFactor*invParam.RsdTarget) ) {
531 QDPIO::cerr <<
"ERROR: QUDA Solver residuum is outside tolerance: QUDA resid="<< rel_resid <<
" Desired =" << invParam.RsdTarget <<
" Max Tolerated = " << invParam.RsdToleranceFactor*invParam.RsdTarget << std::endl;
543 SystemSolverResults_t res;
549 Handle< LinearOperator<T> >
MdagM(
new MdagMLinOp<T>(
A) );
555 double time = swatch.getTimeInSeconds();
556 QDPIO::cout <<
"QUDA_"<< solver_string <<
"_CLOVER_SOLVER: Total time (with prediction)=" << time << std::endl;
564 MdagMSysSolverQUDAWilson() {}
571 QudaPrecision_s cpu_prec;
572 QudaPrecision_s gpu_prec;
573 QudaPrecision_s gpu_half_prec;
575 Handle< LinearOperator<T> >
A;
576 const SysSolverQUDAWilsonParams invParam;
577 QudaGaugeParam q_gauge_param;
578 QudaInvertParam quda_inv_param;
580 SystemSolverResults_t qudaInvert(
const T& chi_s,
Abstract interface for a Chronological Solution predictor.
virtual void newVector(const T &psi)=0
void temporalGauge(multi1d< LatticeColorMatrix > &ug, LatticeColorMatrix &g, int decay_dir)
Temporal gauge fixing.
Class for counted reference semantics.
M^dag*M composition of a linear operator.
bool registerAll()
Register all the factories.
multi1d< Hadron2PtContraction_t > operator()(const multi1d< LatticeColorMatrix > &u)
Asqtad Staggered-Dirac operator.
LinOpSysSolverMGProtoClover::Q Q
LinOpSysSolverMGProtoClover::T T
multi1d< LatticeFermion > chi(Ncb)
multi1d< Real > makeFermCoeffs(const AnisoParam_t &aniso)
Make fermion coefficients.
multi1d< LatticeFermion > s(Ncb)
FloatingPoint< double > Double
Periodic ferm state and a creator.
Reunitarize in place a color matrix to SU(N)
Support class for fermion actions and linear operators.
Handle< LinearOperator< T > > MdagM
multi1d< LatticeColorMatrix > U
multi1d< LatticeColorMatrix > Q
multi1d< LatticeColorMatrixF > QF
multi1d< LatticeColorMatrixD > QD