CHROMA
reunit.cc
Go to the documentation of this file.
1 // -*- C++ -*-
2 
3 /*! \file
4  * \ingroup gauge
5  * \author Subsetting added by A. Hart
6  * \param[in,out] xa The descriptor of matrices to be reunitarized.
7  * Must be of type LatticeColorMatrix
8  * \param[in] bad Descriptor of flags indicating sites violating unitarity.
9  * Only used if ruflag = REUNITARIZE_LABEL or
10  * REUNITARIZE_ERROR.
11  * \param[in] ruflag Can also be REUNITARIZE in which case the
12  * matrices are reunitarized but no complaints are made.
13  * \param[out] numbad Total number of matrices violating unitarity.
14  * ONLY USED IF ruflag is testing for ERROR or LABEL.
15  * \param[in] mstag An (un)ordered subset of sites
16  * \brief Reunitarize in place a color matrix to SU(N)
17  *
18  * Reunitarize (to a SU(N)) inplace the matrix XA under some option
19  */
20 
21 #include "chromabase.h"
22 #include "util/gauge/reunit.h"
23 
24 
25 namespace Chroma {
26 
27 
28  namespace ReunitEnv {
29  static double time_spent =0;
30  double getTime() { return time_spent; }
31 
32  template<typename Q>
33  struct fuzzForType {};
34 
35  template<>
36  struct fuzzForType<LatticeColorMatrixF> {
37  static constexpr float value() { return 1.0e-5; }
38  };
39 
40  template<>
41  struct fuzzForType<LatticeColorMatrixD> {
42  static constexpr double value() { return 1.0e-13; }
43  };
44 
45  }
46 
47  template<typename Q, typename C, typename R, typename S>
48  inline
49  void reunit_t(Q& xa,
50  LatticeBoolean& bad,
51  int& numbad,
52  enum Reunitarize ruflag,
53  const S& mstag)
54  {
55  START_CODE();
56 
57  QDP::StopWatch swatch;
58  swatch.reset();
59  swatch.start();
60 
61  multi2d<C> a(Nc, Nc);
62  multi2d<C> b(Nc, Nc);
63  R t1;
64  C t2;
65  R t3;
66  R t4;
67  R sigmasq = 0;
68  multi1d<C> row(Nc);
69 
70  // The initial number of matrices violating unitarity.
71  numbad = 0;
72 
73  // some kind of small floating point number, should be prec. dep.
75 
76  // Extract initial components
77  for(int i=0; i < Nc; ++i)
78  for(int j=0; j < Nc; ++j)
79  (a[i][j])[mstag] = peekColor(xa, i, j);
80 
81  // Use the Nc-dependent reunitarizers
82  switch (Nc)
83  {
84  //---------------------------------------------------------
85  // U(1)
86  case 1:
87  /* normalise the complex number */
88  /* sigmasq = sqrt(u^t . u) */
89  sigmasq[mstag] = sqrt(localNorm2(a[0][0]));
90 
91  /* rescale the field */
92  /* u <- u/sigmasq */
93  (a[0][0])[mstag] /= sigmasq;
94 
95  /* Now, do various things depending on the input flag. */
96  /* For use later, calculate the mean squared deviation */
97  if ( ruflag == REUNITARIZE_ERROR ||
98  ruflag == REUNITARIZE_LABEL )
99  {
100  sigmasq[mstag] = fabs(1-sigmasq);
101  }
102 
103  /* Do things depending on the mean squared deviation */
104  switch (ruflag)
105  {
106  case REUNITARIZE_ERROR:
107  /* Gripe and stop if unitarity is violated. */
108  numbad = toInt(sum(where(sigmasq > fuzz, LatticeInteger(1), LatticeInteger(0))));
109  if ( numbad > 0 )
110  QDP_error_exit("Unitarity violated", numbad);
111  break;
112 
113  case REUNITARIZE_LABEL:
114  /* Label the bad guys if unitarity is violated. */
115  bad[mstag] = sigmasq > fuzz;
116  numbad = toInt( sum(where(bad, LatticeInteger(1), LatticeInteger(0)),
117  mstag) );
118  break;
119  default:
120  break;
121  }
122  break;
123 
124  //---------------------------------------------------------
125  // SU(2)
126  case 2:
127  /* If you want to check unitarity, I have to save the second row somewhere */
128  if ( ruflag == REUNITARIZE_ERROR ||
129  ruflag == REUNITARIZE_LABEL )
130  {
131  for(int c = 0; c < Nc; ++c)
132  (row[c])[mstag] = a[c][Nc-1];
133  }
134 
135  /* normalise the first row */
136  /* t1 = sqrt(u^t . u) */
137  t1[mstag] = localNorm2(a[0][0]);
138  for(int c = 1; c < Nc; ++c)
139  t1[mstag] += localNorm2(a[c][0]);
140  t1[mstag] = sqrt(t1);
141 
142 
143  /* overwrite the first row with the rescaled value */
144  /* u <- u/t1 */
145  t4[mstag] = 1 / t1;
146  for(int c = 0; c < Nc; ++c)
147  (a[c][0])[mstag] *= t4;
148 
149  /* construct the second row from the first row */
150  (a[1][1])[mstag] = adj(a[0][0]);
151  (a[0][1])[mstag] = -adj(a[1][0]);
152 
153  /* Now, do various things depending on the input flag. */
154  /* For use later, calculate the mean squared deviation */
155  if ( ruflag == REUNITARIZE_ERROR ||
156  ruflag == REUNITARIZE_LABEL )
157  {
158  /* Calculate the mean square deviation. */
159  /* sigmasq = (1 - t1)**2 */
160  sigmasq[mstag] = pow(1-t1,2);
161 
162  /* sigmasq <- sqrt(sigmasq + |crow(.)-a(1,.)|**2) */
163  /* overwrite row */
164  for(int c = 0; c < Nc; ++c)
165  sigmasq[mstag] += localNorm2(row[c] - a[c][Nc-1]);
166 
167  sigmasq[mstag] = sqrt(sigmasq);
168  }
169 
170 
171  /* Do things depending on the mean squared deviation */
172  switch (ruflag)
173  {
174  case REUNITARIZE_ERROR:
175  /* Gripe and stop if unitarity is violated. */
176  numbad = toInt(sum(where(sigmasq > fuzz, LatticeInteger(1), LatticeInteger(0))));
177  if ( numbad > 0 )
178  QDP_error_exit("Unitarity violated", numbad);
179  break;
180 
181  case REUNITARIZE_LABEL:
182  /* Label the bad guys if unitarity is violated. */
183  bad[mstag] = sigmasq > fuzz;
184  numbad = toInt(sum(where(bad, LatticeInteger(1), LatticeInteger(0)),
185  mstag));
186  break;
187 
188  default:
189  break;
190  }
191  break;
192 
193  //---------------------------------------------------------
194  // SU(3)
195  case 3:
196  /* If you want to check unitarity, I have to save the third row somewhere */
197  if ( ruflag == REUNITARIZE_ERROR ||
198  ruflag == REUNITARIZE_LABEL )
199  {
200  for(int c = 0; c < Nc; ++c)
201  (row[c])[mstag] = a[c][Nc-1];
202  }
203 
204 
205  /* normalise the first row */
206  /* t1 = sqrt(u^t . u) */
207  t1[mstag] = localNorm2(a[0][0]);
208  for(int c = 1; c < Nc; ++c)
209  t1[mstag] += localNorm2(a[c][0]);
210  t1[mstag] = sqrt(t1);
211 
212 
213  /* overwrite the first row with the rescaled value */
214  /* u <- u/t1 */
215  t4[mstag] = 1 / t1;
216  for(int c = 0; c < Nc; ++c)
217  (a[c][0])[mstag] *= t4;
218 
219  /* calculate the orthogonal component to the second row */
220  /* t2 <- u^t.v */
221  t2[mstag] = adj(a[0][0]) * a[0][1];
222  for(int c = 1; c < Nc; ++c)
223  t2[mstag] += adj(a[c][0]) * a[c][1];
224 
225  /* orthogonalize the second row relative to the first row */
226  /* v <- v - t2*u */
227  for(int c = 0; c < Nc; ++c)
228  (a[c][1])[mstag] -= t2 * a[c][0];
229 
230  /* normalise the second row */
231  /* t3 = sqrt(u^t . u) */
232  t3[mstag] = localNorm2(a[0][1]);
233  for(int c = 1; c < Nc; ++c)
234  t3[mstag] += localNorm2(a[c][1]);
235  t3[mstag] = sqrt(t3);
236 
237  /* overwrite the second row with the rescaled value */
238  /* v <- v/t3 */
239  t4[mstag] = 1 / t3;
240  for(int c = 0; c < Nc; ++c)
241  (a[c][1])[mstag] *= t4;
242 
243 
244  /* the third row is the cross product of the new first and second rows */
245  /* column 1: w(0) = u(1)*v(2) - u(2)*v(1) */
246  (a[0][2])[mstag] = adj(a[1][0]) * adj(a[2][1]) -
247  adj(a[2][0]) * adj(a[1][1]);
248 
249  /* column 2: w(1) = u(2)*v(0) - u(0)*v(2) */
250  (a[1][2])[mstag] = adj(a[2][0]) * adj(a[0][1]) -
251  adj(a[0][0]) * adj(a[2][1]);
252 
253  /* column 3: w(3) = u(1)*v(2) - u(2)*v(1) */
254  (a[2][2])[mstag] = adj(a[0][0]) * adj(a[1][1]) -
255  adj(a[1][0]) * adj(a[0][1]);
256 
257  /* Now, do various things depending on the input flag. */
258  /* For use later, calculate the mean squared deviation */
259  if ( ruflag == REUNITARIZE_ERROR ||
260  ruflag == REUNITARIZE_LABEL )
261  {
262  /* Calculate the mean square deviation. */
263  /* sigmasq = (1 - t1)**2 + |t2|**2 + (1 - t3)**2 */
264  sigmasq[mstag] = pow(1-t1,2) + localNorm2(t2) + pow(1-t3,2);
265 
266  /* sigmasq <- sqrt(sigmasq + |crow(.)-a(2,.)|**2) */
267  /* overwrite row */
268  for(int c = 0; c < Nc; ++c)
269  sigmasq[mstag] += localNorm2(row[c] - a[c][Nc-1]);
270 
271  sigmasq[mstag] = sqrt(sigmasq);
272  }
273 
274 
275  /* Do things depending on the mean squared deviation */
276  switch (ruflag)
277  {
278  case REUNITARIZE_ERROR:
279  /* Gripe and stop if unitarity is violated. */
280  numbad = toInt(sum(where(sigmasq > fuzz, LatticeInteger(1),
281  LatticeInteger(0)),
282  mstag));
283  if ( numbad > 0 )
284  QDP_error_exit("Unitarity violated", numbad);
285  break;
286  case REUNITARIZE_LABEL:
287  /* Label the bad guys if unitarity is violated. */
288  bad[mstag] = sigmasq > fuzz;
289  numbad = toInt(sum(where(bad,LatticeInteger(1), LatticeInteger(0)),
290  mstag));
291 
292  default:
293  break;
294  }
295  break;
296 
297  //---------------------------------------------------------
298  // SU(N > 3)
299  default:
300  if ( Nc > 3 )
301  {
302  /* If you want to check unitarity, I have to save the third row somewhere */
303  if ( ruflag == REUNITARIZE_ERROR ||
304  ruflag == REUNITARIZE_LABEL )
305  {
306  for(int c = 0; c < Nc; ++c)
307  (row[c])[mstag] = a[c][Nc-1];
308  }
309 
310  /* normalise the first row */
311  /* t1 = sqrt(u^t . u) */
312  t1[mstag] = localNorm2(a[0][0]);
313  for(int c = 1; c < Nc; ++c)
314  t1[mstag] += localNorm2(a[c][0]);
315  t1[mstag] = sqrt(t1);
316 
317  if ( ruflag == REUNITARIZE_ERROR ||
318  ruflag == REUNITARIZE_LABEL )
319  {
320  /* Calculate the mean square deviation. */
321  /* sigmasq = (1 - t1)**2 */
322  sigmasq[mstag] = pow(1-t1,2);
323  }
324 
325  /* overwrite the first row with the rescaled value */
326  /* u <- u/t1 */
327  t3[mstag] = 1 / t1;
328  for(int c = 0; c < Nc; ++c)
329  (a[c][0])[mstag] *= t3;
330 
331  /* Do Gramm-Schmidt on the remaining rows */
332  for(int j = 1; j < Nc; j++ )
333  {
334  for(int i = 0; i < j; i++ )
335  {
336  /* t2 <- u^t.v */
337  t2[mstag] = adj(a[0][i]) * a[0][j];
338  for(int c = 1; c < Nc; ++c)
339  {
340  t2[mstag] += adj(a[c][i]) * a[c][j];
341  }
342 
343  if ( (ruflag == REUNITARIZE_ERROR ||
344  ruflag == REUNITARIZE_LABEL) && j < (Nc-1) )
345  sigmasq[mstag] += localNorm2(t2);
346 
347  /* orthogonalize the j-th row relative to the i-th row */
348  /* v <- v - t2*u */
349  for(int c = 0; c < Nc; ++c)
350  {
351  (a[c][j])[mstag] -= t2 * a[c][i];
352  }
353  }
354 
355  /* normalise the j-th row */
356  /* t1 = sqrt(v^t . v) */
357  t1[mstag] = localNorm2(a[0][j]);
358  for(int c = 1; c < Nc; ++c)
359  t1[mstag] += localNorm2(a[c][j]);
360  t1[mstag] = sqrt(t1);
361 
362  /* overwrite the j-th row with the rescaled value */
363  /* v <- v/t1 */
364  t3[mstag] = 1 / t1;
365  for(int c = 0; c < Nc; ++c)
366  (a[c][j])[mstag] *= t3;
367 
368  if ( (ruflag == REUNITARIZE_ERROR ||
369  ruflag == REUNITARIZE_LABEL) && j < (Nc-1) )
370  {
371  /* Calculate the mean square deviation. */
372  /* sigmasq = (1 - t1)**2 */
373  sigmasq[mstag] += pow(1-t1,2);
374  }
375  }
376 
377  /* Now we have a unitary matrix. We need to multiply the last
378  row with a phase to make the determinant 1. */
379  /* We compute the determinant by LU decomposition */
380  for(int j = 0; j < Nc; j++)
381  for(int i = 0; i < Nc; i++)
382  (b[j][i])[mstag] = a[j][i];
383 
384  for(int j = 0; j < Nc; j++)
385  {
386  for(int i = 0; i <= j; i++)
387  {
388  t2[mstag] = b[j][i];
389  for(int c = 0; c < i; c++)
390  t2[mstag] -= b[c][i] * b[j][c];
391 
392  (b[j][i])[mstag] = t2;
393  }
394 
395  for(int i = (j+1); i < Nc; i++)
396  {
397  t2[mstag] = b[j][i];
398  for(int c = 0; c < j; c++)
399  t2[mstag] -= b[c][i] * b[j][c];
400 
401  (b[j][i])[mstag] = adj(b[j][j]) * t2 / localNorm2(b[j][j]);
402  }
403  }
404 
405  /* The determinant */
406  t2[mstag] = b[0][0] * b[1][1];
407  for(int c = 2; c < Nc; c++)
408  t2[mstag] *= b[c][c];
409 
410  /* The phase of the determinant */
411  t4[mstag] = atan2(imag(t2), real(t1));
412  t2[mstag] = cmplx(cos(t4), -sin(t4));
413  for(int c = 0; c < Nc; ++c)
414  (a[c][Nc-1])[mstag] *= t2;
415 
416 
417  /* Now, do various things depending on the input flag. */
418  /* For use later, finish calculating the mean squared deviation */
419  if ( ruflag == REUNITARIZE_ERROR ||
420  ruflag == REUNITARIZE_LABEL )
421  {
422  for(int c = 0; c < Nc; ++c)
423  sigmasq[mstag] += localNorm2(row[c] - a[c][Nc-1]);
424 
425  sigmasq[mstag] = sqrt(sigmasq);
426  }
427 
428  numbad = toInt(sum(where(sigmasq > fuzz, LatticeInteger(1),
429  LatticeInteger(0))));
430 
431  /* Do things depending on the mean squared deviation */
432  switch (ruflag)
433  {
434  case REUNITARIZE_ERROR:
435  /* Gripe and stop if unitarity is violated. */
436  numbad = toInt(sum(where(sigmasq > fuzz, LatticeInteger(1),
437  LatticeInteger(0))));
438  if ( numbad > 0 )
439  QDP_error_exit("Unitarity violated", numbad);
440  break;
441  case REUNITARIZE_LABEL:
442  /* Label the bad guys if unitarity is violated. */
443  bad[mstag] = sigmasq > fuzz;
444  numbad = toInt(sum(where(bad, LatticeInteger(1),
445  LatticeInteger(0)),
446  mstag));
447  default:
448  break;
449  }
450  }
451  else
452  QDP_error_exit("Invalid Nc for reunit, Nc=%d", Nc);
453  }
454 
455  // Insert final reunitarized components
456  for(int i=0; i < Nc; ++i)
457  for(int j=0; j < Nc; ++j)
458  pokeColor(xa[mstag], a[i][j], i, j);
459 
460  swatch.stop();
461  ReunitEnv::time_spent += swatch.getTimeInSeconds();
462  END_CODE();
463  }
464 
465  // Overloaded definitions
466  // SINGLE
467  void reunit(LatticeColorMatrixF3& xa)
468  {
469  START_CODE();
470 
471  LatticeBoolean bad;
472  int numbad;
473 
474  reunit_t<LatticeColorMatrixF3, LatticeComplexF, LatticeRealF, Subset>(xa, bad, numbad, REUNITARIZE, all);
475 
476  END_CODE();
477  }
478 
479  // Overloaded definitions
480  // DOUBLE
481  void reunit(LatticeColorMatrixD3& xa)
482  {
483  START_CODE();
484 
485  LatticeBoolean bad;
486  int numbad;
487 
488  reunit_t<LatticeColorMatrixD3, LatticeComplexD, LatticeRealD, Subset>(xa, bad, numbad, REUNITARIZE, all);
489 
490  END_CODE();
491  }
492 
493 
494  // SINGLE
495  void reunit(LatticeColorMatrixF3& xa,
496  const Subset& mstag)
497  {
498  START_CODE();
499 
500  LatticeBoolean bad;
501  int numbad;
502 
503  reunit_t<LatticeColorMatrixF3, LatticeComplexF, LatticeRealF, Subset>(xa, bad, numbad, REUNITARIZE, mstag);
504 
505  END_CODE();
506  }
507 
508  // DOUBLE
509  void reunit(LatticeColorMatrixD3& xa,
510  const Subset& mstag)
511  {
512  START_CODE();
513 
514  LatticeBoolean bad;
515  int numbad;
516 
517  reunit_t<LatticeColorMatrixD3, LatticeComplexD, LatticeRealD, Subset>(xa, bad, numbad, REUNITARIZE, mstag);
518 
519  END_CODE();
520  }
521 
522  // Overloaded definitions, with numbad and ruflag
523  // Single
524 
525  void reunit(LatticeColorMatrixF3& xa,
526  int& numbad,
527  enum Reunitarize ruflag)
528  {
529  START_CODE();
530 
531  LatticeBoolean bad;
532 
533  reunit_t<LatticeColorMatrixF3, LatticeComplexF, LatticeRealF, Subset>(xa, bad, numbad, REUNITARIZE, all);
534 
535  END_CODE();
536  }
537 
538  // DOUBLE
539  void reunit(LatticeColorMatrixD3& xa,
540  int& numbad,
541  enum Reunitarize ruflag)
542  {
543  START_CODE();
544 
545  LatticeBoolean bad;
546 
547  reunit_t<LatticeColorMatrixD3, LatticeComplexD, LatticeRealD, Subset>(xa, bad, numbad, REUNITARIZE, all);
548 
549  END_CODE();
550  }
551 
552 
553  // SINGLE
554  void reunit(LatticeColorMatrixF3& xa,
555  int& numbad,
556  enum Reunitarize ruflag,
557  const Subset& mstag)
558  {
559  START_CODE();
560 
561  LatticeBoolean bad;
562 
563  reunit_t<LatticeColorMatrixF3, LatticeComplexF, LatticeRealF, Subset>(xa, bad, numbad, REUNITARIZE, mstag);
564 
565  END_CODE();
566  }
567 
568  // DOUBLE
569  void reunit(LatticeColorMatrixD3& xa,
570  int& numbad,
571  enum Reunitarize ruflag,
572  const Subset& mstag)
573  {
574  START_CODE();
575 
576  LatticeBoolean bad;
577 
578  reunit_t<LatticeColorMatrixD3, LatticeComplexD, LatticeRealD, Subset>(xa, bad, numbad, REUNITARIZE, mstag);
579 
580  END_CODE();
581  }
582  // Overloaded definitions, with bad, numbad and ruflag
583  // SINGLE
584  void reunit(LatticeColorMatrixF3& xa,
585  LatticeBoolean& bad,
586  int& numbad,
587  enum Reunitarize ruflag)
588  {
589  reunit_t<LatticeColorMatrixF3, LatticeComplexF, LatticeRealF, Subset>(xa, bad, numbad, ruflag, all);
590  }
591 
592  // DOUBLE
593  void reunit(LatticeColorMatrixD3& xa,
594  LatticeBoolean& bad,
595  int& numbad,
596  enum Reunitarize ruflag)
597  {
598  reunit_t<LatticeColorMatrixD3, LatticeComplexD, LatticeRealD, Subset>(xa, bad, numbad, ruflag, all);
599  }
600 
601  // Single
602  void reunit(LatticeColorMatrixF3& xa,
603  LatticeBoolean& bad,
604  int& numbad,
605  enum Reunitarize ruflag,
606  const Subset& mstag)
607  {
608  reunit_t<LatticeColorMatrixF3, LatticeComplexF, LatticeRealF, Subset>(xa, bad, numbad, ruflag, mstag);
609  }
610 
611 
612  // Double
613  void reunit(LatticeColorMatrixD3& xa,
614  LatticeBoolean& bad,
615  int& numbad,
616  enum Reunitarize ruflag,
617  const Subset& mstag)
618  {
619  reunit_t<LatticeColorMatrixD3, LatticeComplexD, LatticeRealD, Subset>(xa, bad, numbad, ruflag, mstag);
620  }
621 
622 } // End namespace
Primary include file for CHROMA library code.
int numbad
Definition: cool.cc:28
LatticeBoolean bad
Definition: cool.cc:22
unsigned j
Definition: ldumul_w.cc:35
SpinMatrix C()
C = Gamma(10)
Definition: barspinmat_w.cc:29
static double time_spent
Definition: reunit.cc:29
double getTime()
Definition: reunit.cc:30
Asqtad Staggered-Dirac operator.
Definition: klein_gord.cc:10
QDP_error_exit("too many BiCG iterations", n_count, rsd_sq, cp, c, re_rvr, im_rvr, re_a, im_a, re_b, im_b)
LinOpSysSolverMGProtoClover::Q Q
Double c
Definition: invbicg.cc:108
int i
Definition: pbg5p_w.cc:55
void reunit(LatticeColorMatrixF3 &xa)
Definition: reunit.cc:467
Complex a
Definition: invbicg.cc:95
void reunit_t(Q &xa, LatticeBoolean &bad, int &numbad, enum Reunitarize ruflag, const S &mstag)
Definition: reunit.cc:49
Reunitarize
Definition: reunit.h:29
@ REUNITARIZE
Definition: reunit.h:29
@ REUNITARIZE_ERROR
Definition: reunit.h:29
@ REUNITARIZE_LABEL
Definition: reunit.h:29
START_CODE()
Complex b
Definition: invbicg.cc:96
Double sum
Definition: qtopcor.cc:37
Reunitarize in place a color matrix to SU(N)