CHROMA
syssolver_linop_mdwf_array.cc
Go to the documentation of this file.
1 // -*- C++ -*-
2 /*! \file
3  * \brief DWF/SSE double-prec solver
4  */
7 
9 
10 extern "C" {
11 #include <qop-mdwf3.h>
12 };
13 
14 #include "io/aniso_io.h"
16 
17 
18 using namespace QDP;
19 namespace Chroma
20 {
21 
22  namespace LinOpSysSolverMDWFArrayEnv
23  {
24  //! Callback function
26  const std::string& path,
28  LatticeFermion,
29  multi1d<LatticeColorMatrix>,
30  multi1d<LatticeColorMatrix>
31  >
32  > state,
34  {
35  return new LinOpSysSolverMDWFArray(A, state, SysSolverMDWFParams(xml_in, path));
36  }
37 
38  //! Name to be used
39  const std::string name("MDWF_INVERTER");
40 
41  //! Local registration flag
42  static bool registered = false;
43 
44  //! Register all the factories
45  bool registerAll()
46  {
47  bool success = true;
48  if (! registered)
49  {
51  registered = true;
52  }
53  return success;
54  }
55  }
56 
57 
58  //! AVP's DWF Solver interface
59  /*!
60  * \ingroup qprop
61  *
62  * @{
63  */
64 
65  /* Utility functions -- stick these in local anonymous namespace */
66 
67  namespace {
68 
69  double gaugeReader(int mu,
70  const int latt_coord[4],
71  int row,
72  int col,
73  int reim,
74  void *env)
75  {
76  START_CODE();
77  /* Translate arg */
78  multi1d<LatticeColorMatrix>& u = *(multi1d<LatticeColorMatrix>*)env;
79 
80  // Get node and index
81  multi1d<int> coord(Nd);
82  coord = latt_coord;
83  int node = Layout::nodeNumber(coord);
84  int linear = Layout::linearSiteIndex(coord);
85 
86  if (node != Layout::nodeNumber()) {
87 
88  QDPIO::cerr << __func__ << ": wrong coordinates for this node" << std::endl;
89  QDP_abort(1);
90  }
91 
92  // Get the value
93  // NOTE: it would be nice to use the "peek" functions, but they will
94  // broadcast to all nodes the value since they are platform independent.
95  // We don't want that, so we poke into the on-node data
96  double val = (reim == 0) ?
97  toDouble(u[mu].elem(linear).elem().elem(row,col).real()) :
98  toDouble(u[mu].elem(linear).elem().elem(row,col).imag());
99 
100 
101  END_CODE();
102  return val;
103 
104  }
105 
106 
107  // Fermion Reader function - user supplied
108  double fermionReader(const int latt_coord[5],
109  int color,
110  int spin,
111  int reim,
112  void *env)
113  {
114  START_CODE();
115 
116  /* Translate arg */
117  multi1d<LatticeFermion>& psi = *(multi1d<LatticeFermion>*)env;
118 
119  // Get node and index
120  int s = latt_coord[Nd];
121  multi1d<int> coord(Nd);
122  coord = latt_coord;
123  int node = Layout::nodeNumber(coord);
124  int linear = Layout::linearSiteIndex(coord);
125 
126  if (node != Layout::nodeNumber()) {
127 
128  QDPIO::cerr << __func__ << ": wrong coordinates for this node" << std::endl;
129  QDP_abort(1);
130  }
131 
132  // Get the value
133  // NOTE: it would be nice to use the "peek" functions, but they will
134  // broadcast to all nodes the value since they are platform independent.
135  // We don't want that, so we poke into the on-node data
136  double val = (reim == 0) ?
137  double(psi[s].elem(linear).elem(spin).elem(color).real()) :
138  double(psi[s].elem(linear).elem(spin).elem(color).imag());
139 
140  // if (spin >= Ns/2)
141  // val *= -1;
142 
143  END_CODE();
144  return val;
145  }
146 
147 
148 
149 
150  // Fermion Writer function - user supplied
151  void fermionWriter(const int latt_coord[5],
152  int color,
153  int spin,
154  int reim,
155  double val,
156  void *env )
157 
158  {
159  START_CODE();
160 
161  /* Translate arg */
162  multi1d<LatticeFermion>& psi = *(multi1d<LatticeFermion>*)env;
163 
164  // Get node and index
165  int s = latt_coord[Nd];
166  multi1d<int> coord(Nd);
167  coord = latt_coord;
168  int node = Layout::nodeNumber(coord);
169  int linear = Layout::linearSiteIndex(coord);
170 
171  if (node != Layout::nodeNumber()) {
172  QDPIO::cerr << __func__ << ": wrong coordinates for this node" << std::endl;
173  QDP_abort(1);
174  }
175 
176  // Rescale
177  // if (spin >= Ns/2)
178  // val *= -1;
179 
180  // val *= -2.0;
181 
182  // Set the value
183  // NOTE: it would be nice to use the "peek" functions, but they will
184  // broadcast to all nodes the value since they are platform independent.
185  // We don't want that, so we poke into the on-node data
186  if (reim == 0)
187  psi[s].elem(linear).elem(spin).elem(color).real() = val;
188  else
189  psi[s].elem(linear).elem(spin).elem(color).imag() = val;
190 
191  END_CODE();
192  return;
193  }
194 
195  // Env is for the interface spec. I ignore it completely
196  void sublattice_func(int lo[],
197  int hi[],
198  const int node[],
199  void *env)
200  {
201  START_CODE();
202  // Given my node coordinates in node[]
203  // produce the lo/hi pair
204 
205  // This is the size on the local subgrid.
206  // For QDP++ they are all the same.
207  const multi1d<int>& local_subgrid=Layout::subgridLattSize();
208  for(int i=0; i <Nd; i++) {
209  // The lowest coordinate is just the node coordinate
210  // times the local subgrid size in that direction
211  lo[i]=node[i]*local_subgrid[i];
212 
213  // The high is the start of the next 'corner'
214  // I would say my hi is hi[i]-1 but am following the
215  // document conventions
216  hi[i]=(node[i]+1)*local_subgrid[i];
217  }
218 
219  END_CODE();
220  return;
221  }
222  } // End of anonymous namespace
223 
224  //! Solver the linear system
225  /*!
226  * \param psi quark propagator ( Modify )
227  * \param chi source ( Read )
228  * \return number of CG iterations
229  */
231  const multi1d<LatticeFermion>& chi) const
232  {
233  START_CODE();
234 
236  res.n_count = 0;
237  int out_iters_single=0;
238  int out_iters_double=0;
239 
240  /* Stuff for flopcounting */
241  double time_sec;
242  long long flops;
243  long long sent; /* Messages? Bytes? Faces? */
244  long long received; /* Messages? Bytes? Receives? */
245 
246 
247  /* Single Precision Branch. */
248  {
249  double out_eps_single;
250 
251  // OK Try to solve for the single precision
252  // Cast to single precision gauge field
253  QOP_F3_MDWF_Gauge *sprec_gauge = NULL;
254  if ( QOP_F3_MDWF_import_gauge(&sprec_gauge,
255  state,
256  gaugeReader,
257  (void *)&u) != 0 ) {
258  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
259  QDP_abort(1);
260  }
261 
262  /* Fermion fields */
263  QOP_F3_MDWF_Fermion *sprec_rhs;
264  QOP_F3_MDWF_Fermion *sprec_x0;
265  QOP_F3_MDWF_Fermion *sprec_soln;
266 
267  if( QOP_F3_MDWF_import_fermion(&sprec_rhs,
268  state,
269  fermionReader,
270  (void *)&chi) != 0 ) {
271  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
272  QDP_abort(1);
273  }
274 
275  if( QOP_F3_MDWF_import_fermion(&sprec_x0,
276  state,
277  fermionReader,
278  (void *)&psi) != 0 ) {
279  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
280  QDP_abort(1);
281  }
282 
283  if( QOP_F3_MDWF_allocate_fermion(&sprec_soln,
284  state) != 0 ) {
285  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
286  QDP_abort(1);
287  }
288 
289  /* Do the solve */
290 
291  double target_epsilon = toDouble(invParam.RsdTarget*invParam.RsdTarget);
292  int max_iteration = invParam.MaxIter;
293  int status;
294 
295  QDPIO::cout << "LinOpSysSolverMDWFArray: Beginning Single Precision Solve" << std::endl;
296  if( ( status=QOP_F3_MDWF_DDW_CG(sprec_soln,
297  &out_iters_single,
298  &out_eps_single,
299  params,
300  sprec_x0,
301  sprec_gauge,
302  sprec_rhs,
303  max_iteration,
304  target_epsilon,
305  QOP_MDWF_LOG_NONE)) != 0 ) {
306  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
307  QDP_abort(1);
308  }
309 
310  /* Get Perf counters before solve */
311  if( QOP_MDWF_performance(&time_sec,
312  &flops,
313  &sent,
314  &received,
315  state) != 0 ) {
316  QDPIO::cerr << "MDWF_Error: "<< QOP_MDWF_error(state) << std::endl;
317  QDP_abort(1);
318  }
319 
320  /* Report status */
321  QDPIO::cout << "LinOpSysSolverMDWFArray Single Prec : status=" << status
322  << " iterations=" << out_iters_single
323  << " resulting epsilon=" << sqrt(out_eps_single) << std::endl;
324 
325  /* Report Flops */
326  FlopCounter flopcount_single;
327  flopcount_single.reset();
328  flopcount_single.addFlops(flops);
329  flopcount_single.report("LinOpSysSolverMDWFArray_Single_Prec:", time_sec);
330 
331  /* Export the solution */
332  if( QOP_F3_MDWF_export_fermion(fermionWriter,
333  &psi,
334  sprec_soln) != 0 ) {
335  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
336  QDP_abort(1);
337  }
338 
339  /* Now I can free the fermions */
340  QOP_F3_MDWF_free_fermion(&sprec_soln);
341  QOP_F3_MDWF_free_fermion(&sprec_x0);
342  QOP_F3_MDWF_free_fermion(&sprec_rhs);
343  QOP_F3_MDWF_free_gauge(&sprec_gauge);
344 
345  res.n_count = out_iters_single;
346 
347  }
348 
349  /* DoublePrecision Branch. */
350  {
351  double out_eps_double;
352 
353  // OK Try to solve for the single precision
354  // Cast to single precision gauge field
355  QOP_D3_MDWF_Gauge *dprec_gauge = NULL;
356  if ( QOP_D3_MDWF_import_gauge(&dprec_gauge,
357  state,
358  gaugeReader,
359  (void *)&u) != 0 ) {
360  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
361  QDP_abort(1);
362  }
363 
364  /* Fermion fields */
365  QOP_D3_MDWF_Fermion *dprec_rhs; // Right hand side
366  QOP_D3_MDWF_Fermion *dprec_x0; // Guess
367  QOP_D3_MDWF_Fermion *dprec_soln; // result
368 
369  if( QOP_D3_MDWF_import_fermion(&dprec_rhs,
370  state,
371  fermionReader,
372  (void *)&chi) != 0 ) {
373  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
374  QDP_abort(1);
375  }
376 
377  if( QOP_D3_MDWF_import_fermion(&dprec_x0,
378  state,
379  fermionReader,
380  (void *)&psi) != 0 ) {
381  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
382  QDP_abort(1);
383  }
384 
385  if( QOP_D3_MDWF_allocate_fermion(&dprec_soln,
386  state) != 0 ) {
387  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
388  QDP_abort(1);
389  }
390 
391  /* Do the solve */
392  double target_epsilon = toDouble(invParam.RsdTargetRestart*invParam.RsdTargetRestart);
393  int max_iteration = invParam.MaxIterRestart;
394  int status;
395 
396 
397  QDPIO::cout << "LinOpSysSolverMDWFArray: Beginning Double Precision Solve" << std::endl;
398  if( ( status=QOP_D3_MDWF_DDW_CG(dprec_soln,
399  &out_iters_double,
400  &out_eps_double,
401  params,
402  dprec_x0,
403  dprec_gauge,
404  dprec_rhs,
405  max_iteration,
406  target_epsilon,
407  QOP_MDWF_LOG_NONE)) != 0 ) {
408  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
409  QDP_abort(1);
410  }
411 
412  /* Get Perf counters after solve */
413  if( QOP_MDWF_performance(&time_sec,
414  &flops,
415  &sent,
416  &received,
417  state) != 0 ) {
418  QDPIO::cerr << "MDWF_Error: "<< QOP_MDWF_error(state) << std::endl;
419  QDP_abort(1);
420  }
421  /* Reoirt Status */
422  QDPIO::cout << "LinOpSysSolverMDWFArray Double Prec: status=" << status
423  << " iterations=" << out_iters_double
424  << " resulting epsilon=" << sqrt(out_eps_double) << std::endl;
425 
426  /* Report Flops */
427  FlopCounter flopcount_double;
428  flopcount_double.reset();
429  flopcount_double.addFlops(flops);
430  flopcount_double.report("LinOpSysSolverMDWFArray_Double_Prec:", time_sec);
431 
432  /* Export the solution */
433  if( QOP_D3_MDWF_export_fermion(fermionWriter,
434  &psi,
435  dprec_soln) != 0 ) {
436  QDPIO::cerr << "MDWF Error: "<< QOP_MDWF_error(state) << std::endl;
437  QDP_abort(1);
438  }
439 
440  /* Now I can free the fermions */
441  QOP_D3_MDWF_free_fermion(&dprec_soln);
442  QOP_D3_MDWF_free_fermion(&dprec_x0);
443  QOP_D3_MDWF_free_fermion(&dprec_rhs);
444  QOP_D3_MDWF_free_gauge(&dprec_gauge);
445 
446  /* Add the double prec iteration count onto the global count */
447  res.n_count += out_iters_double;
448  }
449 
450  // Compute actual residual
451  {
452  multi1d<LatticeFermion> r(invParam.N5);
453  (*A)(r, psi, PLUS);
454  r -= chi;
455  res.resid = sqrt(norm2(r));
456  }
457  QDPIO::cout << "MDWF Final: single_iters=" << out_iters_single << " double_iters=" << out_iters_double << " total_iters=" << res.n_count << std::endl;
458  QDPIO::cout << "MDWF Final: final absolute unprec residuum="<<res.resid<<std::endl;
459 
460  END_CODE();
461  return res;
462  }
463 
464 
465  /* INIT Function */
467  {
468  START_CODE();
469 
470  if( Nd != 4 ) {
471  QDPIO::cout << "This will only work for Nd=4" << std::endl;
472  QDP_abort(1);
473  }
474 
475  if( Nc != 3 ) {
476  QDPIO::cout << "This will only work for Nc=3" << std::endl;
477  QDP_abort(1);
478  }
479 
480 
481 
482  // I need to call Andrews init function.
483  // For this I need a function that can tell me the
484  multi1d<int> lattice(5);
485  multi1d<int> network(4);
486  multi1d<int> node_coords(4);
487  int master_p;
488 
489  // Lattice is the 5D lattice size
490  for(int mu=0; mu < Nd; mu++) {
491  lattice[mu] = Layout::lattSize()[mu];
492  }
493  lattice[Nd]=invParam.N5;
494 
495  for(int mu=0; mu < Nd; mu++) {
496  network[mu] = Layout::logicalSize()[mu];
497  node_coords[mu] = Layout::nodeCoord()[mu];
498  }
499 
500  // Master_p has to be zero on master node and nonzero
501  // elsewhere. This is odd.
502  if( Layout::primaryNode() ) {
503  master_p = 0;
504  }
505  else {
506  master_p = 1;
507  }
508 
509  // Announce a version just to be nice
510  QDPIO::cout << "LinOpSysSolverMDWFArray: Initializing MDWF Library Version " << QOP_MDWF_version() << std::endl;
511 
512  // OK. Let's call Andrew's routine
513  if( QOP_MDWF_init(&state, lattice.slice(), network.slice(),
514  node_coords.slice(), master_p, sublattice_func,
515  NULL) != 0 ) {
516  // Nonzero return value => error
517  QDPIO::cerr << "MDWF Error: " << QOP_MDWF_error(state) << std::endl;
518  QDP_abort(1);
519  }
520 
521  // Set up the masses etc...
522  u.resize(Nd);
523  u = fermstate->getLinks();
524  Real ff = Real(1);
525 
526  if (invParam.anisoParam.anisoP) {
527  ff = where(invParam.anisoParam.anisoP, invParam.anisoParam.nu / invParam.anisoParam.xi_0, Real(1));
528  for(int mu=0; mu < u.size(); ++mu) {
529  if (mu != invParam.anisoParam.t_dir)
530  u[mu] *= ff;
531  }
532  }
533 
534  // Set the Shamir parameters for now
535  {
536  double a5 = (double)1;
537 
538  // Convention change. Now in the internal code it is 5 - OverMass
539  // To allow using anisotropy, I subtract off the 5 and 'add it back on'
540  // suitable for the anisotropic case
541  double M5 = (double)(-5) + toDouble((double)1 + a5*((double)1 + (double)(Nd-1)*ff - invParam.OverMass));
542 
543  double m_f = toDouble(invParam.Mass);
544 
545 
546  if( QOP_MDWF_set_generic(&params,
547  state,
548  b5_in.slice(),
549  c5_in.slice(),
550  M5 ,m_f) != 0){
551  QDPIO::cerr << "MDWF Error: " << QOP_MDWF_error(state)<< std::endl;
552  QDP_abort(1);
553  }
554 
555 
556  }
557 
558  END_CODE();
559  return;
560  }
561 
562  // Finalize - destructor call
563  void LinOpSysSolverMDWFArray::fini(void)
564  {
565  START_CODE();
566  QDPIO::cout << "MDWFQpropT: Finalizing MDWF Library Version " << QOP_MDWF_version() << std::endl;
567 
568  if (params != NULL) {
569  QOP_MDWF_free_parameters(&params);
570  }
571 
572  if (state != NULL) {
573  QOP_MDWF_fini(&state);
574  }
575 
576  END_CODE();
577  return;
578  }
579 
580 
581  /*! @} */ // end of group qprop
582 }
583 
Anisotropy parameters.
#define END_CODE()
Definition: chromabase.h:65
#define START_CODE()
Definition: chromabase.h:64
Support class for fermion actions and linear operators.
Definition: state.h:94
Class for counted reference semantics.
Definition: handle.h:33
Linear Operator to arrays.
Definition: linearop.h:61
static T & Instance()
Definition: singleton.h:432
int mu
Definition: cool.cc:24
Params params
unsigned s
Definition: ldumul_w.cc:37
unsigned i
Definition: ldumul_w.cc:34
multi1d< int > coord(Nd)
Nd
Definition: meslate.cc:74
double gaugeReader(const void *OuterGauge, void *env, const int latt_coord[4], int mu, int row, int col, int reim)
static bool registered
Local registration flag.
const std::string name("MDWF_INVERTER")
Name to be used.
LinOpSystemSolverArray< LatticeFermion > * createFerm(XMLReader &xml_in, const std::string &path, Handle< FermState< LatticeFermion, multi1d< LatticeColorMatrix >, multi1d< LatticeColorMatrix > > > state, Handle< LinearOperatorArray< LatticeFermion > > A)
Callback function.
bool registerAll()
Register all the factories.
void init(MesonSpecData_t &data, XMLWriter &xml, const std::string &path, const std::string &id_tag, const Params &params)
Do some initialization.
multi1d< Hadron2PtContraction_t > operator()(const multi1d< LatticeColorMatrix > &u)
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
static multi1d< LatticeColorMatrix > u
@ PLUS
Definition: chromabase.h:45
A(A, psi, r, Ncb, PLUS)
const WilsonTypeFermAct< multi1d< LatticeFermion > > Handle< const ConnectState > state
Definition: pbg5p_w.cc:28
::std::string string
Definition: gtest.h:1979
multi1d< LatticeFermion > r(Ncb)
chi
Definition: pade_trln_w.cc:24
psi
Definition: pade_trln_w.cc:191
Holds return info from SystemSolver call.
Definition: syssolver.h:17
Register linop system solvers that solve M*psi=chi.
Factory for solving M*psi=chi where M is not hermitian or pos. def.
DWF/SSE double-prec solver.