6 #ifndef __clover_term_ptx_w_h__
7 #define __clover_term_ptx_w_h__
9 #warning "Using QPD-JIT/PTX clover"
55 inline T&
elem(
int i) {
return this->arrayF(
i); }
56 inline const T&
elem(
int i)
const {
return this->arrayF(
i); }
64 F[0].setup( rhs.elem(0) );
65 F[1].setup( rhs.elem(1) );
68 inline const T&
elem(
int i)
const {
return F[
i]; }
92 typedef typename WordType<T>::Type_t
Type_t;
98 typedef typename WordType<T>::Type_t
Type_t;
122 for (
int i = 0 ;
i < 2 * Nc ;
i++ )
127 inline T&
elem(
int i) {
return this->arrayF(
i); }
128 inline const T&
elem(
int i)
const {
return this->arrayF(
i); }
136 for (
int i=0;
i<2*Nc;++
i)
164 typedef typename WordType<T>::Type_t
Type_t;
170 typedef typename WordType<T>::Type_t
Type_t;
189 struct PTriOffJIT:
public BaseJIT<T,2*Nc*Nc-Nc>
194 for (
int i = 0 ;
i < 2*Nc*Nc-Nc ;
i++ )
199 inline T&
elem(
int i) {
return this->arrayF(
i); }
200 inline const T&
elem(
int i)
const {
return this->arrayF(
i); }
208 for (
int i=0;
i<2*Nc*Nc-Nc;++
i)
236 typedef typename WordType<T>::Type_t
Type_t;
242 typedef typename WordType<T>::Type_t
Type_t;
246 #if defined(QDP_USE_PROFILING)
248 struct LeafFunctor<
PComp<
T>, PrintTag>
251 static int apply(
const PComp<T> &
s,
const PrintTag &f)
258 struct LeafFunctor<
PTriDia<
T>, PrintTag>
261 static int apply(
const PTriDia<T> &
s,
const PrintTag &f)
268 struct LeafFunctor<
PTriOff<
T>, PrintTag>
271 static int apply(
const PTriOff<T> &
s,
const PrintTag &f)
288 struct QUDAPackedClovSite {
296 template<
typename T,
typename U>
301 typedef typename WordType<T>::Type_t
REALT;
303 typedef OLattice< PScalar< PScalar< RScalar< Word< REALT> > > > >
LatticeREAL;
304 typedef OScalar< PScalar< PScalar< RScalar< Word< REALT> > > > >
RealT;
381 const multi1d<U>&
getU()
const {
return u;}
396 OLattice<PComp<PTriDia<RScalar <Word<REALT> > > > >
tri_dia;
397 OLattice<PComp<PTriOff<RComplex<Word<REALT> > > > >
tri_off;
403 template<
typename T,
typename U>
407 template<
typename T,
typename U>
419 fbc = fs->getFermBC();
423 if (fbc.operator->() == 0) {
424 QDPIO::cerr <<
"PTXCloverTerm: error: fbc is null" << std::endl;
429 RealT ff = where(param.anisoParam.anisoP, Real(1) / param.anisoParam.xi_0, Real(1));
430 param.clovCoeffR *= Real(0.5) * ff;
431 param.clovCoeffT *= Real(0.5);
441 RealT ff = where(param.anisoParam.anisoP, param.anisoParam.nu / param.anisoParam.xi_0, Real(1));
442 diag_mass = 1 + (
Nd-1)*ff + param.Mass;
451 choles_done.resize(rb.numSubsets());
452 for(
int i=0;
i < rb.numSubsets();
i++) {
466 template<
typename T,
typename U>
477 fbc = fs->getFermBC();
481 if (fbc.operator->() == 0) {
482 QDPIO::cerr <<
"PTXCloverTerm: error: fbc is null" << std::endl;
487 RealT ff = where(param.anisoParam.anisoP, Real(1) / param.anisoParam.xi_0, Real(1));
488 param.clovCoeffR *=
RealT(0.5) * ff;
489 param.clovCoeffT *=
RealT(0.5);
499 RealT ff = where(param.anisoParam.anisoP, param.anisoParam.nu / param.anisoParam.xi_0, Real(1));
500 diag_mass = 1 + (
Nd-1)*ff + param.Mass;
506 makeClov(f, diag_mass);
508 choles_done.resize(rb.numSubsets());
509 for(
int i=0;
i < rb.numSubsets();
i++) {
510 choles_done[
i] =
false;
599 template<
typename RealT,
typename U,
typename X,
typename Y>
601 const RealT& diag_mass,
611 AddressLeaf addr_leaf;
613 int junk_0 = forEach(diag_mass, addr_leaf, NullCombine());
614 int junk_1 = forEach(f0, addr_leaf, NullCombine());
615 int junk_2 = forEach(
f1, addr_leaf, NullCombine());
616 int junk_3 = forEach(
f2, addr_leaf, NullCombine());
617 int junk_4 = forEach(
f3, addr_leaf, NullCombine());
618 int junk_5 = forEach(
f4, addr_leaf, NullCombine());
619 int junk_6 = forEach(
f5, addr_leaf, NullCombine());
620 int junk_7 = forEach(tri_dia, addr_leaf, NullCombine());
621 int junk_8 = forEach(tri_off, addr_leaf, NullCombine());
625 int hi = Layout::sitesOnNode();
627 std::vector<void*> addr;
629 addr.push_back( &lo );
632 addr.push_back( &hi );
635 int addr_dest=addr.size();
636 for(
int i=0;
i < addr_leaf.addr.size(); ++
i) {
637 addr.push_back( &addr_leaf.addr[
i] );
641 jit_launch(
function,Layout::sitesOnNode(),addr);
646 template<
typename RealT,
typename U,
typename X,
typename Y>
659 typedef typename WordType<RealT>::Type_t REALT;
663 jit_start_new_function();
665 jit_value r_lo = jit_add_param( jit_ptx_type::s32 );
666 jit_value r_hi = jit_add_param( jit_ptx_type::s32 );
668 jit_value r_idx = jit_geom_get_linear_th_idx();
670 jit_value r_out_of_range = jit_ins_ge( r_idx , r_hi );
671 jit_ins_exit( r_out_of_range );
673 ParamLeaf param_leaf( r_idx );
675 typedef typename LeafFunctor<RealT, ParamLeaf>::Type_t RealTJIT;
676 RealTJIT diag_mass_jit(forEach(diag_mass, param_leaf, TreeCombine()));
678 typedef typename LeafFunctor<U, ParamLeaf>::Type_t UJIT;
679 UJIT f0_jit(forEach(f0, param_leaf, TreeCombine()));
680 UJIT f1_jit(forEach(
f1, param_leaf, TreeCombine()));
681 UJIT f2_jit(forEach(
f2, param_leaf, TreeCombine()));
682 UJIT f3_jit(forEach(
f3, param_leaf, TreeCombine()));
683 UJIT f4_jit(forEach(
f4, param_leaf, TreeCombine()));
684 UJIT f5_jit(forEach(
f5, param_leaf, TreeCombine()));
685 auto& f0_j = f0_jit.elem(JitDeviceLayout::Coalesced);
686 auto& f1_j = f1_jit.elem(JitDeviceLayout::Coalesced);
687 auto& f2_j = f2_jit.elem(JitDeviceLayout::Coalesced);
688 auto& f3_j = f3_jit.elem(JitDeviceLayout::Coalesced);
689 auto& f4_j = f4_jit.elem(JitDeviceLayout::Coalesced);
690 auto& f5_j = f5_jit.elem(JitDeviceLayout::Coalesced);
692 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
693 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
694 auto& tri_dia_j = tri_dia_jit.elem(JitDeviceLayout::Coalesced);
696 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
697 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
698 auto& tri_off_j = tri_off_jit.elem(JitDeviceLayout::Coalesced);
700 typename REGType< typename RealTJIT::Subtype_t >::Type_t diag_mass_reg;
701 diag_mass_reg.setup( diag_mass_jit.elem( JitDeviceLayout::Scalar ) );
703 for(
int jj = 0; jj < 2; jj++) {
704 for(
int ii = 0; ii < 2*Nc; ii++) {
705 tri_dia_j.elem(jj).elem(ii) = diag_mass_reg.elem().elem();
711 RComplexREG<WordREG<REALT> > E_minus;
712 RComplexREG<WordREG<REALT> > B_minus;
713 RComplexREG<WordREG<REALT> > ctmp_0;
714 RComplexREG<WordREG<REALT> > ctmp_1;
715 RScalarREG<WordREG<REALT> > rtmp_0;
716 RScalarREG<WordREG<REALT> > rtmp_1;
719 for(
int i = 0;
i < Nc; ++
i) {
720 ctmp_0 = f5_j.elem().elem(
i,
i);
721 ctmp_0 -= f0_j.elem().elem(
i,
i);
722 rtmp_0 = imag(ctmp_0);
723 tri_dia_j.elem(0).elem(
i) += rtmp_0;
725 tri_dia_j.elem(0).elem(
i+Nc) -= rtmp_0;
727 ctmp_1 = f5_j.elem().elem(
i,
i);
728 ctmp_1 += f0_j.elem().elem(
i,
i);
729 rtmp_1 = imag(ctmp_1);
730 tri_dia_j.elem(1).elem(
i) -= rtmp_1;
732 tri_dia_j.elem(1).elem(
i+Nc) += rtmp_1;
735 for(
int i = 1;
i < Nc; ++
i) {
736 for(
int j = 0;
j <
i; ++
j) {
739 int elem_tmp = (
i+Nc)*(
i+Nc-1)/2 +
j+Nc;
741 ctmp_0 = f0_j.elem().elem(
i,
j);
742 ctmp_0 -= f5_j.elem().elem(
i,
j);
743 tri_off_j.elem(0).elem(
elem_ij) = timesI(ctmp_0);
745 zero_rep( tri_off_j.elem(0).elem(elem_tmp) );
746 tri_off_j.elem(0).elem(elem_tmp) -= tri_off_j.elem(0).elem(
elem_ij);
748 ctmp_1 = f5_j.elem().elem(
i,
j);
749 ctmp_1 += f0_j.elem().elem(
i,
j);
750 tri_off_j.elem(1).elem(
elem_ij) = timesI(ctmp_1);
752 zero_rep( tri_off_j.elem(1).elem(elem_tmp) );
753 tri_off_j.elem(1).elem(elem_tmp) -= tri_off_j.elem(1).elem(
elem_ij);
757 for(
int i = 0;
i < Nc; ++
i) {
758 for(
int j = 0;
j < Nc; ++
j) {
763 E_minus = f2_j.elem().elem(
i,
j);
764 E_minus = timesI( E_minus );
766 E_minus += f4_j.elem().elem(
i,
j);
769 B_minus = f3_j.elem().elem(
i,
j);
770 B_minus = timesI( B_minus );
772 B_minus -= f1_j.elem().elem(
i,
j);
774 tri_off_j.elem(0).elem(
elem_ij) = B_minus - E_minus;
776 tri_off_j.elem(1).elem(
elem_ij) = E_minus + B_minus;
780 return jit_get_cufunction(
"ptx_make_clov.ptx");
787 template<
typename T,
typename U>
793 QDPIO::cerr << __func__ <<
": expecting Nd==4" << std::endl;
798 QDPIO::cerr << __func__ <<
": expecting Ns==4" << std::endl;
802 U f0 = f[0] * getCloverCoeff(0,1);
803 U f1 = f[1] * getCloverCoeff(0,2);
804 U f2 = f[2] * getCloverCoeff(0,3);
805 U f3 = f[3] * getCloverCoeff(1,2);
806 U f4 = f[4] * getCloverCoeff(1,3);
807 U f5 = f[5] * getCloverCoeff(2,3);
811 static CUfunction
function;
813 if (
function == NULL)
827 template<
typename T,
typename U>
835 ldagdlinv(tr_log_diag_,
cb);
847 template<
typename T,
typename U>
852 if( choles_done[
cb] ==
false )
854 QDPIO::cout << __func__ <<
": Error: you have not done the Cholesky.on this operator on this subset" << std::endl;
855 QDPIO::cout <<
"You sure you should not be asking invclov?" << std::endl;
864 return sum(tr_log_diag_, rb[
cb]);
869 template<
typename T,
typename X,
typename Y>
876 if (!
s.hasOrderedRep())
877 QDP_error_exit(
"ldagdlinv on subset with unordered representation not implemented");
879 AddressLeaf addr_leaf;
881 int junk_0 = forEach(tr_log_diag, addr_leaf, NullCombine());
882 int junk_2 = forEach(tri_dia, addr_leaf, NullCombine());
883 int junk_3 = forEach(tri_off, addr_leaf, NullCombine());
889 std::vector<void*> addr;
891 addr.push_back( &lo );
894 addr.push_back( &hi );
897 int addr_dest=addr.size();
898 for(
int i=0;
i < addr_leaf.addr.size(); ++
i) {
899 addr.push_back( &addr_leaf.addr[
i] );
903 jit_launch(
function,
s.numSiteTable(),addr);
910 template<
typename U,
typename T,
typename X,
typename Y>
916 typedef typename WordType<U>::Type_t REALT;
922 jit_start_new_function();
924 jit_value r_lo = jit_add_param( jit_ptx_type::s32 );
925 jit_value r_hi = jit_add_param( jit_ptx_type::s32 );
927 jit_value r_idx_thread = jit_geom_get_linear_th_idx();
929 jit_value r_out_of_range = jit_ins_gt( r_idx_thread , jit_ins_sub( r_hi , r_lo ) );
930 jit_ins_exit( r_out_of_range );
932 jit_value r_idx = jit_ins_add( r_lo , r_idx_thread );
934 ParamLeaf param_leaf( r_idx );
937 typedef typename LeafFunctor<T, ParamLeaf>::Type_t TJIT;
938 TJIT tr_log_diag_jit(forEach(tr_log_diag, param_leaf, TreeCombine()));
939 auto& tr_log_diag_j = tr_log_diag_jit.elem(JitDeviceLayout::Coalesced);
941 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
942 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
943 auto& tri_dia_j = tri_dia_jit.elem(JitDeviceLayout::Coalesced);
945 typename REGType< typename XJIT::Subtype_t >::Type_t tri_dia_r;
946 tri_dia_r.setup( tri_dia_j );
949 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
950 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
951 auto& tri_off_j = tri_off_jit.elem(JitDeviceLayout::Coalesced);
953 typename REGType< typename YJIT::Subtype_t >::Type_t tri_off_r;
954 tri_off_r.setup( tri_off_j );
958 RScalarREG<WordREG<REALT> > zip;
962 int site_neg_logdet=0;
966 RScalarREG<WordREG<REALT> > inv_d[6] ;
967 RComplexREG<WordREG<REALT> > inv_offd[15] ;
968 RComplexREG<WordREG<REALT> > v[6] ;
969 RScalarREG<WordREG<REALT> > diag_g[6] ;
971 for(
int i=0;
i < N;
i++) {
972 inv_d[
i] = tri_dia_r.elem(
block).elem(
i);
975 for(
int i=0;
i < 15;
i++) {
976 inv_offd[
i] = tri_off_r.elem(
block).elem(
i);
980 for(
int j=0;
j < N; ++
j) {
982 for(
int i=0;
i <
j;
i++) {
985 RComplexREG<WordREG<REALT> > A_ii = cmplx( inv_d[
i], zip );
990 v[
j] = cmplx(inv_d[
j],zip);
992 for(
int k=0;
k <
j;
k++) {
993 int elem_jk =
j*(
j-1)/2 +
k;
994 v[
j] -= inv_offd[elem_jk]*v[
k];
997 inv_d[
j] = real( v[
j] );
999 for(
int k=
j+1;
k < N;
k++) {
1000 int elem_kj =
k*(
k-1)/2 +
j;
1001 for(
int l=0;
l <
j;
l++) {
1002 int elem_kl =
k*(
k-1)/2 +
l;
1003 inv_offd[elem_kj] -= inv_offd[elem_kl] * v[
l];
1005 inv_offd[elem_kj] /= v[
j];
1011 RScalarREG<WordREG<REALT> >
one(1.0);
1014 for(
int i=0;
i < N;
i++) {
1015 diag_g[
i] =
one/inv_d[
i];
1021 tr_log_diag_j.elem().elem() += log(fabs(inv_d[
i]));
1025 if( inv_d[
i].elem() < 0 ) {
1045 RComplexREG<WordREG<REALT> >
sum;
1046 for(
int k = 0;
k < N; ++
k) {
1048 for(
int i = 0;
i <
k; ++
i) {
1055 v[
k] = cmplx(diag_g[
k],zip);
1057 for(
int i =
k+1;
i < N; ++
i) {
1060 for(
int j =
k;
j <
i; ++
j) {
1076 for(
int i = N-2; (int)
i >= (
int)
k; --
i) {
1077 for(
int j =
i+1;
j < N; ++
j) {
1085 inv_d[
k] = real(v[
k]);
1086 for(
int i =
k+1;
i < N; ++
i) {
1088 int elem_ik =
i*(
i-1)/2+
k;
1089 inv_offd[elem_ik] = v[
i];
1094 for(
int i=0;
i < N;
i++) {
1095 tri_dia_j.elem(
block).elem(
i) = inv_d[
i];
1097 for(
int i=0;
i < 15;
i++) {
1098 tri_off_j.elem(
block).elem(
i) = inv_offd[
i];
1102 return jit_get_cufunction(
"ptx_ldagdlinv.ptx");
1110 template<
typename T,
typename U>
1117 QDPIO::cerr << __func__ <<
": Matrix is too small" << std::endl;
1122 tr_log_diag[rb[
cb]] =
zero;
1126 static CUfunction
function;
1128 if (
function == NULL)
1129 function = function_ldagdlinv_build<U>(tr_log_diag, tri_dia, tri_off, rb[
cb] );
1135 choles_done[
cb] =
true;
1189 template<
typename U,
typename X,
typename Y>
1197 if (!
s.hasOrderedRep())
1198 QDP_error_exit(
"triacntr on subset with unordered representation not implemented");
1200 AddressLeaf addr_leaf;
1202 int junk_0 = forEach(B, addr_leaf, NullCombine());
1203 int junk_2 = forEach(tri_dia, addr_leaf, NullCombine());
1204 int junk_3 = forEach(tri_off, addr_leaf, NullCombine());
1210 std::vector<void*> addr;
1212 addr.push_back( &lo );
1215 addr.push_back( &hi );
1218 addr.push_back( &mat );
1221 int addr_dest=addr.size();
1222 for(
int i=0;
i < addr_leaf.addr.size(); ++
i) {
1223 addr.push_back( &addr_leaf.addr[
i] );
1227 jit_launch(
function,
s.numSiteTable(),addr);
1233 template<
typename U,
typename X,
typename Y>
1242 typedef typename WordType<U>::Type_t REALT;
1248 jit_start_new_function();
1250 jit_value r_lo = jit_add_param( jit_ptx_type::s32 );
1251 jit_value r_hi = jit_add_param( jit_ptx_type::s32 );
1252 jit_value r_mat = jit_add_param( jit_ptx_type::s32 );
1254 jit_value r_idx_thread = jit_geom_get_linear_th_idx();
1256 jit_value r_out_of_range = jit_ins_gt( r_idx_thread , jit_ins_sub( r_hi , r_lo ) );
1257 jit_ins_exit( r_out_of_range );
1259 jit_value r_idx = jit_ins_add( r_lo , r_idx_thread );
1261 ParamLeaf param_leaf( r_idx );
1264 typedef typename LeafFunctor<U, ParamLeaf>::Type_t UJIT;
1265 UJIT B_jit(forEach(B, param_leaf, TreeCombine()));
1266 auto& B_j = B_jit.elem(JitDeviceLayout::Coalesced);
1268 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
1269 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
1270 auto& tri_dia_j = tri_dia_jit.elem(JitDeviceLayout::Coalesced);
1272 typename REGType< typename XJIT::Subtype_t >::Type_t tri_dia_r;
1273 tri_dia_r.setup( tri_dia_j );
1276 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
1277 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
1278 auto& tri_off_j = tri_off_jit.elem(JitDeviceLayout::Coalesced);
1280 typename REGType< typename YJIT::Subtype_t >::Type_t tri_off_r;
1281 tri_off_r.setup( tri_off_j );
1290 jit_label_t case_10;
1291 jit_label_t case_12;
1294 jit_ins_branch( case_0 , jit_ins_eq( r_mat , jit_value(0) ) );
1295 jit_ins_branch( case_3 , jit_ins_eq( r_mat , jit_value(3) ) );
1296 jit_ins_branch( case_5 , jit_ins_eq( r_mat , jit_value(5) ) );
1297 jit_ins_branch( case_6 , jit_ins_eq( r_mat , jit_value(6) ) );
1298 jit_ins_branch( case_9 , jit_ins_eq( r_mat , jit_value(9) ) );
1299 jit_ins_branch( case_10 , jit_ins_eq( r_mat , jit_value(10) ) );
1300 jit_ins_branch( case_12 , jit_ins_eq( r_mat , jit_value(12) ) );
1309 jit_ins_label( case_0 );
1311 RComplexREG<WordREG<REALT> > lctmp0;
1312 RScalarREG< WordREG<REALT> > lr_zero0;
1313 RScalarREG< WordREG<REALT> > lrtmp0;
1317 for(
int i0 = 0; i0 < Nc; ++i0) {
1319 lrtmp0 = tri_dia_r.elem(0).elem(i0);
1320 lrtmp0 += tri_dia_r.elem(0).elem(i0+Nc);
1321 lrtmp0 += tri_dia_r.elem(1).elem(i0);
1322 lrtmp0 += tri_dia_r.elem(1).elem(i0+Nc);
1323 B_j.elem().elem(i0,i0) = cmplx(lrtmp0,lr_zero0);
1328 for(
int i0 = 1; i0 < Nc; ++i0) {
1330 int elem_ijb0 = (i0+Nc)*(i0+Nc-1)/2 + Nc;
1332 for(
int j0 = 0; j0 < i0; ++j0) {
1334 lctmp0 = tri_off_r.elem(0).elem(elem_ij0);
1335 lctmp0 += tri_off_r.elem(0).elem(elem_ijb0);
1336 lctmp0 += tri_off_r.elem(1).elem(elem_ij0);
1337 lctmp0 += tri_off_r.elem(1).elem(elem_ijb0);
1339 B_j.elem().elem(j0,i0) = lctmp0;
1340 B_j.elem().elem(i0,j0) = adj(lctmp0);
1350 jit_ins_label( case_3 );
1358 RComplexREG<WordREG<REALT> > lctmp3;
1359 RScalarREG<WordREG<REALT> > lr_zero3;
1360 RScalarREG<WordREG<REALT> > lrtmp3;
1364 for(
int i3 = 0; i3 < Nc; ++i3) {
1366 lrtmp3 = tri_dia_r.elem(0).elem(i3+Nc);
1367 lrtmp3 -= tri_dia_r.elem(0).elem(i3);
1368 lrtmp3 -= tri_dia_r.elem(1).elem(i3);
1369 lrtmp3 += tri_dia_r.elem(1).elem(i3+Nc);
1370 B_j.elem().elem(i3,i3) = cmplx(lr_zero3,lrtmp3);
1375 for(
int i3 = 1; i3 < Nc; ++i3) {
1377 int elem_ijb3 = (i3+Nc)*(i3+Nc-1)/2 + Nc;
1379 for(
int j3 = 0; j3 < i3; ++j3) {
1381 lctmp3 = tri_off_r.elem(0).elem(elem_ijb3);
1382 lctmp3 -= tri_off_r.elem(0).elem(elem_ij3);
1383 lctmp3 -= tri_off_r.elem(1).elem(elem_ij3);
1384 lctmp3 += tri_off_r.elem(1).elem(elem_ijb3);
1386 B_j.elem().elem(j3,i3) = timesI(adj(lctmp3));
1387 B_j.elem().elem(i3,j3) = timesI(lctmp3);
1397 jit_ins_label( case_5 );
1404 RComplexREG<WordREG<REALT> > lctmp5;
1405 RScalarREG<WordREG<REALT> > lrtmp5;
1407 for(
int i5 = 0; i5 < Nc; ++i5) {
1409 int elem_ij5 = (i5+Nc)*(i5+Nc-1)/2;
1411 for(
int j5 = 0; j5 < Nc; ++j5) {
1413 int elem_ji5 = (j5+Nc)*(j5+Nc-1)/2 + i5;
1416 lctmp5 = adj(tri_off_r.elem(0).elem(elem_ji5));
1417 lctmp5 -= tri_off_r.elem(0).elem(elem_ij5);
1418 lctmp5 += adj(tri_off_r.elem(1).elem(elem_ji5));
1419 lctmp5 -= tri_off_r.elem(1).elem(elem_ij5);
1421 B_j.elem().elem(i5,j5) = lctmp5;
1430 jit_ins_label( case_6 );
1437 RComplexREG<WordREG<REALT> > lctmp6;
1438 RScalarREG<WordREG<REALT> > lrtmp6;
1440 for(
int i6 = 0; i6 < Nc; ++i6) {
1442 int elem_ij6 = (i6+Nc)*(i6+Nc-1)/2;
1444 for(
int j6 = 0; j6 < Nc; ++j6) {
1446 int elem_ji6 = (j6+Nc)*(j6+Nc-1)/2 + i6;
1448 lctmp6 = adj(tri_off_r.elem(0).elem(elem_ji6));
1449 lctmp6 += tri_off_r.elem(0).elem(elem_ij6);
1450 lctmp6 += adj(tri_off_r.elem(1).elem(elem_ji6));
1451 lctmp6 += tri_off_r.elem(1).elem(elem_ij6);
1453 B_j.elem().elem(i6,j6) = timesMinusI(lctmp6);
1462 jit_ins_label( case_9 );
1469 RComplexREG<WordREG<REALT> > lctmp9;
1470 RScalarREG<WordREG<REALT> > lrtmp9;
1472 for(
int i9 = 0; i9 < Nc; ++i9) {
1474 int elem_ij9 = (i9+Nc)*(i9+Nc-1)/2;
1476 for(
int j9 = 0; j9 < Nc; ++j9) {
1478 int elem_ji9 = (j9+Nc)*(j9+Nc-1)/2 + i9;
1480 lctmp9 = adj(tri_off_r.elem(0).elem(elem_ji9));
1481 lctmp9 += tri_off_r.elem(0).elem(elem_ij9);
1482 lctmp9 -= adj(tri_off_r.elem(1).elem(elem_ji9));
1483 lctmp9 -= tri_off_r.elem(1).elem(elem_ij9);
1485 B_j.elem().elem(i9,j9) = timesI(lctmp9);
1494 jit_ins_label( case_10 );
1501 RComplexREG<WordREG<REALT> > lctmp10;
1502 RScalarREG<WordREG<REALT> > lrtmp10;
1504 for(
int i10 = 0; i10 < Nc; ++i10) {
1506 int elem_ij10 = (i10+Nc)*(i10+Nc-1)/2;
1508 for(
int j10 = 0; j10 < Nc; ++j10) {
1510 int elem_ji10 = (j10+Nc)*(j10+Nc-1)/2 + i10;
1512 lctmp10 = adj(tri_off_r.elem(0).elem(elem_ji10));
1513 lctmp10 -= tri_off_r.elem(0).elem(elem_ij10);
1514 lctmp10 -= adj(tri_off_r.elem(1).elem(elem_ji10));
1515 lctmp10 += tri_off_r.elem(1).elem(elem_ij10);
1517 B_j.elem().elem(i10,j10) = lctmp10;
1526 jit_ins_label( case_12 );
1534 RComplexREG<WordREG<REALT> > lctmp12;
1535 RScalarREG<WordREG<REALT> > lr_zero12;
1536 RScalarREG<WordREG<REALT> > lrtmp12;
1540 for(
int i12 = 0; i12 < Nc; ++i12) {
1542 lrtmp12 = tri_dia_r.elem(0).elem(i12);
1543 lrtmp12 -= tri_dia_r.elem(0).elem(i12+Nc);
1544 lrtmp12 -= tri_dia_r.elem(1).elem(i12);
1545 lrtmp12 += tri_dia_r.elem(1).elem(i12+Nc);
1546 B_j.elem().elem(i12,i12) = cmplx(lr_zero12,lrtmp12);
1551 for(
int i12 = 1; i12 < Nc; ++i12) {
1553 int elem_ijb12 = (i12+Nc)*(i12+Nc-1)/2 + Nc;
1555 for(
int j12 = 0; j12 < i12; ++j12) {
1557 lctmp12 = tri_off_r.elem(0).elem(elem_ij12);
1558 lctmp12 -= tri_off_r.elem(0).elem(elem_ijb12);
1559 lctmp12 -= tri_off_r.elem(1).elem(elem_ij12);
1560 lctmp12 += tri_off_r.elem(1).elem(elem_ijb12);
1562 B_j.elem().elem(i12,j12) = timesI(lctmp12);
1563 B_j.elem().elem(j12,i12) = timesI(adj(lctmp12));
1572 return jit_get_cufunction(
"ptx_triacntr.ptx");
1578 template<
typename T,
typename U>
1585 if ( mat < 0 || mat > 15 )
1587 QDPIO::cerr << __func__ <<
": Gamma out of range: mat = " << mat << std::endl;
1593 static CUfunction
function;
1595 if (
function == NULL)
1596 function = function_triacntr_build<U>( B, tri_dia, tri_off, mat, rb[
cb] );
1605 template<
typename T,
typename U>
1611 if( param.anisoParam.anisoP ) {
1612 if (
mu==param.anisoParam.t_dir ||
nu == param.anisoParam.t_dir) {
1613 return param.clovCoeffT;
1617 return param.clovCoeffR;
1623 return param.clovCoeffR;
1631 template<
typename T,
typename X,
typename Y>
1639 if (!
s.hasOrderedRep())
1640 QDP_error_exit(
"clover on subset with unordered representation not implemented");
1644 AddressLeaf addr_leaf;
1646 int junk_0 = forEach(
chi, addr_leaf, NullCombine());
1647 int junk_1 = forEach(
psi, addr_leaf, NullCombine());
1648 int junk_2 = forEach(tri_dia, addr_leaf, NullCombine());
1649 int junk_3 = forEach(tri_off, addr_leaf, NullCombine());
1655 std::vector<void*> addr;
1657 addr.push_back( &lo );
1660 addr.push_back( &hi );
1663 int addr_dest=addr.size();
1664 for(
int i=0;
i < addr_leaf.addr.size(); ++
i) {
1665 addr.push_back( &addr_leaf.addr[
i] );
1669 jit_launch(
function,
s.numSiteTable(),addr);
1675 template<
typename T,
typename X,
typename Y>
1687 jit_start_new_function();
1689 jit_value r_lo = jit_add_param( jit_ptx_type::s32 );
1690 jit_value r_hi = jit_add_param( jit_ptx_type::s32 );
1692 jit_value r_idx_thread = jit_geom_get_linear_th_idx();
1694 jit_value r_out_of_range = jit_ins_gt( r_idx_thread , jit_ins_sub( r_hi , r_lo ) );
1695 jit_ins_exit( r_out_of_range );
1697 jit_value r_idx = jit_ins_add( r_lo , r_idx_thread );
1699 ParamLeaf param_leaf( r_idx );
1702 typedef typename LeafFunctor<T, ParamLeaf>::Type_t TJIT;
1703 TJIT chi_jit(forEach(
chi, param_leaf, TreeCombine()));
1704 TJIT psi_jit(forEach(
psi, param_leaf, TreeCombine()));
1705 auto& chi_j = chi_jit.elem(JitDeviceLayout::Coalesced);
1707 typename REGType< typename TJIT::Subtype_t >::Type_t psi_r;
1708 psi_r.setup( psi_jit.elem(JitDeviceLayout::Coalesced) );
1710 typename REGType< typename TJIT::Subtype_t >::Type_t chi_r;
1713 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
1714 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
1716 typename REGType< typename XJIT::Subtype_t >::Type_t tri_dia_r;
1717 tri_dia_r.setup( tri_dia_jit.elem(JitDeviceLayout::Coalesced) );
1722 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
1723 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
1724 auto& tri_off_j = tri_off_jit.elem(JitDeviceLayout::Coalesced);
1726 typename REGType< typename YJIT::Subtype_t >::Type_t tri_off_r;
1727 tri_off_r.setup( tri_off_jit.elem(JitDeviceLayout::Coalesced) );
1737 for(
int i = 0;
i <
n; ++
i)
1739 chi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3) = tri_dia_r.elem(0).elem(
i) * psi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3);
1742 chi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3) = tri_dia_r.elem(1).elem(
i) * psi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3);
1747 for(
int i = 0;
i <
n; ++
i)
1749 for(
int j = 0;
j <
i;
j++)
1751 chi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3) += tri_off_r.elem(0).elem(kij) * psi_r.elem((0*
n+
j)/3).elem((0*
n+
j)%3);
1754 chi_r.elem((0*
n+
j)/3).elem((0*
n+
j)%3) += conj(tri_off_r.elem(0).elem(kij)) * psi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3);
1757 chi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3) += tri_off_r.elem(1).elem(kij) * psi_r.elem((1*
n+
j)/3).elem((1*
n+
j)%3);
1760 chi_r.elem((1*
n+
j)/3).elem((1*
n+
j)%3) += conj(tri_off_r.elem(1).elem(kij)) * psi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3);
1769 return jit_get_cufunction(
"ptx_apply_clov.ptx");
1795 template<
typename T,
typename U>
1802 QDPIO::cerr << __func__ <<
": CloverTerm::apply requires Ns==4" << std::endl;
1808 static CUfunction
function;
1810 if (
function == NULL)
1816 (*this).getFermBC().modifyF(
chi, QDP::rb[
cb]);
1825 namespace QDPCloverEnv {
1826 template<
typename R,
typename TD,
typename TO>
1827 struct QUDAPackArgs {
1834 template<
typename R,
typename TD,
typename TO>
1839 multi1d<QUDAPackedClovSite<R> >& quda_array =
a->quda_array;
1841 const TD& tri_dia =
a->tri_dia;
1842 const TO& tri_off =
a->tri_off;
1844 const int idtab[15]={0,1,3,6,10,2,4,7,11,5,8,12,9,13,14};
1846 for(
int ssite=lo; ssite < hi; ++ssite) {
1847 int site = rb[
cb].siteTable()[ssite];
1849 for(
int i=0;
i < 6;
i++) {
1850 quda_array[site].diag1[
i] = tri_dia.elem(site).comp[0].diag[
i].elem().elem();
1855 for(
int col=0; col < Nc*Ns2-1; col++) {
1856 for(
int row=col+1; row < Nc*Ns2; row++) {
1858 int source_index = row*(row-1)/2 + col;
1860 quda_array[site].offDiag1[target_index][0] = tri_off.elem(site).comp[0].offd[source_index].real().elem();
1861 quda_array[site].offDiag1[target_index][1] = tri_off.elem(site).comp[0].offd[source_index].imag().elem();
1866 for(
int i=0;
i < 6;
i++) {
1867 quda_array[site].diag2[
i] = tri_dia.elem(site).comp[1].diag[
i].elem().elem();
1871 for(
int col=0; col < Nc*Ns2-1; col++) {
1872 for(
int row=col+1; row < Nc*Ns2; row++) {
1874 int source_index = row*(row-1)/2 + col;
1876 quda_array[site].offDiag2[target_index][0] = tri_off.elem(site).comp[1].offd[source_index].real().elem();
1877 quda_array[site].offDiag2[target_index][1] = tri_off.elem(site).comp[1].offd[source_index].imag().elem();
1882 QDPIO::cout <<
"\n";
1886 template<
typename T,
typename U>
1889 typedef typename WordType<T>::Type_t
REALT;
1890 int num_sites = rb[
cb].siteTable().size();
1892 typedef OLattice<PComp<PTriDia<RScalar <Word<REALT> > > > >
TD;
1893 typedef OLattice<PComp<PTriOff<RComplex<Word<REALT> > > > > TO;
1899 dispatch_to_threads(num_sites, args, QDPCloverEnv::qudaPackSiteLoop<REALT,TD,TO>);
1907 template<
typename T,
typename U>
1911 QDP_error_exit(
"PTXCloverTermT<T,U>::applySite(T& chi, const T& psi,..) not implemented ");
Base class for all fermion action boundary conditions.
Support class for fermion actions and linear operators.
Class for counted reference semantics.
void choles(int cb)
Computes the inverse of the term on cb using Cholesky.
OLattice< PScalar< PScalar< RScalar< Word< REALT > > > > > LatticeREAL
const multi1d< U > & getU() const
Get the u field.
multi1d< bool > choles_done
CloverFermActParams param
void applySite(T &chi, const T &psi, enum PlusMinus isign, int site) const
void makeClov(const multi1d< U > &f, const RealT &diag_mass)
Create the clover term on cb.
Real getCloverCoeff(int mu, int nu) const
Calculates Tr_D ( Gamma_mat L )
Handle< FermBC< T, multi1d< U >, multi1d< U > > > fbc
OLattice< PComp< PTriDia< RScalar< Word< REALT > > > > > tri_dia
const FermBC< T, multi1d< U >, multi1d< U > > & getFermBC() const
Return the fermion BC object for this linear operator.
Double cholesDet(int cb) const
Computes the inverse of the term on cb using Cholesky.
WordType< T >::Type_t REALT
void ldagdlinv(LatticeREAL &tr_log_diag, int cb)
Invert the clover term on cb.
OScalar< PScalar< PScalar< RScalar< Word< REALT > > > > > RealT
void packForQUDA(multi1d< QUDAPackedClovSite< REALT > > &quda_pack, int cb) const
PACK UP the Clover term for QUDA library:
~PTXCloverTermT()
No real need for cleanup here.
OLattice< PComp< PTriOff< RComplex< Word< REALT > > > > > tri_off
void create(Handle< FermState< T, multi1d< U >, multi1d< U > > > fs, const CloverFermActParams ¶m_)
Creation routine.
void apply(T &chi, const T &psi, enum PlusMinus isign, int cb) const
void triacntr(U &B, int mat, int cb) const
Calculates Tr_D ( Gamma_mat L )
PTXCloverTermT()
Empty constructor. Must use create later.
const double & get() const
static PackForQUDATimer & Instance()
Parameters for Clover fermion action.
Clover term linear operator.
void block(LatticeColorMatrix &u_block, const multi1d< LatticeColorMatrix > &u, int mu, int bl_level, const Real &BlkAccu, int BlkMax, int j_decay)
Construct block links.
void function_triacntr_exec(const JitFunction &function, U &B, const X &tri_dia, const Y &tri_off, int mat, const Subset &s)
TRIACNTR.
Calculates the antihermitian field strength tensor iF(mu,nu)
void qudaPackSiteLoop(int lo, int hi, int myId, QUDAPackArgs< R, TD, TO > *a)
Asqtad Staggered-Dirac operator.
QDP_error_exit("too many BiCG iterations", n_count, rsd_sq, cp, c, re_rvr, im_rvr, re_a, im_a, re_b, im_b)
void function_make_clov_exec(const JitFunction &function, const RealT &diag_mass, const U &f0, const U &f1, const U &f2, const U &f3, const U &f4, const U &f5, X &tri_dia, Y &tri_off)
static multi1d< LatticeColorMatrix > u
void function_ldagdlinv_exec(const JitFunction &function, T &tr_log_diag, X &tri_dia, Y &tri_off, const Subset &s)
void function_apply_clov_exec(const JitFunction &function, T &chi, const T &psi, const X &tri_dia, const Y &tri_off, const Subset &s)
LinOpSysSolverMGProtoClover::T T
PTXCloverTermT< LatticeFermion, LatticeColorMatrix > PTXCloverTerm
void function_triacntr_build(JitFunction &func, const U &B, const X &tri_dia, const Y &tri_off, int mat, const Subset &s)
void function_ldagdlinv_build(JitFunction &func, const T &tr_log_diag, const X &tri_dia, const Y &tri_off, const Subset &s)
multi1d< LatticeFermion > chi(Ncb)
PTXCloverTermT< LatticeFermionD, LatticeColorMatrixD > PTXCloverTermD
void mesField(multi1d< LatticeColorMatrixF > &f, const multi1d< LatticeColorMatrixF > &u)
Calculates the antihermitian field strength tensor iF(mu,nu)
void function_make_clov_build(JitFunction &func, const RealT &diag_mass, const U &f0, const U &f1, const U &f2, const U &f3, const U &f4, const U &f5, const X &tri_dia, const Y &tri_off)
PTXCloverTermT< LatticeFermionF, LatticeColorMatrixF > PTXCloverTermF
void function_apply_clov_build(JitFunction &func, const T &chi, const T &psi, const X &tri_dia, const Y &tri_off, const Subset &s)
multi1d< LatticeFermion > s(Ncb)
const T1 const T2 const T3 & f3
const T1 const T2 const T3 const T4 const T5 & f5
const T1 const T2 const T3 const T4 & f4
FloatingPoint< double > Double
Support class for fermion actions and linear operators.
Params for clover ferm acts.
multi1d< QUDAPackedClovSite< R > > & quda_array
PCompJIT< typename JITType< T >::Type_t > Type_t
PCompJIT< typename JITType< T >::Type_t > Type_t
PTriDiaJIT< typename JITType< T >::Type_t > Type_t
PTriDiaJIT< typename JITType< T >::Type_t > Type_t
PTriOffJIT< typename JITType< T >::Type_t > Type_t
PTriOffJIT< typename JITType< T >::Type_t > Type_t
PCompJIT & operator=(const PCompREG< T1 > &rhs)
const T & elem(int i) const
void setup(const PCompJIT< typename JITType< T >::Type_t > &rhs)
const T & elem(int i) const
const T & elem(int i) const
PTriDiaJIT & operator=(const PTriDiaREG< T1 > &rhs)
const T & elem(int i) const
void setup(const PTriDiaJIT< typename JITType< T >::Type_t > &rhs)
PTriOffJIT & operator=(const PTriOffREG< T1 > &rhs)
const T & elem(int i) const
const T & elem(int i) const
void setup(const PTriOffJIT< typename JITType< T >::Type_t > &rhs)
PCompREG< typename REGType< T >::Type_t > Type_t
PTriDiaREG< typename REGType< T >::Type_t > Type_t
PTriOffREG< typename REGType< T >::Type_t > Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
multi1d< LatticeColorMatrix > U