6 #ifndef __clover_term_llvm_w_h__
7 #define __clover_term_llvm_w_h__
9 #warning "Using QPD-JIT/LLVM clover"
55 inline T&
elem(
int i) {
return this->arrayF(
i); }
56 inline const T&
elem(
int i)
const {
return this->arrayF(
i); }
64 F[0].setup( rhs.elem(0) );
65 F[1].setup( rhs.elem(1) );
68 inline const T&
elem(
int i)
const {
return F[
i]; }
92 typedef typename WordType<T>::Type_t
Type_t;
98 typedef typename WordType<T>::Type_t
Type_t;
122 for (
int i = 0 ;
i < 2 * Nc ;
i++ )
127 inline T&
elem(
int i) {
return this->arrayF(
i); }
128 inline const T&
elem(
int i)
const {
return this->arrayF(
i); }
136 for (
int i=0;
i<2*Nc;++
i)
164 typedef typename WordType<T>::Type_t
Type_t;
170 typedef typename WordType<T>::Type_t
Type_t;
194 for (
int i = 0 ;
i < 2*Nc*Nc-Nc ;
i++ )
199 inline T&
elem(
int i) {
return this->arrayF(
i); }
200 inline const T&
elem(
int i)
const {
return this->arrayF(
i); }
208 for (
int i=0;
i<2*Nc*Nc-Nc;++
i)
236 typedef typename WordType<T>::Type_t
Type_t;
242 typedef typename WordType<T>::Type_t
Type_t;
246 #if defined(QDP_USE_PROFILING)
248 struct LeafFunctor<
PComp<
T>, PrintTag>
251 static int apply(
const PComp<T> &
s,
const PrintTag &f)
258 struct LeafFunctor<
PTriDia<
T>, PrintTag>
261 static int apply(
const PTriDia<T> &
s,
const PrintTag &f)
268 struct LeafFunctor<
PTriOff<
T>, PrintTag>
271 static int apply(
const PTriOff<T> &
s,
const PrintTag &f)
296 template<
typename T,
typename U>
301 typedef typename WordType<T>::Type_t
REALT;
303 typedef OLattice< PScalar< PScalar< RScalar< Word< REALT> > > > >
LatticeREAL;
304 typedef OScalar< PScalar< PScalar< RScalar< Word< REALT> > > > >
RealT;
367 using DiagType = OLattice<PComp<PTriDia<RScalar <Word<REALT> > > > >;
368 using OffDiagType = OLattice<PComp<PTriOff<RComplex<Word<REALT> > > > >;
386 const multi1d<U>&
getU()
const {
return u;}
401 OLattice<PComp<PTriDia<RScalar <Word<REALT> > > > >
tri_dia;
402 OLattice<PComp<PTriOff<RComplex<Word<REALT> > > > >
tri_off;
408 template<
typename T,
typename U>
412 template<
typename T,
typename U>
424 fbc = fs->getFermBC();
428 if (fbc.operator->() == 0) {
429 QDPIO::cerr <<
"LLVMCloverTerm: error: fbc is null" << std::endl;
434 RealT ff = where(param.anisoParam.anisoP, Real(1) / param.anisoParam.xi_0, Real(1));
435 param.clovCoeffR *= Real(0.5) * ff;
436 param.clovCoeffT *= Real(0.5);
446 RealT ff = where(param.anisoParam.anisoP, param.anisoParam.nu / param.anisoParam.xi_0, Real(1));
447 diag_mass = 1 + (
Nd-1)*ff + param.Mass;
456 choles_done.resize(rb.numSubsets());
457 for(
int i=0;
i < rb.numSubsets();
i++) {
471 template<
typename T,
typename U>
482 fbc = fs->getFermBC();
486 if (fbc.operator->() == 0) {
487 QDPIO::cerr <<
"LLVMCloverTerm: error: fbc is null" << std::endl;
492 RealT ff = where(param.anisoParam.anisoP, Real(1) / param.anisoParam.xi_0, Real(1));
493 param.clovCoeffR *=
RealT(0.5) * ff;
494 param.clovCoeffT *=
RealT(0.5);
504 RealT ff = where(param.anisoParam.anisoP, param.anisoParam.nu / param.anisoParam.xi_0, Real(1));
505 diag_mass = 1 + (
Nd-1)*ff + param.Mass;
511 makeClov(f, diag_mass);
513 choles_done.resize(rb.numSubsets());
514 for(
int i=0;
i < rb.numSubsets();
i++) {
515 choles_done[
i] =
false;
604 template<
typename RealT,
typename U,
typename X,
typename Y>
606 const RealT& diag_mass,
616 AddressLeaf addr_leaf(all);
618 forEach(diag_mass, addr_leaf, NullCombine());
619 forEach(f0, addr_leaf, NullCombine());
620 forEach(
f1, addr_leaf, NullCombine());
621 forEach(
f2, addr_leaf, NullCombine());
622 forEach(
f3, addr_leaf, NullCombine());
623 forEach(
f4, addr_leaf, NullCombine());
624 forEach(
f5, addr_leaf, NullCombine());
625 forEach(tri_dia, addr_leaf, NullCombine());
626 forEach(tri_off, addr_leaf, NullCombine());
628 jit_dispatch(
function.
func().at(0),Layout::sitesOnNode(),getDataLayoutInnerSize(),
true,0,addr_leaf);
633 template<
typename RealT,
typename U,
typename X,
typename Y>
635 const RealT& diag_mass,
649 ParamLeaf param_leaf;
651 typedef typename WordType<RealT>::Type_t REALT;
653 typedef typename LeafFunctor<RealT, ParamLeaf>::Type_t RealTJIT;
654 RealTJIT diag_mass_jit(forEach(diag_mass, param_leaf, TreeCombine()));
656 typedef typename LeafFunctor<U, ParamLeaf>::Type_t UJIT;
657 UJIT f0_jit(forEach(f0, param_leaf, TreeCombine()));
658 UJIT f1_jit(forEach(
f1, param_leaf, TreeCombine()));
659 UJIT f2_jit(forEach(
f2, param_leaf, TreeCombine()));
660 UJIT f3_jit(forEach(
f3, param_leaf, TreeCombine()));
661 UJIT f4_jit(forEach(
f4, param_leaf, TreeCombine()));
662 UJIT f5_jit(forEach(
f5, param_leaf, TreeCombine()));
664 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
665 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
667 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
668 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
670 IndexDomainVector idx = loop.getIdx();
672 typename UJIT::Subtype_t& f0_j = f0_jit.elem(JitDeviceLayout::LayoutCoalesced , idx );
673 typename UJIT::Subtype_t& f1_j = f1_jit.elem(JitDeviceLayout::LayoutCoalesced , idx );
674 typename UJIT::Subtype_t& f2_j = f2_jit.elem(JitDeviceLayout::LayoutCoalesced , idx );
675 typename UJIT::Subtype_t& f3_j = f3_jit.elem(JitDeviceLayout::LayoutCoalesced , idx );
676 typename UJIT::Subtype_t& f4_j = f4_jit.elem(JitDeviceLayout::LayoutCoalesced , idx );
677 typename UJIT::Subtype_t& f5_j = f5_jit.elem(JitDeviceLayout::LayoutCoalesced , idx );
679 typename XJIT::Subtype_t& tri_dia_j = tri_dia_jit.elem(JitDeviceLayout::LayoutCoalesced , idx );
680 typename YJIT::Subtype_t& tri_off_j = tri_off_jit.elem(JitDeviceLayout::LayoutCoalesced , idx );
682 typename REGType< typename RealTJIT::Subtype_t >::Type_t diag_mass_reg;
683 diag_mass_reg.setup( diag_mass_jit.elem() );
688 for(
int jj = 0; jj < 2; jj++) {
689 for(
int ii = 0; ii < 2*Nc; ii++) {
690 tri_dia_j.elem(jj).elem(ii) = diag_mass_reg.elem().elem();
696 RComplexREG<WordREG<REALT> > E_minus;
697 RComplexREG<WordREG<REALT> > B_minus;
698 RComplexREG<WordREG<REALT> > ctmp_0;
699 RComplexREG<WordREG<REALT> > ctmp_1;
700 RScalarREG<WordREG<REALT> > rtmp_0;
701 RScalarREG<WordREG<REALT> > rtmp_1;
704 for(
int i = 0;
i < Nc; ++
i) {
705 ctmp_0 = f5_j.elem().elem(
i,
i);
706 ctmp_0 -= f0_j.elem().elem(
i,
i);
707 rtmp_0 = imag(ctmp_0);
708 tri_dia_j.elem(0).elem(
i) += rtmp_0;
710 tri_dia_j.elem(0).elem(
i+Nc) -= rtmp_0;
712 ctmp_1 = f5_j.elem().elem(
i,
i);
713 ctmp_1 += f0_j.elem().elem(
i,
i);
714 rtmp_1 = imag(ctmp_1);
715 tri_dia_j.elem(1).elem(
i) -= rtmp_1;
717 tri_dia_j.elem(1).elem(
i+Nc) += rtmp_1;
720 for(
int i = 1;
i < Nc; ++
i) {
721 for(
int j = 0;
j <
i; ++
j) {
724 int elem_tmp = (
i+Nc)*(
i+Nc-1)/2 +
j+Nc;
726 ctmp_0 = f0_j.elem().elem(
i,
j);
727 ctmp_0 -= f5_j.elem().elem(
i,
j);
728 tri_off_j.elem(0).elem(
elem_ij) = timesI(ctmp_0);
730 zero_rep( tri_off_j.elem(0).elem(elem_tmp) );
731 tri_off_j.elem(0).elem(elem_tmp) -= tri_off_j.elem(0).elem(
elem_ij);
733 ctmp_1 = f5_j.elem().elem(
i,
j);
734 ctmp_1 += f0_j.elem().elem(
i,
j);
735 tri_off_j.elem(1).elem(
elem_ij) = timesI(ctmp_1);
737 zero_rep( tri_off_j.elem(1).elem(elem_tmp) );
738 tri_off_j.elem(1).elem(elem_tmp) -= tri_off_j.elem(1).elem(
elem_ij);
742 for(
int i = 0;
i < Nc; ++
i) {
743 for(
int j = 0;
j < Nc; ++
j) {
748 E_minus = f2_j.elem().elem(
i,
j);
749 E_minus = timesI( E_minus );
751 E_minus += f4_j.elem().elem(
i,
j);
754 B_minus = f3_j.elem().elem(
i,
j);
755 B_minus = timesI( B_minus );
757 B_minus -= f1_j.elem().elem(
i,
j);
759 tri_off_j.elem(0).elem(
elem_ij) = B_minus - E_minus;
761 tri_off_j.elem(1).elem(
elem_ij) = E_minus + B_minus;
769 func.func().push_back( jit_function_epilogue_get(
"jit_make_clov.ptx") );
776 template<
typename T,
typename U>
782 QDPIO::cerr << __func__ <<
": expecting Nd==4" << std::endl;
787 QDPIO::cerr << __func__ <<
": expecting Ns==4" << std::endl;
791 U f0 = f[0] * getCloverCoeff(0,1);
792 U f1 = f[1] * getCloverCoeff(0,2);
793 U f2 = f[2] * getCloverCoeff(0,3);
794 U f3 = f[3] * getCloverCoeff(1,2);
795 U f4 = f[4] * getCloverCoeff(1,3);
796 U f5 = f[5] * getCloverCoeff(2,3);
800 static JitFunction
function;
802 if (!
function.built()) {
803 QDPIO::cout <<
"Building JIT make clover function\n";
818 template<
typename T,
typename U>
826 ldagdlinv(tr_log_diag_,
cb);
838 template<
typename T,
typename U>
843 if( choles_done[
cb] ==
false )
845 QDPIO::cout << __func__ <<
": Error: you have not done the Cholesky.on this operator on this subset" << std::endl;
846 QDPIO::cout <<
"You sure you should not be asking invclov?" << std::endl;
855 return sum(tr_log_diag_, rb[
cb]);
860 template<
typename T,
typename X,
typename Y>
867 if (!
s.hasOrderedRep())
868 QDP_error_exit(
"ldagdlinv on subset with unordered representation not implemented");
870 AddressLeaf addr_leaf(
s);
872 forEach(tr_log_diag, addr_leaf, NullCombine());
873 forEach(tri_dia, addr_leaf, NullCombine());
874 forEach(tri_off, addr_leaf, NullCombine());
876 jit_dispatch(
function.
func().at(0),
s.numSiteTable(),getDataLayoutInnerSize(),
s.hasOrderedRep(),
s.start(),addr_leaf);
883 template<
typename U,
typename T,
typename X,
typename Y>
885 const T& tr_log_diag,
890 typedef typename WordType<U>::Type_t REALT;
896 ParamLeaf param_leaf;
898 typedef typename LeafFunctor<T, ParamLeaf>::Type_t TJIT;
899 TJIT tr_log_diag_jit(forEach(tr_log_diag, param_leaf, TreeCombine()));
901 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
902 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
904 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
905 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
907 IndexDomainVector idx = loop.getIdx();
909 typename TJIT::Subtype_t& tr_log_diag_j = tr_log_diag_jit.elem(JitDeviceLayout::LayoutCoalesced,idx);
910 typename XJIT::Subtype_t& tri_dia_j = tri_dia_jit.elem(JitDeviceLayout::LayoutCoalesced,idx);
911 typename YJIT::Subtype_t& tri_off_j = tri_off_jit.elem(JitDeviceLayout::LayoutCoalesced,idx);
913 typename REGType< typename XJIT::Subtype_t >::Type_t tri_dia_r;
914 typename REGType< typename YJIT::Subtype_t >::Type_t tri_off_r;
916 tri_dia_r.setup( tri_dia_j );
917 tri_off_r.setup( tri_off_j );
920 RScalarREG<WordREG<REALT> > zip;
928 RScalarREG<WordREG<REALT> > inv_d[6] ;
929 RComplexREG<WordREG<REALT> > inv_offd[15] ;
930 RComplexREG<WordREG<REALT> > v[6] ;
931 RScalarREG<WordREG<REALT> > diag_g[6] ;
933 for(
int i=0;
i < N;
i++) {
934 inv_d[
i] = tri_dia_r.elem(
block).elem(
i);
937 for(
int i=0;
i < 15;
i++) {
938 inv_offd[
i] = tri_off_r.elem(
block).elem(
i);
942 for(
int j=0;
j < N; ++
j) {
944 for(
int i=0;
i <
j;
i++) {
947 RComplexREG<WordREG<REALT> > A_ii = cmplx( inv_d[
i], zip );
952 v[
j] = cmplx(inv_d[
j],zip);
954 for(
int k=0;
k <
j;
k++) {
955 int elem_jk =
j*(
j-1)/2 +
k;
956 v[
j] -= inv_offd[elem_jk]*v[
k];
959 inv_d[
j] = real( v[
j] );
961 for(
int k=
j+1;
k < N;
k++) {
962 int elem_kj =
k*(
k-1)/2 +
j;
963 for(
int l=0;
l <
j;
l++) {
964 int elem_kl =
k*(
k-1)/2 +
l;
965 inv_offd[elem_kj] -= inv_offd[elem_kl] * v[
l];
967 inv_offd[elem_kj] /= v[
j];
973 RScalarREG<WordREG<REALT> >
one(1.0);
976 for(
int i=0;
i < N;
i++) {
977 diag_g[
i] =
one/inv_d[
i];
983 tr_log_diag_j.elem().elem() += log(fabs(inv_d[
i]));
987 if( inv_d[
i].elem() < 0 ) {
1007 RComplexREG<WordREG<REALT> >
sum;
1008 for(
int k = 0;
k < N; ++
k) {
1010 for(
int i = 0;
i <
k; ++
i) {
1017 v[
k] = cmplx(diag_g[
k],zip);
1019 for(
int i =
k+1;
i < N; ++
i) {
1022 for(
int j =
k;
j <
i; ++
j) {
1038 for(
int i = N-2; (int)
i >= (
int)
k; --
i) {
1039 for(
int j =
i+1;
j < N; ++
j) {
1047 inv_d[
k] = real(v[
k]);
1048 for(
int i =
k+1;
i < N; ++
i) {
1050 int elem_ik =
i*(
i-1)/2+
k;
1051 inv_offd[elem_ik] = v[
i];
1056 for(
int i=0;
i < N;
i++) {
1057 tri_dia_j.elem(
block).elem(
i) = inv_d[
i];
1059 for(
int i=0;
i < 15;
i++) {
1060 tri_off_j.elem(
block).elem(
i) = inv_offd[
i];
1067 func.func().push_back( jit_function_epilogue_get(
"jit_ldagdlinv.ptx") );
1075 template<
typename T,
typename U>
1082 QDPIO::cerr << __func__ <<
": Matrix is too small" << std::endl;
1087 tr_log_diag[rb[
cb]] =
zero;
1091 static JitFunction
function;
1093 if (!
function.built()) {
1094 QDPIO::cout <<
"Building JIT ldagdlinv\n";
1095 function_ldagdlinv_build<U>(
function, tr_log_diag, tri_dia, tri_off, rb[
cb] );
1102 choles_done[
cb] =
true;
1156 template<
typename U,
typename X,
typename Y>
1164 if (!
s.hasOrderedRep())
1165 QDP_error_exit(
"triacntr on subset with unordered representation not implemented");
1167 AddressLeaf addr_leaf(
s);
1169 addr_leaf.setLit( mat );
1171 forEach(B, addr_leaf, NullCombine());
1172 forEach(tri_dia, addr_leaf, NullCombine());
1173 forEach(tri_off, addr_leaf, NullCombine());
1175 jit_dispatch(
function.
func().at(0),
s.numSiteTable(),getDataLayoutInnerSize(),
s.hasOrderedRep(),
s.start(),addr_leaf);
1181 template<
typename U,
typename X,
typename Y>
1191 typedef typename WordType<U>::Type_t REALT;
1195 ParamRef p_mat = llvm_add_param<int>();
1197 ParamLeaf param_leaf;
1199 typedef typename LeafFunctor<U, ParamLeaf>::Type_t UJIT;
1200 UJIT B_jit(forEach(B, param_leaf, TreeCombine()));
1202 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
1203 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
1205 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
1206 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
1208 llvm::Value * r_mat = llvm_derefParam( p_mat );
1210 IndexDomainVector idx = loop.getIdx();
1212 typename UJIT::Subtype_t& B_j = B_jit.elem(JitDeviceLayout::LayoutCoalesced,idx);
1213 typename XJIT::Subtype_t& tri_dia_j = tri_dia_jit.elem(JitDeviceLayout::LayoutCoalesced,idx);
1214 typename YJIT::Subtype_t& tri_off_j = tri_off_jit.elem(JitDeviceLayout::LayoutCoalesced,idx);
1216 typename REGType< typename XJIT::Subtype_t >::Type_t tri_dia_r;
1217 typename REGType< typename YJIT::Subtype_t >::Type_t tri_off_r;
1219 tri_dia_r.setup( tri_dia_j );
1220 tri_off_r.setup( tri_off_j );
1222 llvm::BasicBlock * case_0 = llvm_new_basic_block();
1223 llvm::BasicBlock * case_3 = llvm_new_basic_block();
1224 llvm::BasicBlock * case_5 = llvm_new_basic_block();
1225 llvm::BasicBlock * case_6 = llvm_new_basic_block();
1226 llvm::BasicBlock * case_9 = llvm_new_basic_block();
1227 llvm::BasicBlock * case_10 = llvm_new_basic_block();
1228 llvm::BasicBlock * case_12 = llvm_new_basic_block();
1229 llvm::BasicBlock * case_default = llvm_new_basic_block();
1231 llvm::SwitchInst * mat_sw = llvm_switch( r_mat , case_default );
1233 mat_sw->addCase( llvm_create_const_int(0) , case_0 );
1234 mat_sw->addCase( llvm_create_const_int(3) , case_3 );
1235 mat_sw->addCase( llvm_create_const_int(5) , case_5 );
1236 mat_sw->addCase( llvm_create_const_int(6) , case_6 );
1237 mat_sw->addCase( llvm_create_const_int(9) , case_9 );
1238 mat_sw->addCase( llvm_create_const_int(10) , case_10 );
1239 mat_sw->addCase( llvm_create_const_int(12) , case_12 );
1246 llvm_set_insert_point( case_0 );
1248 RComplexREG<WordREG<REALT> > lctmp0;
1249 RScalarREG< WordREG<REALT> > lr_zero0;
1250 RScalarREG< WordREG<REALT> > lrtmp0;
1254 for(
int i0 = 0; i0 < Nc; ++i0) {
1256 lrtmp0 = tri_dia_r.elem(0).elem(i0);
1257 lrtmp0 += tri_dia_r.elem(0).elem(i0+Nc);
1258 lrtmp0 += tri_dia_r.elem(1).elem(i0);
1259 lrtmp0 += tri_dia_r.elem(1).elem(i0+Nc);
1260 B_j.elem().elem(i0,i0) = cmplx(lrtmp0,lr_zero0);
1265 for(
int i0 = 1; i0 < Nc; ++i0) {
1267 int elem_ijb0 = (i0+Nc)*(i0+Nc-1)/2 + Nc;
1269 for(
int j0 = 0; j0 < i0; ++j0) {
1271 lctmp0 = tri_off_r.elem(0).elem(elem_ij0);
1272 lctmp0 += tri_off_r.elem(0).elem(elem_ijb0);
1273 lctmp0 += tri_off_r.elem(1).elem(elem_ij0);
1274 lctmp0 += tri_off_r.elem(1).elem(elem_ijb0);
1276 B_j.elem().elem(j0,i0) = lctmp0;
1277 B_j.elem().elem(i0,j0) = adj(lctmp0);
1283 llvm_branch(case_default);
1287 llvm_set_insert_point( case_3 );
1295 RComplexREG<WordREG<REALT> > lctmp3;
1296 RScalarREG<WordREG<REALT> > lr_zero3;
1297 RScalarREG<WordREG<REALT> > lrtmp3;
1301 for(
int i3 = 0; i3 < Nc; ++i3) {
1303 lrtmp3 = tri_dia_r.elem(0).elem(i3+Nc);
1304 lrtmp3 -= tri_dia_r.elem(0).elem(i3);
1305 lrtmp3 -= tri_dia_r.elem(1).elem(i3);
1306 lrtmp3 += tri_dia_r.elem(1).elem(i3+Nc);
1307 B_j.elem().elem(i3,i3) = cmplx(lr_zero3,lrtmp3);
1312 for(
int i3 = 1; i3 < Nc; ++i3) {
1314 int elem_ijb3 = (i3+Nc)*(i3+Nc-1)/2 + Nc;
1316 for(
int j3 = 0; j3 < i3; ++j3) {
1318 lctmp3 = tri_off_r.elem(0).elem(elem_ijb3);
1319 lctmp3 -= tri_off_r.elem(0).elem(elem_ij3);
1320 lctmp3 -= tri_off_r.elem(1).elem(elem_ij3);
1321 lctmp3 += tri_off_r.elem(1).elem(elem_ijb3);
1323 B_j.elem().elem(j3,i3) = timesI(adj(lctmp3));
1324 B_j.elem().elem(i3,j3) = timesI(lctmp3);
1330 llvm_branch(case_default);
1334 llvm_set_insert_point( case_5 );
1341 RComplexREG<WordREG<REALT> > lctmp5;
1342 RScalarREG<WordREG<REALT> > lrtmp5;
1344 for(
int i5 = 0; i5 < Nc; ++i5) {
1346 int elem_ij5 = (i5+Nc)*(i5+Nc-1)/2;
1348 for(
int j5 = 0; j5 < Nc; ++j5) {
1350 int elem_ji5 = (j5+Nc)*(j5+Nc-1)/2 + i5;
1353 lctmp5 = adj(tri_off_r.elem(0).elem(elem_ji5));
1354 lctmp5 -= tri_off_r.elem(0).elem(elem_ij5);
1355 lctmp5 += adj(tri_off_r.elem(1).elem(elem_ji5));
1356 lctmp5 -= tri_off_r.elem(1).elem(elem_ij5);
1358 B_j.elem().elem(i5,j5) = lctmp5;
1363 llvm_branch(case_default);
1367 llvm_set_insert_point( case_6 );
1374 RComplexREG<WordREG<REALT> > lctmp6;
1375 RScalarREG<WordREG<REALT> > lrtmp6;
1377 for(
int i6 = 0; i6 < Nc; ++i6) {
1379 int elem_ij6 = (i6+Nc)*(i6+Nc-1)/2;
1381 for(
int j6 = 0; j6 < Nc; ++j6) {
1383 int elem_ji6 = (j6+Nc)*(j6+Nc-1)/2 + i6;
1385 lctmp6 = adj(tri_off_r.elem(0).elem(elem_ji6));
1386 lctmp6 += tri_off_r.elem(0).elem(elem_ij6);
1387 lctmp6 += adj(tri_off_r.elem(1).elem(elem_ji6));
1388 lctmp6 += tri_off_r.elem(1).elem(elem_ij6);
1390 B_j.elem().elem(i6,j6) = timesMinusI(lctmp6);
1395 llvm_branch(case_default);
1399 llvm_set_insert_point( case_9 );
1406 RComplexREG<WordREG<REALT> > lctmp9;
1407 RScalarREG<WordREG<REALT> > lrtmp9;
1409 for(
int i9 = 0; i9 < Nc; ++i9) {
1411 int elem_ij9 = (i9+Nc)*(i9+Nc-1)/2;
1413 for(
int j9 = 0; j9 < Nc; ++j9) {
1415 int elem_ji9 = (j9+Nc)*(j9+Nc-1)/2 + i9;
1417 lctmp9 = adj(tri_off_r.elem(0).elem(elem_ji9));
1418 lctmp9 += tri_off_r.elem(0).elem(elem_ij9);
1419 lctmp9 -= adj(tri_off_r.elem(1).elem(elem_ji9));
1420 lctmp9 -= tri_off_r.elem(1).elem(elem_ij9);
1422 B_j.elem().elem(i9,j9) = timesI(lctmp9);
1427 llvm_branch(case_default);
1431 llvm_set_insert_point( case_10 );
1438 RComplexREG<WordREG<REALT> > lctmp10;
1439 RScalarREG<WordREG<REALT> > lrtmp10;
1441 for(
int i10 = 0; i10 < Nc; ++i10) {
1443 int elem_ij10 = (i10+Nc)*(i10+Nc-1)/2;
1445 for(
int j10 = 0; j10 < Nc; ++j10) {
1447 int elem_ji10 = (j10+Nc)*(j10+Nc-1)/2 + i10;
1449 lctmp10 = adj(tri_off_r.elem(0).elem(elem_ji10));
1450 lctmp10 -= tri_off_r.elem(0).elem(elem_ij10);
1451 lctmp10 -= adj(tri_off_r.elem(1).elem(elem_ji10));
1452 lctmp10 += tri_off_r.elem(1).elem(elem_ij10);
1454 B_j.elem().elem(i10,j10) = lctmp10;
1459 llvm_branch(case_default);
1463 llvm_set_insert_point( case_12 );
1471 RComplexREG<WordREG<REALT> > lctmp12;
1472 RScalarREG<WordREG<REALT> > lr_zero12;
1473 RScalarREG<WordREG<REALT> > lrtmp12;
1477 for(
int i12 = 0; i12 < Nc; ++i12) {
1479 lrtmp12 = tri_dia_r.elem(0).elem(i12);
1480 lrtmp12 -= tri_dia_r.elem(0).elem(i12+Nc);
1481 lrtmp12 -= tri_dia_r.elem(1).elem(i12);
1482 lrtmp12 += tri_dia_r.elem(1).elem(i12+Nc);
1483 B_j.elem().elem(i12,i12) = cmplx(lr_zero12,lrtmp12);
1488 for(
int i12 = 1; i12 < Nc; ++i12) {
1490 int elem_ijb12 = (i12+Nc)*(i12+Nc-1)/2 + Nc;
1492 for(
int j12 = 0; j12 < i12; ++j12) {
1494 lctmp12 = tri_off_r.elem(0).elem(elem_ij12);
1495 lctmp12 -= tri_off_r.elem(0).elem(elem_ijb12);
1496 lctmp12 -= tri_off_r.elem(1).elem(elem_ij12);
1497 lctmp12 += tri_off_r.elem(1).elem(elem_ijb12);
1499 B_j.elem().elem(i12,j12) = timesI(lctmp12);
1500 B_j.elem().elem(j12,i12) = timesI(adj(lctmp12));
1506 llvm_branch(case_default);
1509 llvm_set_insert_point( case_default );
1513 func.func().push_back( jit_function_epilogue_get(
"jit_triacntr.ptx") );
1519 template<
typename T,
typename U>
1526 if ( mat < 0 || mat > 15 )
1528 QDPIO::cerr << __func__ <<
": Gamma out of range: mat = " << mat << std::endl;
1534 static JitFunction
function;
1536 if (!
function.built()) {
1537 QDPIO::cout <<
"Building JIT triacntr\n";
1538 function_triacntr_build<U>(
function, B, tri_dia, tri_off, mat, rb[
cb] );
1548 template<
typename T,
typename U>
1554 if( param.anisoParam.anisoP ) {
1555 if (
mu==param.anisoParam.t_dir ||
nu == param.anisoParam.t_dir) {
1556 return param.clovCoeffT;
1560 return param.clovCoeffR;
1566 return param.clovCoeffR;
1574 template<
typename T,
typename X,
typename Y>
1582 if (!
s.hasOrderedRep())
1583 QDP_error_exit(
"clover on subset with unordered representation not implemented");
1587 AddressLeaf addr_leaf(
s);
1589 forEach(
chi, addr_leaf, NullCombine());
1590 forEach(
psi, addr_leaf, NullCombine());
1591 forEach(tri_dia, addr_leaf, NullCombine());
1592 forEach(tri_off, addr_leaf, NullCombine());
1594 jit_dispatch(
function.
func().at(0),
s.numSiteTable(),getDataLayoutInnerSize(),
s.hasOrderedRep(),
s.start(),addr_leaf);
1600 template<
typename T,
typename X,
typename Y>
1613 ParamLeaf param_leaf;
1615 typedef typename LeafFunctor<T, ParamLeaf>::Type_t TJIT;
1616 TJIT chi_jit(forEach(
chi, param_leaf, TreeCombine()));
1617 TJIT psi_jit(forEach(
psi, param_leaf, TreeCombine()));
1618 typename REGType< typename TJIT::Subtype_t >::Type_t psi_r;
1619 typename REGType< typename TJIT::Subtype_t >::Type_t chi_r;
1621 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
1622 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
1623 typename REGType< typename XJIT::Subtype_t >::Type_t tri_dia_r;
1625 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
1626 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
1627 typename REGType< typename YJIT::Subtype_t >::Type_t tri_off_r;
1629 IndexDomainVector idx = loop.getIdx();
1631 typename TJIT::Subtype_t& chi_j = chi_jit.elem(JitDeviceLayout::LayoutCoalesced,idx);
1632 psi_r.setup( psi_jit.elem(JitDeviceLayout::LayoutCoalesced,idx) );
1633 tri_dia_r.setup( tri_dia_jit.elem(JitDeviceLayout::LayoutCoalesced,idx) );
1634 tri_off_r.setup( tri_off_jit.elem(JitDeviceLayout::LayoutCoalesced,idx) );
1641 for(
int i = 0;
i <
n; ++
i)
1643 chi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3) = tri_dia_r.elem(0).elem(
i) * psi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3);
1646 chi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3) = tri_dia_r.elem(1).elem(
i) * psi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3);
1651 for(
int i = 0;
i <
n; ++
i)
1653 for(
int j = 0;
j <
i;
j++)
1655 chi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3) += tri_off_r.elem(0).elem(kij) * psi_r.elem((0*
n+
j)/3).elem((0*
n+
j)%3);
1658 chi_r.elem((0*
n+
j)/3).elem((0*
n+
j)%3) += conj(tri_off_r.elem(0).elem(kij)) * psi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3);
1661 chi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3) += tri_off_r.elem(1).elem(kij) * psi_r.elem((1*
n+
j)/3).elem((1*
n+
j)%3);
1664 chi_r.elem((1*
n+
j)/3).elem((1*
n+
j)%3) += conj(tri_off_r.elem(1).elem(kij)) * psi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3);
1675 func.func().push_back( jit_function_epilogue_get(
"jit_apply_clov.ptx") );
1701 template<
typename T,
typename U>
1708 QDPIO::cerr << __func__ <<
": CloverTerm::apply requires Ns==4" << std::endl;
1714 static JitFunction
function;
1716 if (!
function.built()) {
1717 QDPIO::cout <<
"Building JIT clover apply function\n";
1724 (*this).getFermBC().modifyF(
chi, QDP::rb[
cb]);
1733 namespace QDPCloverEnv {
1734 template<
typename R,
typename TD,
typename TO>
1742 template<
typename R,
typename TD,
typename TO>
1747 multi1d<QUDAPackedClovSite<R> >& quda_array =
a->quda_array;
1749 const TD& tri_dia =
a->tri_dia;
1750 const TO& tri_off =
a->tri_off;
1752 const int idtab[15]={0,1,3,6,10,2,4,7,11,5,8,12,9,13,14};
1754 for(
int ssite=lo; ssite < hi; ++ssite) {
1755 int site = rb[
cb].siteTable()[ssite];
1757 for(
int i=0;
i < 6;
i++) {
1758 quda_array[site].diag1[
i] = tri_dia.elem(site).comp[0].diag[
i].elem().elem();
1763 for(
int col=0; col < Nc*Ns2-1; col++) {
1764 for(
int row=col+1; row < Nc*Ns2; row++) {
1766 int source_index = row*(row-1)/2 + col;
1768 quda_array[site].offDiag1[target_index][0] = tri_off.elem(site).comp[0].offd[source_index].real().elem();
1769 quda_array[site].offDiag1[target_index][1] = tri_off.elem(site).comp[0].offd[source_index].imag().elem();
1774 for(
int i=0;
i < 6;
i++) {
1775 quda_array[site].diag2[
i] = tri_dia.elem(site).comp[1].diag[
i].elem().elem();
1779 for(
int col=0; col < Nc*Ns2-1; col++) {
1780 for(
int row=col+1; row < Nc*Ns2; row++) {
1782 int source_index = row*(row-1)/2 + col;
1784 quda_array[site].offDiag2[target_index][0] = tri_off.elem(site).comp[1].offd[source_index].real().elem();
1785 quda_array[site].offDiag2[target_index][1] = tri_off.elem(site).comp[1].offd[source_index].imag().elem();
1790 QDPIO::cout <<
"\n";
1794 template<
typename T,
typename U>
1797 typedef typename WordType<T>::Type_t
REALT;
1798 int num_sites = rb[
cb].siteTable().size();
1800 typedef OLattice<PComp<PTriDia<RScalar <Word<REALT> > > > >
TD;
1801 typedef OLattice<PComp<PTriOff<RComplex<Word<REALT> > > > > TO;
1807 dispatch_to_threads(num_sites, args, QDPCloverEnv::qudaPackSiteLoop<REALT,TD,TO>);
1815 template<
typename T,
typename U>
1819 QDP_error_exit(
"LLVMCloverTermT<T,U>::applySite(T& chi, const T& psi,..) not implemented ");
Base class for all fermion action boundary conditions.
Support class for fermion actions and linear operators.
Class for counted reference semantics.
OScalar< PScalar< PScalar< RScalar< Word< REALT > > > > > RealT
~LLVMCloverTermT()
No real need for cleanup here.
OLattice< PComp< PTriDia< RScalar< Word< REALT > > > > > DiagType
const FermBC< T, multi1d< U >, multi1d< U > > & getFermBC() const
Return the fermion BC object for this linear operator.
void ldagdlinv(LatticeREAL &tr_log_diag, int cb)
Invert the clover term on cb.
void makeClov(const multi1d< U > &f, const RealT &diag_mass)
Create the clover term on cb.
void apply(T &chi, const T &psi, enum PlusMinus isign, int cb) const
OLattice< PComp< PTriDia< RScalar< Word< REALT > > > > > tri_dia
OLattice< PScalar< PScalar< RScalar< Word< REALT > > > > > LatticeREAL
multi1d< bool > choles_done
CloverFermActParams param
OLattice< PComp< PTriOff< RComplex< Word< REALT > > > > > OffDiagType
void create(Handle< FermState< T, multi1d< U >, multi1d< U > > > fs, const CloverFermActParams ¶m_)
Creation routine.
Handle< FermBC< T, multi1d< U >, multi1d< U > > > fbc
const OffDiagType & getOffDiagBuffer() const
void applySite(T &chi, const T &psi, enum PlusMinus isign, int site) const
OLattice< PComp< PTriOff< RComplex< Word< REALT > > > > > tri_off
void choles(int cb)
Computes the inverse of the term on cb using Cholesky.
const DiagType & getDiagBuffer() const
void triacntr(U &B, int mat, int cb) const
Calculates Tr_D ( Gamma_mat L )
Double cholesDet(int cb) const
Computes the inverse of the term on cb using Cholesky.
Real getCloverCoeff(int mu, int nu) const
Calculates Tr_D ( Gamma_mat L )
LLVMCloverTermT()
Empty constructor. Must use create later.
WordType< T >::Type_t REALT
const multi1d< U > & getU() const
Get the u field.
void packForQUDA(multi1d< QUDAPackedClovSite< REALT > > &quda_pack, int cb) const
PACK UP the Clover term for QUDA library:
const double & get() const
static PackForQUDATimer & Instance()
Parameters for Clover fermion action.
Clover term linear operator.
void block(LatticeColorMatrix &u_block, const multi1d< LatticeColorMatrix > &u, int mu, int bl_level, const Real &BlkAccu, int BlkMax, int j_decay)
Construct block links.
void function_triacntr_exec(const JitFunction &function, U &B, const X &tri_dia, const Y &tri_off, int mat, const Subset &s)
TRIACNTR.
Calculates the antihermitian field strength tensor iF(mu,nu)
void qudaPackSiteLoop(int lo, int hi, int myId, QUDAPackArgs< R, TD, TO > *a)
Asqtad Staggered-Dirac operator.
LLVMCloverTermT< LatticeFermionF, LatticeColorMatrixF > LLVMCloverTermF
QDP_error_exit("too many BiCG iterations", n_count, rsd_sq, cp, c, re_rvr, im_rvr, re_a, im_a, re_b, im_b)
void function_make_clov_exec(const JitFunction &function, const RealT &diag_mass, const U &f0, const U &f1, const U &f2, const U &f3, const U &f4, const U &f5, X &tri_dia, Y &tri_off)
static multi1d< LatticeColorMatrix > u
void function_ldagdlinv_exec(const JitFunction &function, T &tr_log_diag, X &tri_dia, Y &tri_off, const Subset &s)
void function_apply_clov_exec(const JitFunction &function, T &chi, const T &psi, const X &tri_dia, const Y &tri_off, const Subset &s)
LinOpSysSolverMGProtoClover::T T
void function_triacntr_build(JitFunction &func, const U &B, const X &tri_dia, const Y &tri_off, int mat, const Subset &s)
LLVMCloverTermT< LatticeFermionD, LatticeColorMatrixD > LLVMCloverTermD
void function_ldagdlinv_build(JitFunction &func, const T &tr_log_diag, const X &tri_dia, const Y &tri_off, const Subset &s)
multi1d< LatticeFermion > chi(Ncb)
void mesField(multi1d< LatticeColorMatrixF > &f, const multi1d< LatticeColorMatrixF > &u)
Calculates the antihermitian field strength tensor iF(mu,nu)
void function_make_clov_build(JitFunction &func, const RealT &diag_mass, const U &f0, const U &f1, const U &f2, const U &f3, const U &f4, const U &f5, const X &tri_dia, const Y &tri_off)
void function_apply_clov_build(JitFunction &func, const T &chi, const T &psi, const X &tri_dia, const Y &tri_off, const Subset &s)
LLVMCloverTermT< LatticeFermion, LatticeColorMatrix > LLVMCloverTerm
multi1d< LatticeFermion > s(Ncb)
const T1 const T2 const T3 & f3
const T1 const T2 const T3 const T4 const T5 & f5
const T1 const T2 const T3 const T4 & f4
FloatingPoint< double > Double
Support class for fermion actions and linear operators.
Params for clover ferm acts.
multi1d< QUDAPackedClovSite< R > > & quda_array
PCompJIT< typename JITType< T >::Type_t > Type_t
PCompJIT< typename JITType< T >::Type_t > Type_t
PTriDiaJIT< typename JITType< T >::Type_t > Type_t
PTriDiaJIT< typename JITType< T >::Type_t > Type_t
PTriOffJIT< typename JITType< T >::Type_t > Type_t
PTriOffJIT< typename JITType< T >::Type_t > Type_t
PCompJIT & operator=(const PCompREG< T1 > &rhs)
const T & elem(int i) const
void setup(const PCompJIT< typename JITType< T >::Type_t > &rhs)
const T & elem(int i) const
const T & elem(int i) const
PTriDiaJIT & operator=(const PTriDiaREG< T1 > &rhs)
const T & elem(int i) const
void setup(const PTriDiaJIT< typename JITType< T >::Type_t > &rhs)
PTriOffJIT & operator=(const PTriOffREG< T1 > &rhs)
const T & elem(int i) const
const T & elem(int i) const
void setup(const PTriOffJIT< typename JITType< T >::Type_t > &rhs)
PCompREG< typename REGType< T >::Type_t > Type_t
PTriDiaREG< typename REGType< T >::Type_t > Type_t
PTriOffREG< typename REGType< T >::Type_t > Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
multi1d< LatticeColorMatrix > U