CHROMA
multi_syssolver_mdagm_cg_clover_quda_w.cc
Go to the documentation of this file.
1 /*! \file
2  * \brief Solve a MdagM*psi=chi linear system by CG2
3  */
4 
9 #include "quda.h"
10 
11 #include <cstdlib>
12 
13 namespace Chroma
14 {
15 
16  //! CG2 system solver namespace
17  namespace MdagMMultiSysSolverCGQudaCloverEnv
18  {
19  //! Callback function
21  const std::string& path,
22  Handle< FermState< LatticeFermion, multi1d<LatticeColorMatrix>, multi1d<LatticeColorMatrix> > > state,
24  {
25  return new MdagMMultiSysSolverCGQudaClover(A, state,MultiSysSolverQUDACloverParams(xml_in, path));
26  }
27 
28  //! Name to be used
29  const std::string name("MULTI_CG_QUDA_CLOVER_INVERTER");
30 
31  //! Local registration flag
32  static bool registered = false;
33 
34  //! Register all the factories
35  bool registerAll()
36  {
37  bool success = true;
38  if (! registered)
39  {
41  registered = true;
42  }
43  return success;
44  }
45  }
46 
47  SystemSolverResults_t
48  MdagMMultiSysSolverCGQudaClover::qudaInvertMulti(const T& chi_s,
49  multi1d<T>& psi_s,
50  const multi1d<Real> shifts) const{
51 
52  SystemSolverResults_t ret;
53 
54  void *spinorIn;
55 
56 #ifndef BUILD_QUDA_DEVIFACE_SPINOR
57  spinorIn =(void *)&(chi_s.elem(rb[1].start()).elem(0).elem(0).real());
58 #else
59  // have to do this later
60 #endif
61 
62  void** spinorOut = new void*[ shifts.size() ];
63  if (spinorOut == nullptr ) {
64  QDPIO::cerr << "Couldn't allocate spinorOut" << std::endl;
65  QDP_abort(1);
66  }
67 
68  if ( shifts.size() > QUDA_MAX_MULTI_SHIFT ) {
69  QDPIO::cerr << "You want more shifts than QUDA_MAX_MULTI_SHIFT" << std::endl;
70  QDPIO::cerr << "Requested : " << shifts.size() << " QUDA_MAX_MULTI_SHIFT=" << QUDA_MAX_MULTI_SHIFT << std::endl;
71  QDP_abort(1);
72  }
73 
74  psi_s.resize( shifts.size());
75 
76 #ifndef BUILD_QUDA_DEVIFACE_SPINOR
77  for(int s=0; s < shifts.size(); s++) {
78  //psi_s[s][ rb[1] ] = zero;
79  psi_s[s] = zero; // Sanity check
80  spinorOut[s] = (void *)&(psi_s[s].elem(rb[1].start()).elem(0).elem(0).real());
81  quda_inv_param.offset[s] = toDouble(shifts[s]);
82  }
83 #else
84  std::vector<QDPCache::ArgKey> ids = {chi_s.getId()};
85  for(int s=0; s < shifts.size(); s++) {
86  // psi_s[s][ rb[1] ] = zero;
87  psi_s[s] = zero; // Sanity Check
88  ids.push_back( psi_s[s].getId() );
89  quda_inv_param.offset[s] = toDouble(shifts[s]);
90  }
91  auto dev_ptr = GetMemoryPtr( ids );
92  spinorIn = dev_ptr[0];
93  for(int s=0; s < shifts.size(); s++) {
94  spinorOut[s] = dev_ptr[s+1];
95  }
96 #endif
97 
98  quda_inv_param.num_offset = shifts.size();
99 
100  if( invParam.RsdTarget.size() == 1 ) {
101  for (int i=0; i< quda_inv_param.num_offset; i++) quda_inv_param.tol_offset[i] = toDouble(invParam.RsdTarget[0]);
102  }
103  else {
104  for (int i=0; i< quda_inv_param.num_offset; i++) quda_inv_param.tol_offset[i] = toDouble(invParam.RsdTarget[i]);
105  }
106 
107  // Do the solve here
108  StopWatch swatch1;
109  swatch1.reset();
110  swatch1.start();
111  QDPIO::cout << "CALLING QUDA SOLVER" << std::endl << std::flush ;
112  invertMultiShiftQuda(spinorOut, spinorIn, (QudaInvertParam*)&quda_inv_param);
113  swatch1.stop();
114 
115  // Tidy Up
116  delete [] spinorOut;
117 
118 
119  QDPIO::cout << "QUDA_"<<solver_string<<"_CLOVER_SOLVER: time="<< quda_inv_param.secs <<" s" ;
120  QDPIO::cout << "\tPerformance="<< quda_inv_param.gflops/quda_inv_param.secs<<" GFLOPS" ;
121  QDPIO::cout << "\tTotal Time (incl. load gauge)=" << swatch1.getTimeInSeconds() <<" s"<<std::endl;
122 
123  ret.n_count =quda_inv_param.iter;
124 
125  return ret;
126 
127  }
128 
129 
130 }
131 
Support class for fermion actions and linear operators.
Definition: state.h:94
Class for counted reference semantics.
Definition: handle.h:33
static T & Instance()
Definition: singleton.h:432
std::string getId()
Get the default gauge field named object id.
Register MdagM system solvers.
Solve a MdagM*psi=chi linear system by CG2 using CG.
Factory for producing system solvers for MdagM*psi = chi.
MdagMMultiSystemSolver< LatticeFermion > * createFerm(XMLReader &xml_in, const std::string &path, Handle< FermState< LatticeFermion, multi1d< LatticeColorMatrix >, multi1d< LatticeColorMatrix > > > state, Handle< LinearOperator< LatticeFermion > > A)
Callback function.
const std::string name("MULTI_CG_QUDA_CLOVER_INVERTER")
Name to be used.
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
int i
Definition: pbg5p_w.cc:55
A(A, psi, r, Ncb, PLUS)
const WilsonTypeFermAct< multi1d< LatticeFermion > > Handle< const ConnectState > state
Definition: pbg5p_w.cc:28
Double zero
Definition: invbicg.cc:106
multi1d< LatticeFermion > s(Ncb)
::std::string string
Definition: gtest.h:1979
LatticeFermion T
Definition: t_clover.cc:11