6 #ifndef __clover_term_nvvm_w_h__
7 #define __clover_term_nvvm_w_h__
9 #warning "Using QDP-JIT/NVVM clover"
56 inline T&
elem(
int i) {
return this->arrayF(
i); }
57 inline const T&
elem(
int i)
const {
return this->arrayF(
i); }
65 F[0].setup( rhs.elem(0) );
66 F[1].setup( rhs.elem(1) );
69 inline const T&
elem(
int i)
const {
return F[
i]; }
93 typedef typename WordType<T>::Type_t
Type_t;
99 typedef typename WordType<T>::Type_t
Type_t;
123 for (
int i = 0 ;
i < 2 * Nc ;
i++ )
128 inline T&
elem(
int i) {
return this->arrayF(
i); }
129 inline const T&
elem(
int i)
const {
return this->arrayF(
i); }
137 for (
int i=0;
i<2*Nc;++
i)
165 typedef typename WordType<T>::Type_t
Type_t;
171 typedef typename WordType<T>::Type_t
Type_t;
190 struct PTriOffJIT:
public BaseJIT<T,2*Nc*Nc-Nc>
195 for (
int i = 0 ;
i < 2*Nc*Nc-Nc ;
i++ )
200 inline T&
elem(
int i) {
return this->arrayF(
i); }
201 inline const T&
elem(
int i)
const {
return this->arrayF(
i); }
209 for (
int i=0;
i<2*Nc*Nc-Nc;++
i)
237 typedef typename WordType<T>::Type_t
Type_t;
243 typedef typename WordType<T>::Type_t
Type_t;
247 #if defined(QDP_USE_PROFILING)
249 struct LeafFunctor<
PComp<
T>, PrintTag>
252 static int apply(
const PComp<T> &
s,
const PrintTag &f)
259 struct LeafFunctor<
PTriDia<
T>, PrintTag>
262 static int apply(
const PTriDia<T> &
s,
const PrintTag &f)
269 struct LeafFunctor<
PTriOff<
T>, PrintTag>
272 static int apply(
const PTriOff<T> &
s,
const PrintTag &f)
289 struct QUDAPackedClovSite {
297 template<
typename T,
typename U>
302 typedef typename WordType<T>::Type_t
REALT;
304 typedef OLattice< PScalar< PScalar< RScalar< Word< REALT> > > > >
LatticeREAL;
305 typedef OScalar< PScalar< PScalar< RScalar< Word< REALT> > > > >
RealT;
382 const multi1d<U>&
getU()
const {
return u;}
397 OLattice<PComp<PTriDia<RScalar <Word<REALT> > > > >
tri_dia;
398 OLattice<PComp<PTriOff<RComplex<Word<REALT> > > > >
tri_off;
404 template<
typename T,
typename U>
408 template<
typename T,
typename U>
420 fbc = fs->getFermBC();
424 if (fbc.operator->() == 0) {
425 QDPIO::cerr <<
"NVVMCloverTerm: error: fbc is null" << std::endl;
430 RealT ff = where(param.anisoParam.anisoP, Real(1) / param.anisoParam.xi_0, Real(1));
431 param.clovCoeffR *= Real(0.5) * ff;
432 param.clovCoeffT *= Real(0.5);
442 RealT ff = where(param.anisoParam.anisoP, param.anisoParam.nu / param.anisoParam.xi_0, Real(1));
443 diag_mass = 1 + (
Nd-1)*ff + param.Mass;
452 choles_done.resize(rb.numSubsets());
453 for(
int i=0;
i < rb.numSubsets();
i++) {
467 template<
typename T,
typename U>
478 fbc = fs->getFermBC();
482 if (fbc.operator->() == 0) {
483 QDPIO::cerr <<
"NVVMCloverTerm: error: fbc is null" << std::endl;
488 RealT ff = where(param.anisoParam.anisoP, Real(1) / param.anisoParam.xi_0, Real(1));
489 param.clovCoeffR *=
RealT(0.5) * ff;
490 param.clovCoeffT *=
RealT(0.5);
500 RealT ff = where(param.anisoParam.anisoP, param.anisoParam.nu / param.anisoParam.xi_0, Real(1));
501 diag_mass = 1 + (
Nd-1)*ff + param.Mass;
507 makeClov(f, diag_mass);
509 choles_done.resize(rb.numSubsets());
510 for(
int i=0;
i < rb.numSubsets();
i++) {
511 choles_done[
i] =
false;
600 template<
typename RealT,
typename U,
typename X,
typename Y>
602 const RealT& diag_mass,
612 AddressLeaf addr_leaf(all);
614 forEach(diag_mass, addr_leaf, NullCombine());
615 forEach(f0, addr_leaf, NullCombine());
616 forEach(
f1, addr_leaf, NullCombine());
617 forEach(
f2, addr_leaf, NullCombine());
618 forEach(
f3, addr_leaf, NullCombine());
619 forEach(
f4, addr_leaf, NullCombine());
620 forEach(
f5, addr_leaf, NullCombine());
621 forEach(tri_dia, addr_leaf, NullCombine());
622 forEach(tri_off, addr_leaf, NullCombine());
626 int hi = Layout::sitesOnNode();
628 #ifndef QDP_JIT_NVVM_USE_LEGACY_LAUNCH
629 JitParam jit_lo( QDP_get_global_cache().addJitParamInt( lo ) );
630 JitParam jit_hi( QDP_get_global_cache().addJitParamInt( hi ) );
632 std::vector<QDPCache::ArgKey> ids;
633 ids.push_back( jit_lo.get_id() );
634 ids.push_back( jit_hi.get_id() );
635 for(
unsigned i=0;
i < addr_leaf.ids.size(); ++
i)
636 ids.push_back( addr_leaf.ids[
i] );
637 jit_launch(
function,Layout::sitesOnNode(),ids);
639 std::vector<void*> addr;
640 addr.push_back( &lo );
641 addr.push_back( &hi );
642 for(
unsigned i=0;
i < addr_leaf.addr.size(); ++
i) {
643 addr.push_back( &addr_leaf.addr[
i] );
645 jit_launch(
function,Layout::sitesOnNode(),addr);
651 template<
typename RealT,
typename U,
typename X,
typename Y>
663 if (ptx_db::db_enabled) {
664 CUfunction
func = llvm_ptx_db( __PRETTY_FUNCTION__ );
669 typedef typename WordType<RealT>::Type_t REALT;
671 llvm_start_new_function();
673 ParamRef p_lo = llvm_add_param<int>();
674 ParamRef p_hi = llvm_add_param<int>();
676 ParamLeaf param_leaf;
678 typedef typename LeafFunctor<RealT, ParamLeaf>::Type_t RealTJIT;
679 RealTJIT diag_mass_jit(forEach(diag_mass, param_leaf, TreeCombine()));
681 typedef typename LeafFunctor<U, ParamLeaf>::Type_t UJIT;
682 UJIT f0_jit(forEach(f0, param_leaf, TreeCombine()));
683 UJIT f1_jit(forEach(
f1, param_leaf, TreeCombine()));
684 UJIT f2_jit(forEach(
f2, param_leaf, TreeCombine()));
685 UJIT f3_jit(forEach(
f3, param_leaf, TreeCombine()));
686 UJIT f4_jit(forEach(
f4, param_leaf, TreeCombine()));
687 UJIT f5_jit(forEach(
f5, param_leaf, TreeCombine()));
689 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
690 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
692 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
693 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
696 llvm_derefParam( p_lo );
697 llvm::Value * r_hi = llvm_derefParam( p_hi );
698 llvm::Value * r_idx = llvm_thread_idx();
700 llvm_cond_exit( llvm_ge( r_idx , r_hi ) );
702 auto& f0_j = f0_jit.elem(JitDeviceLayout::Coalesced , r_idx );
703 auto& f1_j = f1_jit.elem(JitDeviceLayout::Coalesced , r_idx );
704 auto& f2_j = f2_jit.elem(JitDeviceLayout::Coalesced , r_idx );
705 auto& f3_j = f3_jit.elem(JitDeviceLayout::Coalesced , r_idx );
706 auto& f4_j = f4_jit.elem(JitDeviceLayout::Coalesced , r_idx );
707 auto& f5_j = f5_jit.elem(JitDeviceLayout::Coalesced , r_idx );
709 auto& tri_dia_j = tri_dia_jit.elem(JitDeviceLayout::Coalesced , r_idx );
710 auto& tri_off_j = tri_off_jit.elem(JitDeviceLayout::Coalesced , r_idx );
712 typename REGType< typename RealTJIT::Subtype_t >::Type_t diag_mass_reg;
713 diag_mass_reg.setup( diag_mass_jit.elem() );
718 for(
int jj = 0; jj < 2; jj++) {
719 for(
int ii = 0; ii < 2*Nc; ii++) {
720 tri_dia_j.elem(jj).elem(ii) = diag_mass_reg.elem().elem();
726 RComplexREG<WordREG<REALT> > E_minus;
727 RComplexREG<WordREG<REALT> > B_minus;
728 RComplexREG<WordREG<REALT> > ctmp_0;
729 RComplexREG<WordREG<REALT> > ctmp_1;
730 RScalarREG<WordREG<REALT> > rtmp_0;
731 RScalarREG<WordREG<REALT> > rtmp_1;
734 for(
int i = 0;
i < Nc; ++
i) {
735 ctmp_0 = f5_j.elem().elem(
i,
i);
736 ctmp_0 -= f0_j.elem().elem(
i,
i);
737 rtmp_0 = imag(ctmp_0);
738 tri_dia_j.elem(0).elem(
i) += rtmp_0;
740 tri_dia_j.elem(0).elem(
i+Nc) -= rtmp_0;
742 ctmp_1 = f5_j.elem().elem(
i,
i);
743 ctmp_1 += f0_j.elem().elem(
i,
i);
744 rtmp_1 = imag(ctmp_1);
745 tri_dia_j.elem(1).elem(
i) -= rtmp_1;
747 tri_dia_j.elem(1).elem(
i+Nc) += rtmp_1;
750 for(
int i = 1;
i < Nc; ++
i) {
751 for(
int j = 0;
j <
i; ++
j) {
754 int elem_tmp = (
i+Nc)*(
i+Nc-1)/2 +
j+Nc;
756 ctmp_0 = f0_j.elem().elem(
i,
j);
757 ctmp_0 -= f5_j.elem().elem(
i,
j);
758 tri_off_j.elem(0).elem(
elem_ij) = timesI(ctmp_0);
760 zero_rep( tri_off_j.elem(0).elem(elem_tmp) );
761 tri_off_j.elem(0).elem(elem_tmp) -= tri_off_j.elem(0).elem(
elem_ij);
763 ctmp_1 = f5_j.elem().elem(
i,
j);
764 ctmp_1 += f0_j.elem().elem(
i,
j);
765 tri_off_j.elem(1).elem(
elem_ij) = timesI(ctmp_1);
767 zero_rep( tri_off_j.elem(1).elem(elem_tmp) );
768 tri_off_j.elem(1).elem(elem_tmp) -= tri_off_j.elem(1).elem(
elem_ij);
772 for(
int i = 0;
i < Nc; ++
i) {
773 for(
int j = 0;
j < Nc; ++
j) {
778 E_minus = f2_j.elem().elem(
i,
j);
779 E_minus = timesI( E_minus );
781 E_minus += f4_j.elem().elem(
i,
j);
784 B_minus = f3_j.elem().elem(
i,
j);
785 B_minus = timesI( B_minus );
787 B_minus -= f1_j.elem().elem(
i,
j);
789 tri_off_j.elem(0).elem(
elem_ij) = B_minus - E_minus;
791 tri_off_j.elem(1).elem(
elem_ij) = E_minus + B_minus;
797 return jit_function_epilogue_get_cuf(
"jit_make_clov.ptx" , __PRETTY_FUNCTION__ );
804 template<
typename T,
typename U>
810 QDPIO::cerr << __func__ <<
": expecting Nd==4" << std::endl;
815 QDPIO::cerr << __func__ <<
": expecting Ns==4" << std::endl;
819 U f0 = f[0] * getCloverCoeff(0,1);
820 U f1 = f[1] * getCloverCoeff(0,2);
821 U f2 = f[2] * getCloverCoeff(0,3);
822 U f3 = f[3] * getCloverCoeff(1,2);
823 U f4 = f[4] * getCloverCoeff(1,3);
824 U f5 = f[5] * getCloverCoeff(2,3);
828 static CUfunction
function;
830 if (
function == NULL)
844 template<
typename T,
typename U>
852 ldagdlinv(tr_log_diag_,
cb);
864 template<
typename T,
typename U>
869 if( choles_done[
cb] ==
false )
871 QDPIO::cout << __func__ <<
": Error: you have not done the Cholesky.on this operator on this subset" << std::endl;
872 QDPIO::cout <<
"You sure you should not be asking invclov?" << std::endl;
883 return sum(tr_log_diag_, rb[
cb]);
888 template<
typename T,
typename X,
typename Y>
895 if (!
s.hasOrderedRep())
896 QDP_error_exit(
"ldagdlinv on subset with unordered representation not implemented");
898 AddressLeaf addr_leaf(
s);
900 forEach(tr_log_diag, addr_leaf, NullCombine());
901 forEach(tri_dia, addr_leaf, NullCombine());
902 forEach(tri_off, addr_leaf, NullCombine());
908 #ifndef QDP_JIT_NVVM_USE_LEGACY_LAUNCH
909 JitParam jit_lo( QDP_get_global_cache().addJitParamInt( lo ) );
910 JitParam jit_hi( QDP_get_global_cache().addJitParamInt( hi ) );
911 std::vector<QDPCache::ArgKey> ids;
912 ids.push_back( jit_lo.get_id() );
913 ids.push_back( jit_hi.get_id() );
914 for(
unsigned i=0;
i < addr_leaf.ids.size(); ++
i)
915 ids.push_back( addr_leaf.ids[
i] );
916 jit_launch(
function,
s.numSiteTable(),ids);
918 std::vector<void*> addr;
919 addr.push_back( &lo );
920 addr.push_back( &hi );
921 for(
unsigned i=0;
i < addr_leaf.addr.size(); ++
i) {
922 addr.push_back( &addr_leaf.addr[
i] );
924 jit_launch(
function,
s.numSiteTable(),addr);
932 template<
typename U,
typename T,
typename X,
typename Y>
938 typedef typename WordType<U>::Type_t REALT;
940 if (ptx_db::db_enabled) {
941 CUfunction
func = llvm_ptx_db( __PRETTY_FUNCTION__ );
949 llvm_start_new_function();
951 ParamRef p_lo = llvm_add_param<int>();
952 ParamRef p_hi = llvm_add_param<int>();
954 ParamLeaf param_leaf;
956 typedef typename LeafFunctor<T, ParamLeaf>::Type_t TJIT;
957 TJIT tr_log_diag_jit(forEach(tr_log_diag, param_leaf, TreeCombine()));
959 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
960 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
962 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
963 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
965 llvm::Value * r_lo = llvm_derefParam( p_lo );
966 llvm::Value * r_hi = llvm_derefParam( p_hi );
967 llvm::Value * r_idx_thread = llvm_thread_idx();
969 llvm_cond_exit( llvm_gt( r_idx_thread , llvm_sub( r_hi , r_lo ) ) );
971 llvm::Value * r_idx = llvm_add( r_lo , r_idx_thread );
973 auto& tr_log_diag_j = tr_log_diag_jit.elem(JitDeviceLayout::Coalesced,r_idx);
974 auto& tri_dia_j = tri_dia_jit.elem(JitDeviceLayout::Coalesced,r_idx);
975 auto& tri_off_j = tri_off_jit.elem(JitDeviceLayout::Coalesced,r_idx);
977 typename REGType< typename XJIT::Subtype_t >::Type_t tri_dia_r;
978 typename REGType< typename YJIT::Subtype_t >::Type_t tri_off_r;
980 tri_dia_r.setup( tri_dia_j );
981 tri_off_r.setup( tri_off_j );
984 RScalarREG<WordREG<REALT> > zip;
992 RScalarREG<WordREG<REALT> > inv_d[6] ;
993 RComplexREG<WordREG<REALT> > inv_offd[15] ;
994 RComplexREG<WordREG<REALT> > v[6] ;
995 RScalarREG<WordREG<REALT> > diag_g[6] ;
997 for(
int i=0;
i < N;
i++) {
998 inv_d[
i] = tri_dia_r.elem(
block).elem(
i);
1001 for(
int i=0;
i < 15;
i++) {
1002 inv_offd[
i] = tri_off_r.elem(
block).elem(
i);
1006 for(
int j=0;
j < N; ++
j) {
1008 for(
int i=0;
i <
j;
i++) {
1011 RComplexREG<WordREG<REALT> > A_ii = cmplx( inv_d[
i], zip );
1012 v[
i] = A_ii*adj(inv_offd[
elem_ji]);
1016 v[
j] = cmplx(inv_d[
j],zip);
1018 for(
int k=0;
k <
j;
k++) {
1019 int elem_jk =
j*(
j-1)/2 +
k;
1020 v[
j] -= inv_offd[elem_jk]*v[
k];
1023 inv_d[
j] = real( v[
j] );
1025 for(
int k=
j+1;
k < N;
k++) {
1026 int elem_kj =
k*(
k-1)/2 +
j;
1027 for(
int l=0;
l <
j;
l++) {
1028 int elem_kl =
k*(
k-1)/2 +
l;
1029 inv_offd[elem_kj] -= inv_offd[elem_kl] * v[
l];
1031 inv_offd[elem_kj] /= v[
j];
1037 RScalarREG<WordREG<REALT> >
one(1.0);
1040 for(
int i=0;
i < N;
i++) {
1041 diag_g[
i] =
one/inv_d[
i];
1047 tr_log_diag_j.elem().elem() += log(fabs(inv_d[
i]));
1051 if( inv_d[
i].elem() < 0 ) {
1071 RComplexREG<WordREG<REALT> >
sum;
1072 for(
int k = 0;
k < N; ++
k) {
1074 for(
int i = 0;
i <
k; ++
i) {
1081 v[
k] = cmplx(diag_g[
k],zip);
1083 for(
int i =
k+1;
i < N; ++
i) {
1086 for(
int j =
k;
j <
i; ++
j) {
1102 for(
int i = N-2; (int)
i >= (
int)
k; --
i) {
1103 for(
int j =
i+1;
j < N; ++
j) {
1111 inv_d[
k] = real(v[
k]);
1112 for(
int i =
k+1;
i < N; ++
i) {
1114 int elem_ik =
i*(
i-1)/2+
k;
1115 inv_offd[elem_ik] = v[
i];
1120 for(
int i=0;
i < N;
i++) {
1121 tri_dia_j.elem(
block).elem(
i) = inv_d[
i];
1123 for(
int i=0;
i < 15;
i++) {
1124 tri_off_j.elem(
block).elem(
i) = inv_offd[
i];
1130 return jit_function_epilogue_get_cuf(
"jit_ldagdlinv.ptx" , __PRETTY_FUNCTION__ );
1138 template<
typename T,
typename U>
1145 QDPIO::cerr << __func__ <<
": Matrix is too small" << std::endl;
1150 tr_log_diag[rb[
cb]] =
zero;
1154 static CUfunction
function;
1156 if (
function == NULL)
1157 function = function_ldagdlinv_build<U>(tr_log_diag, tri_dia, tri_off, rb[
cb] );
1163 choles_done[
cb] =
true;
1217 template<
typename U,
typename X,
typename Y>
1225 if (!
s.hasOrderedRep())
1226 QDP_error_exit(
"triacntr on subset with unordered representation not implemented");
1228 AddressLeaf addr_leaf(
s);
1230 forEach(B, addr_leaf, NullCombine());
1231 forEach(tri_dia, addr_leaf, NullCombine());
1232 forEach(tri_off, addr_leaf, NullCombine());
1238 #ifndef QDP_JIT_NVVM_USE_LEGACY_LAUNCH
1239 JitParam jit_lo( QDP_get_global_cache().addJitParamInt( lo ) );
1240 JitParam jit_hi( QDP_get_global_cache().addJitParamInt( hi ) );
1241 JitParam jit_mat( QDP_get_global_cache().addJitParamInt( mat ) );
1243 std::vector<QDPCache::ArgKey> ids;
1244 ids.push_back( jit_lo.get_id() );
1245 ids.push_back( jit_hi.get_id() );
1246 ids.push_back( jit_mat.get_id() );
1247 for(
unsigned i=0;
i < addr_leaf.ids.size(); ++
i)
1248 ids.push_back( addr_leaf.ids[
i] );
1249 jit_launch(
function,
s.numSiteTable(),ids);
1251 std::vector<void*> addr;
1252 addr.push_back( &lo );
1253 addr.push_back( &hi );
1254 addr.push_back( &mat );
1255 for(
unsigned i=0;
i < addr_leaf.addr.size(); ++
i) {
1256 addr.push_back( &addr_leaf.addr[
i] );
1258 jit_launch(
function,
s.numSiteTable(),addr);
1265 template<
typename U,
typename X,
typename Y>
1272 if (ptx_db::db_enabled) {
1273 CUfunction
func = llvm_ptx_db( __PRETTY_FUNCTION__ );
1281 typedef typename WordType<U>::Type_t REALT;
1283 llvm_start_new_function();
1285 ParamRef p_lo = llvm_add_param<int>();
1286 ParamRef p_hi = llvm_add_param<int>();
1288 ParamRef p_mat = llvm_add_param<int>();
1290 ParamLeaf param_leaf;
1293 typedef typename LeafFunctor<U, ParamLeaf>::Type_t UJIT;
1294 UJIT B_jit(forEach(B, param_leaf, TreeCombine()));
1296 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
1297 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
1299 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
1300 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
1302 llvm::Value * r_mat = llvm_derefParam( p_mat );
1303 llvm::Value * r_lo = llvm_derefParam( p_lo );
1304 llvm::Value * r_hi = llvm_derefParam( p_hi );
1305 llvm::Value * r_idx_thread = llvm_thread_idx();
1307 llvm_cond_exit( llvm_gt( r_idx_thread , llvm_sub( r_hi , r_lo ) ) );
1309 llvm::Value * r_idx = llvm_add( r_lo , r_idx_thread );
1311 auto& B_j = B_jit.elem(JitDeviceLayout::Coalesced,r_idx);
1312 auto& tri_dia_j = tri_dia_jit.elem(JitDeviceLayout::Coalesced,r_idx);
1313 auto& tri_off_j = tri_off_jit.elem(JitDeviceLayout::Coalesced,r_idx);
1315 typename REGType< typename XJIT::Subtype_t >::Type_t tri_dia_r;
1316 typename REGType< typename YJIT::Subtype_t >::Type_t tri_off_r;
1318 tri_dia_r.setup( tri_dia_j );
1319 tri_off_r.setup( tri_off_j );
1322 llvm::BasicBlock * case_0 = llvm_new_basic_block();
1323 llvm::BasicBlock * case_3 = llvm_new_basic_block();
1324 llvm::BasicBlock * case_5 = llvm_new_basic_block();
1325 llvm::BasicBlock * case_6 = llvm_new_basic_block();
1326 llvm::BasicBlock * case_9 = llvm_new_basic_block();
1327 llvm::BasicBlock * case_10 = llvm_new_basic_block();
1328 llvm::BasicBlock * case_12 = llvm_new_basic_block();
1329 llvm::BasicBlock * case_default = llvm_new_basic_block();
1331 llvm::SwitchInst * mat_sw = llvm_switch( r_mat , case_default );
1333 mat_sw->addCase( llvm_create_const_int(0) , case_0 );
1334 mat_sw->addCase( llvm_create_const_int(3) , case_3 );
1335 mat_sw->addCase( llvm_create_const_int(5) , case_5 );
1336 mat_sw->addCase( llvm_create_const_int(6) , case_6 );
1337 mat_sw->addCase( llvm_create_const_int(9) , case_9 );
1338 mat_sw->addCase( llvm_create_const_int(10) , case_10 );
1339 mat_sw->addCase( llvm_create_const_int(12) , case_12 );
1346 llvm_set_insert_point( case_0 );
1348 RComplexREG<WordREG<REALT> > lctmp0;
1349 RScalarREG< WordREG<REALT> > lr_zero0;
1350 RScalarREG< WordREG<REALT> > lrtmp0;
1354 for(
int i0 = 0; i0 < Nc; ++i0) {
1356 lrtmp0 = tri_dia_r.elem(0).elem(i0);
1357 lrtmp0 += tri_dia_r.elem(0).elem(i0+Nc);
1358 lrtmp0 += tri_dia_r.elem(1).elem(i0);
1359 lrtmp0 += tri_dia_r.elem(1).elem(i0+Nc);
1360 B_j.elem().elem(i0,i0) = cmplx(lrtmp0,lr_zero0);
1365 for(
int i0 = 1; i0 < Nc; ++i0) {
1367 int elem_ijb0 = (i0+Nc)*(i0+Nc-1)/2 + Nc;
1369 for(
int j0 = 0; j0 < i0; ++j0) {
1371 lctmp0 = tri_off_r.elem(0).elem(elem_ij0);
1372 lctmp0 += tri_off_r.elem(0).elem(elem_ijb0);
1373 lctmp0 += tri_off_r.elem(1).elem(elem_ij0);
1374 lctmp0 += tri_off_r.elem(1).elem(elem_ijb0);
1376 B_j.elem().elem(j0,i0) = lctmp0;
1377 B_j.elem().elem(i0,j0) = adj(lctmp0);
1387 llvm_set_insert_point( case_3 );
1395 RComplexREG<WordREG<REALT> > lctmp3;
1396 RScalarREG<WordREG<REALT> > lr_zero3;
1397 RScalarREG<WordREG<REALT> > lrtmp3;
1401 for(
int i3 = 0; i3 < Nc; ++i3) {
1403 lrtmp3 = tri_dia_r.elem(0).elem(i3+Nc);
1404 lrtmp3 -= tri_dia_r.elem(0).elem(i3);
1405 lrtmp3 -= tri_dia_r.elem(1).elem(i3);
1406 lrtmp3 += tri_dia_r.elem(1).elem(i3+Nc);
1407 B_j.elem().elem(i3,i3) = cmplx(lr_zero3,lrtmp3);
1412 for(
int i3 = 1; i3 < Nc; ++i3) {
1414 int elem_ijb3 = (i3+Nc)*(i3+Nc-1)/2 + Nc;
1416 for(
int j3 = 0; j3 < i3; ++j3) {
1418 lctmp3 = tri_off_r.elem(0).elem(elem_ijb3);
1419 lctmp3 -= tri_off_r.elem(0).elem(elem_ij3);
1420 lctmp3 -= tri_off_r.elem(1).elem(elem_ij3);
1421 lctmp3 += tri_off_r.elem(1).elem(elem_ijb3);
1423 B_j.elem().elem(j3,i3) = timesI(adj(lctmp3));
1424 B_j.elem().elem(i3,j3) = timesI(lctmp3);
1434 llvm_set_insert_point( case_5 );
1441 RComplexREG<WordREG<REALT> > lctmp5;
1442 RScalarREG<WordREG<REALT> > lrtmp5;
1444 for(
int i5 = 0; i5 < Nc; ++i5) {
1446 int elem_ij5 = (i5+Nc)*(i5+Nc-1)/2;
1448 for(
int j5 = 0; j5 < Nc; ++j5) {
1450 int elem_ji5 = (j5+Nc)*(j5+Nc-1)/2 + i5;
1453 lctmp5 = adj(tri_off_r.elem(0).elem(elem_ji5));
1454 lctmp5 -= tri_off_r.elem(0).elem(elem_ij5);
1455 lctmp5 += adj(tri_off_r.elem(1).elem(elem_ji5));
1456 lctmp5 -= tri_off_r.elem(1).elem(elem_ij5);
1458 B_j.elem().elem(i5,j5) = lctmp5;
1467 llvm_set_insert_point( case_6 );
1474 RComplexREG<WordREG<REALT> > lctmp6;
1475 RScalarREG<WordREG<REALT> > lrtmp6;
1477 for(
int i6 = 0; i6 < Nc; ++i6) {
1479 int elem_ij6 = (i6+Nc)*(i6+Nc-1)/2;
1481 for(
int j6 = 0; j6 < Nc; ++j6) {
1483 int elem_ji6 = (j6+Nc)*(j6+Nc-1)/2 + i6;
1485 lctmp6 = adj(tri_off_r.elem(0).elem(elem_ji6));
1486 lctmp6 += tri_off_r.elem(0).elem(elem_ij6);
1487 lctmp6 += adj(tri_off_r.elem(1).elem(elem_ji6));
1488 lctmp6 += tri_off_r.elem(1).elem(elem_ij6);
1490 B_j.elem().elem(i6,j6) = timesMinusI(lctmp6);
1499 llvm_set_insert_point( case_9 );
1506 RComplexREG<WordREG<REALT> > lctmp9;
1507 RScalarREG<WordREG<REALT> > lrtmp9;
1509 for(
int i9 = 0; i9 < Nc; ++i9) {
1511 int elem_ij9 = (i9+Nc)*(i9+Nc-1)/2;
1513 for(
int j9 = 0; j9 < Nc; ++j9) {
1515 int elem_ji9 = (j9+Nc)*(j9+Nc-1)/2 + i9;
1517 lctmp9 = adj(tri_off_r.elem(0).elem(elem_ji9));
1518 lctmp9 += tri_off_r.elem(0).elem(elem_ij9);
1519 lctmp9 -= adj(tri_off_r.elem(1).elem(elem_ji9));
1520 lctmp9 -= tri_off_r.elem(1).elem(elem_ij9);
1522 B_j.elem().elem(i9,j9) = timesI(lctmp9);
1531 llvm_set_insert_point( case_10 );
1538 RComplexREG<WordREG<REALT> > lctmp10;
1539 RScalarREG<WordREG<REALT> > lrtmp10;
1541 for(
int i10 = 0; i10 < Nc; ++i10) {
1543 int elem_ij10 = (i10+Nc)*(i10+Nc-1)/2;
1545 for(
int j10 = 0; j10 < Nc; ++j10) {
1547 int elem_ji10 = (j10+Nc)*(j10+Nc-1)/2 + i10;
1549 lctmp10 = adj(tri_off_r.elem(0).elem(elem_ji10));
1550 lctmp10 -= tri_off_r.elem(0).elem(elem_ij10);
1551 lctmp10 -= adj(tri_off_r.elem(1).elem(elem_ji10));
1552 lctmp10 += tri_off_r.elem(1).elem(elem_ij10);
1554 B_j.elem().elem(i10,j10) = lctmp10;
1563 llvm_set_insert_point( case_12 );
1571 RComplexREG<WordREG<REALT> > lctmp12;
1572 RScalarREG<WordREG<REALT> > lr_zero12;
1573 RScalarREG<WordREG<REALT> > lrtmp12;
1577 for(
int i12 = 0; i12 < Nc; ++i12) {
1579 lrtmp12 = tri_dia_r.elem(0).elem(i12);
1580 lrtmp12 -= tri_dia_r.elem(0).elem(i12+Nc);
1581 lrtmp12 -= tri_dia_r.elem(1).elem(i12);
1582 lrtmp12 += tri_dia_r.elem(1).elem(i12+Nc);
1583 B_j.elem().elem(i12,i12) = cmplx(lr_zero12,lrtmp12);
1588 for(
int i12 = 1; i12 < Nc; ++i12) {
1590 int elem_ijb12 = (i12+Nc)*(i12+Nc-1)/2 + Nc;
1592 for(
int j12 = 0; j12 < i12; ++j12) {
1594 lctmp12 = tri_off_r.elem(0).elem(elem_ij12);
1595 lctmp12 -= tri_off_r.elem(0).elem(elem_ijb12);
1596 lctmp12 -= tri_off_r.elem(1).elem(elem_ij12);
1597 lctmp12 += tri_off_r.elem(1).elem(elem_ijb12);
1599 B_j.elem().elem(i12,j12) = timesI(lctmp12);
1600 B_j.elem().elem(j12,i12) = timesI(adj(lctmp12));
1609 llvm_set_insert_point( case_default );
1611 return jit_function_epilogue_get_cuf(
"jit_triacntr.ptx" , __PRETTY_FUNCTION__ );
1618 template<
typename T,
typename U>
1625 if ( mat < 0 || mat > 15 )
1627 QDPIO::cerr << __func__ <<
": Gamma out of range: mat = " << mat << std::endl;
1633 static CUfunction
function;
1635 if (
function == NULL)
1636 function = function_triacntr_build<U>( B, tri_dia, tri_off, mat, rb[
cb] );
1645 template<
typename T,
typename U>
1651 if( param.anisoParam.anisoP ) {
1652 if (
mu==param.anisoParam.t_dir ||
nu == param.anisoParam.t_dir) {
1653 return param.clovCoeffT;
1657 return param.clovCoeffR;
1663 return param.clovCoeffR;
1671 template<
typename T,
typename X,
typename Y>
1679 if (!
s.hasOrderedRep())
1680 QDP_error_exit(
"clover on subset with unordered representation not implemented");
1684 AddressLeaf addr_leaf(
s);
1686 forEach(
chi, addr_leaf, NullCombine());
1687 forEach(
psi, addr_leaf, NullCombine());
1688 forEach(tri_dia, addr_leaf, NullCombine());
1689 forEach(tri_off, addr_leaf, NullCombine());
1695 #ifndef QDP_JIT_NVVM_USE_LEGACY_LAUNCH
1696 JitParam jit_lo( QDP_get_global_cache().addJitParamInt( lo ) );
1697 JitParam jit_hi( QDP_get_global_cache().addJitParamInt( hi ) );
1698 std::vector<QDPCache::ArgKey> ids;
1699 ids.push_back( jit_lo.get_id() );
1700 ids.push_back( jit_hi.get_id() );
1701 for(
unsigned i=0;
i < addr_leaf.ids.size(); ++
i)
1702 ids.push_back( addr_leaf.ids[
i] );
1703 jit_launch(
function,
s.numSiteTable(),ids);
1705 std::vector<void*> addr;
1706 addr.push_back( &lo );
1707 addr.push_back( &hi );
1708 for(
unsigned i=0;
i < addr_leaf.addr.size(); ++
i) {
1709 addr.push_back( &addr_leaf.addr[
i] );
1711 jit_launch(
function,
s.numSiteTable(),addr);
1718 template<
typename T,
typename X,
typename Y>
1725 if (ptx_db::db_enabled) {
1726 CUfunction
func = llvm_ptx_db( __PRETTY_FUNCTION__ );
1736 llvm_start_new_function();
1740 ParamRef p_lo = llvm_add_param<int>();
1741 ParamRef p_hi = llvm_add_param<int>();
1746 ParamLeaf param_leaf;
1748 typedef typename LeafFunctor<T, ParamLeaf>::Type_t TJIT;
1749 TJIT chi_jit(forEach(
chi, param_leaf, TreeCombine()));
1750 TJIT psi_jit(forEach(
psi, param_leaf, TreeCombine()));
1751 typename REGType< typename TJIT::Subtype_t >::Type_t psi_r;
1752 typename REGType< typename TJIT::Subtype_t >::Type_t chi_r;
1754 typedef typename LeafFunctor<X, ParamLeaf>::Type_t XJIT;
1755 XJIT tri_dia_jit(forEach(tri_dia, param_leaf, TreeCombine()));
1756 typename REGType< typename XJIT::Subtype_t >::Type_t tri_dia_r;
1758 typedef typename LeafFunctor<Y, ParamLeaf>::Type_t YJIT;
1759 YJIT tri_off_jit(forEach(tri_off, param_leaf, TreeCombine()));
1760 typename REGType< typename YJIT::Subtype_t >::Type_t tri_off_r;
1764 llvm::Value * r_lo = llvm_derefParam( p_lo );
1765 llvm::Value * r_hi = llvm_derefParam( p_hi );
1766 llvm::Value * r_idx_thread = llvm_thread_idx();
1768 llvm_cond_exit( llvm_gt( r_idx_thread , llvm_sub( r_hi , r_lo ) ) );
1770 llvm::Value * r_idx = llvm_add( r_lo , r_idx_thread );
1772 auto& chi_j = chi_jit.elem(JitDeviceLayout::Coalesced,r_idx);
1773 psi_r.setup( psi_jit.elem(JitDeviceLayout::Coalesced,r_idx) );
1774 tri_dia_r.setup( tri_dia_jit.elem(JitDeviceLayout::Coalesced,r_idx) );
1775 tri_off_r.setup( tri_off_jit.elem(JitDeviceLayout::Coalesced,r_idx) );
1782 for(
int i = 0;
i <
n; ++
i)
1784 chi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3) = tri_dia_r.elem(0).elem(
i) * psi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3);
1787 chi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3) = tri_dia_r.elem(1).elem(
i) * psi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3);
1792 for(
int i = 0;
i <
n; ++
i)
1794 for(
int j = 0;
j <
i;
j++)
1796 chi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3) += tri_off_r.elem(0).elem(kij) * psi_r.elem((0*
n+
j)/3).elem((0*
n+
j)%3);
1799 chi_r.elem((0*
n+
j)/3).elem((0*
n+
j)%3) += conj(tri_off_r.elem(0).elem(kij)) * psi_r.elem((0*
n+
i)/3).elem((0*
n+
i)%3);
1802 chi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3) += tri_off_r.elem(1).elem(kij) * psi_r.elem((1*
n+
j)/3).elem((1*
n+
j)%3);
1805 chi_r.elem((1*
n+
j)/3).elem((1*
n+
j)%3) += conj(tri_off_r.elem(1).elem(kij)) * psi_r.elem((1*
n+
i)/3).elem((1*
n+
i)%3);
1814 return jit_function_epilogue_get_cuf(
"jit_apply_clov.ptx" , __PRETTY_FUNCTION__ );
1840 template<
typename T,
typename U>
1847 QDPIO::cerr << __func__ <<
": CloverTerm::apply requires Ns==4" << std::endl;
1853 static CUfunction
function;
1855 if (
function == NULL)
1861 (*this).getFermBC().modifyF(
chi, QDP::rb[
cb]);
1870 namespace QDPCloverEnv {
1871 template<
typename R,
typename TD,
typename TO>
1872 struct QUDAPackArgs {
1879 template<
typename R,
typename TD,
typename TO>
1884 multi1d<QUDAPackedClovSite<R> >& quda_array =
a->quda_array;
1886 const TD& tri_dia =
a->tri_dia;
1887 const TO& tri_off =
a->tri_off;
1889 const int idtab[15]={0,1,3,6,10,2,4,7,11,5,8,12,9,13,14};
1891 for(
int ssite=lo; ssite < hi; ++ssite) {
1892 int site = rb[
cb].siteTable()[ssite];
1894 for(
int i=0;
i < 6;
i++) {
1895 quda_array[site].diag1[
i] = tri_dia.elem(site).comp[0].diag[
i].elem().elem();
1900 for(
int col=0; col < Nc*Ns2-1; col++) {
1901 for(
int row=col+1; row < Nc*Ns2; row++) {
1903 int source_index = row*(row-1)/2 + col;
1905 quda_array[site].offDiag1[target_index][0] = tri_off.elem(site).comp[0].offd[source_index].real().elem();
1906 quda_array[site].offDiag1[target_index][1] = tri_off.elem(site).comp[0].offd[source_index].imag().elem();
1911 for(
int i=0;
i < 6;
i++) {
1912 quda_array[site].diag2[
i] = tri_dia.elem(site).comp[1].diag[
i].elem().elem();
1916 for(
int col=0; col < Nc*Ns2-1; col++) {
1917 for(
int row=col+1; row < Nc*Ns2; row++) {
1919 int source_index = row*(row-1)/2 + col;
1921 quda_array[site].offDiag2[target_index][0] = tri_off.elem(site).comp[1].offd[source_index].real().elem();
1922 quda_array[site].offDiag2[target_index][1] = tri_off.elem(site).comp[1].offd[source_index].imag().elem();
1927 QDPIO::cout <<
"\n";
1931 template<
typename T,
typename U>
1934 typedef typename WordType<T>::Type_t
REALT;
1935 int num_sites = rb[
cb].siteTable().size();
1937 typedef OLattice<PComp<PTriDia<RScalar <Word<REALT> > > > >
TD;
1938 typedef OLattice<PComp<PTriOff<RComplex<Word<REALT> > > > > TO;
1944 dispatch_to_threads(num_sites, args, QDPCloverEnv::qudaPackSiteLoop<REALT,TD,TO>);
1952 template<
typename T,
typename U>
1956 QDP_error_exit(
"NVVMCloverTermT<T,U>::applySite(T& chi, const T& psi,..) not implemented ");
1968 #undef QDP_JIT_NVVM_USE_LEGACY_LAUNCH
Base class for all fermion action boundary conditions.
Support class for fermion actions and linear operators.
Class for counted reference semantics.
void create(Handle< FermState< T, multi1d< U >, multi1d< U > > > fs, const CloverFermActParams ¶m_)
Creation routine.
Double cholesDet(int cb) const
Computes the inverse of the term on cb using Cholesky.
OLattice< PComp< PTriDia< RScalar< Word< REALT > > > > > tri_dia
~NVVMCloverTermT()
No real need for cleanup here.
Real getCloverCoeff(int mu, int nu) const
Calculates Tr_D ( Gamma_mat L )
const multi1d< U > & getU() const
Get the u field.
Handle< FermBC< T, multi1d< U >, multi1d< U > > > fbc
void triacntr(U &B, int mat, int cb) const
Calculates Tr_D ( Gamma_mat L )
WordType< T >::Type_t REALT
void choles(int cb)
Computes the inverse of the term on cb using Cholesky.
OLattice< PComp< PTriOff< RComplex< Word< REALT > > > > > tri_off
OLattice< PScalar< PScalar< RScalar< Word< REALT > > > > > LatticeREAL
void applySite(T &chi, const T &psi, enum PlusMinus isign, int site) const
void apply(T &chi, const T &psi, enum PlusMinus isign, int cb) const
void packForQUDA(multi1d< QUDAPackedClovSite< REALT > > &quda_pack, int cb) const
PACK UP the Clover term for QUDA library:
NVVMCloverTermT()
Empty constructor. Must use create later.
void makeClov(const multi1d< U > &f, const RealT &diag_mass)
Create the clover term on cb.
const FermBC< T, multi1d< U >, multi1d< U > > & getFermBC() const
Return the fermion BC object for this linear operator.
CloverFermActParams param
multi1d< bool > choles_done
OScalar< PScalar< PScalar< RScalar< Word< REALT > > > > > RealT
void ldagdlinv(LatticeREAL &tr_log_diag, int cb)
Invert the clover term on cb.
const double & get() const
static PackForQUDATimer & Instance()
Parameters for Clover fermion action.
Clover term linear operator.
void block(LatticeColorMatrix &u_block, const multi1d< LatticeColorMatrix > &u, int mu, int bl_level, const Real &BlkAccu, int BlkMax, int j_decay)
Construct block links.
void function_triacntr_exec(const JitFunction &function, U &B, const X &tri_dia, const Y &tri_off, int mat, const Subset &s)
TRIACNTR.
Calculates the antihermitian field strength tensor iF(mu,nu)
void qudaPackSiteLoop(int lo, int hi, int myId, QUDAPackArgs< R, TD, TO > *a)
Asqtad Staggered-Dirac operator.
QDP_error_exit("too many BiCG iterations", n_count, rsd_sq, cp, c, re_rvr, im_rvr, re_a, im_a, re_b, im_b)
void function_make_clov_exec(const JitFunction &function, const RealT &diag_mass, const U &f0, const U &f1, const U &f2, const U &f3, const U &f4, const U &f5, X &tri_dia, Y &tri_off)
static multi1d< LatticeColorMatrix > u
void function_ldagdlinv_exec(const JitFunction &function, T &tr_log_diag, X &tri_dia, Y &tri_off, const Subset &s)
void function_apply_clov_exec(const JitFunction &function, T &chi, const T &psi, const X &tri_dia, const Y &tri_off, const Subset &s)
NVVMCloverTermT< LatticeFermionF, LatticeColorMatrixF > NVVMCloverTermF
LinOpSysSolverMGProtoClover::T T
void function_triacntr_build(JitFunction &func, const U &B, const X &tri_dia, const Y &tri_off, int mat, const Subset &s)
NVVMCloverTermT< LatticeFermion, LatticeColorMatrix > NVVMCloverTerm
void function_ldagdlinv_build(JitFunction &func, const T &tr_log_diag, const X &tri_dia, const Y &tri_off, const Subset &s)
multi1d< LatticeFermion > chi(Ncb)
void mesField(multi1d< LatticeColorMatrixF > &f, const multi1d< LatticeColorMatrixF > &u)
Calculates the antihermitian field strength tensor iF(mu,nu)
NVVMCloverTermT< LatticeFermionD, LatticeColorMatrixD > NVVMCloverTermD
void function_make_clov_build(JitFunction &func, const RealT &diag_mass, const U &f0, const U &f1, const U &f2, const U &f3, const U &f4, const U &f5, const X &tri_dia, const Y &tri_off)
void function_apply_clov_build(JitFunction &func, const T &chi, const T &psi, const X &tri_dia, const Y &tri_off, const Subset &s)
multi1d< LatticeFermion > s(Ncb)
const T1 const T2 const T3 & f3
const T1 const T2 const T3 const T4 const T5 & f5
const T1 const T2 const T3 const T4 & f4
FloatingPoint< double > Double
Support class for fermion actions and linear operators.
Params for clover ferm acts.
multi1d< QUDAPackedClovSite< R > > & quda_array
PCompJIT< typename JITType< T >::Type_t > Type_t
PCompJIT< typename JITType< T >::Type_t > Type_t
PTriDiaJIT< typename JITType< T >::Type_t > Type_t
PTriDiaJIT< typename JITType< T >::Type_t > Type_t
PTriOffJIT< typename JITType< T >::Type_t > Type_t
PTriOffJIT< typename JITType< T >::Type_t > Type_t
PCompJIT & operator=(const PCompREG< T1 > &rhs)
const T & elem(int i) const
void setup(const PCompJIT< typename JITType< T >::Type_t > &rhs)
const T & elem(int i) const
const T & elem(int i) const
PTriDiaJIT & operator=(const PTriDiaREG< T1 > &rhs)
const T & elem(int i) const
void setup(const PTriDiaJIT< typename JITType< T >::Type_t > &rhs)
PTriOffJIT & operator=(const PTriOffREG< T1 > &rhs)
const T & elem(int i) const
const T & elem(int i) const
void setup(const PTriOffJIT< typename JITType< T >::Type_t > &rhs)
PCompREG< typename REGType< T >::Type_t > Type_t
PTriDiaREG< typename REGType< T >::Type_t > Type_t
PTriOffREG< typename REGType< T >::Type_t > Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
WordType< T >::Type_t Type_t
multi1d< LatticeColorMatrix > U