CHROMA
mre_shifted_predictor.h
Go to the documentation of this file.
1 // -*- C++ -*-
2 /*! \file
3  * \brief Minimal residual predictor
4  *
5  * Predictors for HMC
6  */
7 
8 #ifndef __mre_shifted_predictor_h__
9 #define __mre_shifted_predictor_h__
10 
11 #include "chromabase.h"
12 #include "handle.h"
17 namespace Chroma
18 {
19 
20  /*! @ingroup predictor */
21  namespace MinimalResidualExtrapolationShifted4DChronoPredictorEnv
22  {
23  extern const std::string name;
24  bool registerAll();
25  }
26 
27  //! Minimal residual predictor
28  /*! @ingroup predictor */
29  template<typename T, typename R>
31  {
32  private:
36 
37 
38  void
40  T& psi,
41  const T& chi,
42  const R& shift,
43  enum PlusMinus isign)
44  {
45  START_CODE();
46 
47  const Subset& s= M.subset();
48 
49 
50 #if 1
51  T rvec;
52  {
53 
54  rvec[s] = chi;
55  T tmp;
56  M(tmp, psi, isign);
57  tmp[s] += shift*psi;
58  rvec[s] -= tmp;
59  // Double norm_r = sqrt(norm2(rvec,s));
60  // Double norm_chi = sqrt(norm2(chi,s));
61  // QDPIO::cout << "MRE Predictor: before prediction || r || / || b || =" << norm_r/norm_chi << std::endl;
62  }
63 #endif
64 
65  int Nvec = chrono_buf->size();
66 
67 
68  // Construct an orthonormal basis from the
69  // vectors in the buffer. Stick to notation of paper and call these
70  // v
71  // multi1d<T> v(Nvec);
72 
73 
74  // Now I need to form G_n m = v_[n]^{dag} A v[m]
75  multi2d<DComplex> G(Nvec,Nvec);
76  multi1d<DComplex> b(Nvec);
77 
78  for(int m = 0 ; m < Nvec; m++) {
79  T v_m;
80  T Mv_m;
81  chrono_buf->get(m, v_m);
82  chrono_bufM->get(m, Mv_m);
83  // M(Mv_m, v_m, isign);
84  b[m] = innerProduct(v_m, rvec, s);
85  for(int n = 0; n < Nvec; n++) {
86  T v_n;
87  chrono_buf->get(n, v_n);
88  G(n,m) = innerProduct(v_n, Mv_m, s);
89  DComplex dcshift(shift);
90  G(n,m) += dcshift*innerProduct(v_n, v_m,s);
91  }
92  }
93 
94 
95  // Solve G_nm a_m = b_n:
96 
97  // First LU decompose G in place
98  multi1d<DComplex> a(Nvec);
99 
100  LUSolve(a, G, b);
101 
102 #if 0
103  // Check solution
104  multi1d<DComplex> Ga(Nvec);
105 
106  for(int i=0; i < Nvec; i++) {
107  Ga[i] = G(i,0)*a[0];
108  for(int j=1; j < Nvec; j++) {
109  Ga[i] += G(i,j)*a[j];
110  }
111  }
112 
113  multi1d<DComplex> r(Nvec);
114  for(int i=0; i < Nvec; i++) {
115  r[i] = b[i]-Ga[i];
116  }
117 
118 
119  QDPIO::cout << "Constraint Eq Solution Check" << std::endl;
120  for(int i=0; i < Nvec; i++) {
121  QDPIO::cout << " r[ " << i << "] = " << r[i] << std::endl;
122  }
123 #endif
124 
125  // Create teh lnear combination
126  {
127  T v;
128  chrono_buf->get(0,v);
129  psi[s] += Complex(a[0])*v;
130  for(int n=1; n < Nvec; n++) {
131  chrono_buf->get(n, v);
132  psi[s] += Complex(a[n])*v;
133  }
134  }
135 
136 #if 0
137  {
138  // T rvec;
139  rvec[s] = chi;
140  T tmp;
141  M(tmp, psi, isign);
142  tmp[s] += shift*psi;
143 
144  rvec[s] -= tmp;
145  Double norm_r = sqrt(norm2(rvec,s));
146  Double norm_chi = sqrt(norm2(chi,s));
147  QDPIO::cout << "MRE Predictor: after prediction || r || / || b || =" << norm_r/norm_chi << std::endl;
148  }
149 #endif
150 
151  END_CODE();
152  }
153 
154 
155 
156  public:
157 
159  chrono_buf(new CircularBuffer<T>(max_chrono)),
160  chrono_bufM(new CircularBuffer<T>(max_chrono)),
161  M(M_) {}
162 
163  // Destructor is automagic
165 
166  void predictX(T& X,
167  const R& shift,
168  const T& chi)
169  {
170  START_CODE();
171  StopWatch swatch;
172  swatch.reset();
173  swatch.start();
174 
175  int Nvec = chrono_buf->size();
176  switch(Nvec) {
177  case 0:
178  {
179  return;
180  }
181  break;
182 #if 0
183  case 1:
184  {
185  QDPIO::cout << "MRE Predictor: Only 1 std::vector stored. Giving you last solution " << std::endl;
186  chrono_buf->get(0,X);
187  }
188  break;
189 #endif
190  default:
191  {
192  QDPIO::cout << "MRE Predictor: Finding X extrapolation with "<< Nvec << " vectors" << std::endl;
193 
194  // Expect M is either MdagM if we use chi
195  // or M if we minimize against Y
196  find_extrap_solution(X, chi, shift, PLUS);
197  }
198  break;
199  }
200 
201  swatch.stop();
202  QDPIO::cout << "MRE_PREDICT_X_TIME = " << swatch.getTimeInSeconds() << " s" << std::endl;
203 
204  END_CODE();
205  }
206 
207 
208 
209  // No internal state so reset is a nop
210  void reset(void) {
211  chrono_buf->reset();
212  }
213 
214 
215  void newXVector(const T& X)
216  {
217  START_CODE();
218  chrono_buf->push(X);
219  T Mv;
220  M(Mv,X,PLUS);
221  chrono_bufM->push(Mv);
222 
223 
224  QDPIO::cout << "MREPredictor: number of X vectors stored is = " << chrono_buf->size() << std::endl;
225 
226  END_CODE();
227  }
228 
229 
230 
231  void replaceXHead(const T& v_)
232  {
233  START_CODE();
234  const Subset& s = M.subset();
235 
236  chrono_buf->replaceHead(v_);
237 
238  T Mv;
239  M(Mv, v_, PLUS);
240  chrono_bufM->replaceHead(Mv);
241 
242 
243  QDPIO::cout << "MREPredictor: number of X vectors stored is = " << chrono_buf->size() << std::endl;
244 
245  END_CODE();
246  }
247 
248 
249  };
250 
251 
252 } // End Namespace Chroma
253 
254 #endif
Primary include file for CHROMA library code.
Chronological predictor for HMC.
Monomial factories.
Circular buffers.
Circular Buffer.
Class for counted reference semantics.
Definition: handle.h:33
Linear Operator.
Definition: linearop.h:27
MinimalResidualExtrapolationShifted4DChronoPredictor(unsigned int max_chrono, const LinearOperator< T > &M_)
void find_extrap_solution(T &psi, const T &chi, const R &shift, enum PlusMinus isign)
void LUSolve(multi1d< DComplex > &a, const multi2d< DComplex > &M, const multi1d< DComplex > &b)
Solve M a = b by LU decomposition with partial pivoting.
Definition: lu_solve.cc:8
Class for counted reference semantics.
unsigned j
Definition: ldumul_w.cc:35
unsigned n
Definition: ldumul_w.cc:36
LU solver.
static int m[4]
Definition: make_seeds.cc:16
BinaryReturn< C1, C2, FnInnerProduct >::Type_t innerProduct(const QDPSubType< T1, C1 > &s1, const QDPType< T2, C2 > &s2)
multi1d< LatticeColorMatrix > G
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
LatticeFermion tmp
Definition: mespbg5p_w.cc:36
LinOpSysSolverMGProtoClover::T T
int i
Definition: pbg5p_w.cc:55
@ PLUS
Definition: chromabase.h:45
multi1d< LatticeFermion > chi(Ncb)
Complex a
Definition: invbicg.cc:95
LatticeFermion psi
Definition: mespbg5p_w.cc:35
START_CODE()
Complex b
Definition: invbicg.cc:96
multi1d< LatticeFermion > s(Ncb)
FloatingPoint< double > Double
Definition: gtest.h:7351
::std::string string
Definition: gtest.h:1979