6 #ifndef __lwldslash_llvm_h__
7 #define __lwldslash_llvm_h__
14 #include "dslash_sig_0.h"
15 #include "dslash_sig_1.h"
50 template<
typename T,
typename P,
typename Q>
67 const multi1d<Real>& coeffs_);
78 const multi1d<Real>& coeffs_);
166 template<
typename T,
typename P,
typename Q>
168 for (
int i=0;
i<8;
i++)
169 comms[
i].do_comms =
false;
173 template<
typename T,
typename P,
typename Q>
180 template<
typename T,
typename P,
typename Q>
184 create(
state, aniso_);
188 template<
typename T,
typename P,
typename Q>
190 const multi1d<Real>& coeffs_)
192 create(
state, coeffs_);
196 template<
typename T,
typename P,
typename Q>
199 multi1d<Real> cf(
Nd);
205 template<
typename T,
typename P,
typename Q>
218 template<
typename T,
typename P,
typename Q>
221 for (
int i = 0 ;
i < 8 ;
i++ )
222 if (comms[
i].do_comms)
224 int dstnum = comms[
i].snd_sz;
225 int srcnum = comms[
i].rcv_sz;
226 int dstnode = comms[
i].snd_nd;
227 int srcnode = comms[
i].rcv_nd;
229 send_buf_mem[
i] = QMP_allocate_aligned_memory(dstnum,QDP_ALIGNMENT_SIZE, (QMP_MEM_COMMS|QMP_MEM_FAST) );
230 if( send_buf_mem[
i] == 0x0 ) {
231 send_buf_mem[
i] = QMP_allocate_aligned_memory(dstnum, QDP_ALIGNMENT_SIZE, QMP_MEM_COMMS);
232 if( send_buf_mem[
i] == 0x0 ) {
236 recv_buf_mem[
i] = QMP_allocate_aligned_memory(srcnum,QDP_ALIGNMENT_SIZE, (QMP_MEM_COMMS|QMP_MEM_FAST));
237 if( recv_buf_mem[
i] == 0x0 ) {
238 recv_buf_mem[
i] = QMP_allocate_aligned_memory(srcnum, QDP_ALIGNMENT_SIZE, QMP_MEM_COMMS);
239 if( recv_buf_mem[
i] == 0x0 ) {
243 send_buf[
i]=(
double*)QMP_get_memory_pointer(send_buf_mem[
i]);
244 recv_buf[
i]=(
double*)QMP_get_memory_pointer(recv_buf_mem[
i]);
247 msg[
i][0] = QMP_declare_msgmem( recv_buf[
i] , srcnum );
249 if( msg[
i][0] == (QMP_msgmem_t)NULL ) {
250 QDP_error_exit(
"QMP_declare_msgmem for msg[0] failed in Map::operator()\n");
253 msg[
i][1] = QMP_declare_msgmem( send_buf[
i] , dstnum );
255 if( msg[
i][1] == (QMP_msgmem_t)NULL ) {
256 QDP_error_exit(
"QMP_declare_msgmem for msg[1] failed in Map::operator()\n");
259 mh_a[
i][0] = QMP_declare_receive_from(msg[
i][0], srcnode, 0);
260 if( mh_a[
i][0] == (QMP_msghandle_t)NULL ) {
261 QDP_error_exit(
"QMP_declare_receive_from for mh_a[0] failed in Map::operator()\n");
264 mh_a[
i][1] = QMP_declare_send_to(msg[
i][1], dstnode , 0);
265 if( mh_a[
i][1] == (QMP_msghandle_t)NULL ) {
266 QDP_error_exit(
"QMP_declare_send_to for mh_a[1] failed in Map::operator()\n");
269 mh[
i] = QMP_declare_multiple(mh_a[
i], 2);
270 if( mh[
i] == (QMP_msghandle_t)NULL ) {
271 QDP_error_exit(
"QMP_declare_multiple for mh failed in Map::operator()\n");
277 template<
typename T,
typename P,
typename Q>
280 if ((err = QMP_start(mh[
i])) != QMP_SUCCESS)
285 template<
typename T,
typename P,
typename Q>
288 for (
int i = 0 ;
i < 8 ;
i++ )
289 if (comms[
i].do_comms)
290 if ((err = QMP_wait(mh[
i])) != QMP_SUCCESS)
294 QDP_info(
"Map: calling free msgs");
299 template<
typename T,
typename P,
typename Q>
301 for (
int i = 0 ;
i < 8 ;
i++ )
302 if (comms[
i].do_comms)
304 QMP_free_msghandle(mh[
i]);
307 QMP_free_msgmem(msg[
i][1]);
308 QMP_free_msgmem(msg[
i][0]);
310 QMP_free_memory(recv_buf_mem[
i]);
311 QMP_free_memory(send_buf_mem[
i]);
316 template<
typename T,
typename P,
typename Q>
324 for(
int dir = 3 ; dir >= 0 ; --dir ) {
328 const Map& map = shift.getMap(
isign,dir);
332 if (map.hasOffnode())
334 offnode_maps |= map.getId();
336 int dstnum = shift.getMap(
isign,dir).get_destnodes_num()[rb[0].getId()][0]*
sizeof(double)*12;
337 int srcnum = shift.getMap(
isign,dir).get_srcenodes_num()[rb[0].getId()][0]*
sizeof(double)*12;
342 int dstnode = shift.getMap(
isign,dir).get_destnodes()[0];
343 int srcnode = shift.getMap(
isign,dir).get_srcenodes()[0];
348 comms[comm_no].do_comms=
true;
349 comms[comm_no].snd_nd=dstnode;
350 comms[comm_no].snd_sz=dstnum;
351 comms[comm_no].rcv_nd=srcnode;
352 comms[comm_no].rcv_sz=srcnum;
356 comms[comm_no].do_comms=
false;
365 for(
int cb=0 ;
cb<2 ; ++
cb) {
366 if (offnode_maps > 0)
368 innerCount[
cb] = MasterMap::Instance().getCountInner(rb[
cb],offnode_maps);
369 faceCount[
cb] = MasterMap::Instance().getCountFace(rb[
cb],offnode_maps);
370 innerSites[
cb] = MasterMap::Instance().getInnerSites(rb[
cb],offnode_maps).slice();
371 faceSites[
cb] = MasterMap::Instance().getFaceSites(rb[
cb],offnode_maps).slice();
375 innerCount[
cb] = rb[
cb].numSiteTable();
377 innerSites[
cb] = rb[
cb].siteTable().slice();
378 faceSites[
cb] = NULL;
384 comm_thread = omp_get_max_threads() >=4 ? 3 : 0;
392 template<
typename T,
typename P,
typename Q>
394 const multi1d<Real>& coeffs_)
400 fbc =
state->getFermBC();
403 if (fbc.operator->() == 0)
405 QDPIO::cerr <<
"LLVMWilsonDslash: error: fbc is null" << std::endl;
412 for(
int mu=0;
mu <
u.size(); ++
mu) {
417 for(
int mu=0;
mu <
u.size(); ++
mu)
439 template<
typename T,
typename P,
typename Q>
444 double* psi_ptr = (
double*)
psi.getFjit();
445 double* chi_ptr = (
double*)
chi.getFjit();
446 double* u0_ptr = (
double*)
u[0].getFjit();
447 double* u1_ptr = (
double*)
u[1].getFjit();
448 double* u2_ptr = (
double*)
u[2].getFjit();
449 double* u3_ptr = (
double*)
u[3].getFjit();
454 #pragma omp parallel default(shared)
456 int threads_num = omp_get_num_threads();
457 int myId = omp_get_thread_num();
459 if (comms[0].do_comms)
462 const Map& map = shift.getMap(-1,3);
463 const int* soffset_slice = map.soffset(rb[
cb]).slice();
464 int soffset_num = map.soffset(rb[
cb]).size();
466 int low = soffset_num*myId/threads_num;
467 int high = soffset_num*(myId+1)/threads_num;
469 func_gather_M_3_0(low,high,0,
true,0,soffset_slice,send_buf[0],psi_ptr,u3_ptr);
473 if (myId == comm_thread)
474 comms_send_receive(0);
477 const Map& map = shift.getMap(+1,3);
478 const int* soffset_slice = map.soffset(rb[
cb]).slice();
479 int soffset_num = map.soffset(rb[
cb]).size();
481 int low = soffset_num*myId/threads_num;
482 int high = soffset_num*(myId+1)/threads_num;
484 func_gather_P_3_0(low,high,0,
true,0,soffset_slice,send_buf[1],psi_ptr);
488 if (myId == comm_thread)
489 comms_send_receive(1);
492 if (comms[2].do_comms)
495 const Map& map = shift.getMap(-1,2);
496 const int* soffset_slice = map.soffset(rb[
cb]).slice();
497 int soffset_num = map.soffset(rb[
cb]).size();
499 int low = soffset_num*myId/threads_num;
500 int high = soffset_num*(myId+1)/threads_num;
502 func_gather_M_2_0(low,high,0,
true,0,soffset_slice,send_buf[2],psi_ptr,u2_ptr);
506 if (myId == comm_thread)
507 comms_send_receive(2);
510 const Map& map = shift.getMap(+1,2);
511 const int* soffset_slice = map.soffset(rb[
cb]).slice();
512 int soffset_num = map.soffset(rb[
cb]).size();
514 int low = soffset_num*myId/threads_num;
515 int high = soffset_num*(myId+1)/threads_num;
517 func_gather_P_2_0(low,high,0,
true,0,soffset_slice,send_buf[3],psi_ptr);
521 if (myId == comm_thread)
522 comms_send_receive(3);
525 if (comms[4].do_comms)
528 const Map& map = shift.getMap(-1,1);
529 const int* soffset_slice = map.soffset(rb[
cb]).slice();
530 int soffset_num = map.soffset(rb[
cb]).size();
532 int low = soffset_num*myId/threads_num;
533 int high = soffset_num*(myId+1)/threads_num;
535 func_gather_M_1_0(low,high,0,
true,0,soffset_slice,send_buf[4],psi_ptr,u1_ptr);
539 if (myId == comm_thread)
540 comms_send_receive(4);
543 const Map& map = shift.getMap(+1,1);
544 const int* soffset_slice = map.soffset(rb[
cb]).slice();
545 int soffset_num = map.soffset(rb[
cb]).size();
547 int low = soffset_num*myId/threads_num;
548 int high = soffset_num*(myId+1)/threads_num;
550 func_gather_P_1_0(low,high,0,
true,0,soffset_slice,send_buf[5],psi_ptr);
554 if (myId == comm_thread)
555 comms_send_receive(5);
558 if (comms[6].do_comms)
561 const Map& map = shift.getMap(-1,0);
562 const int* soffset_slice = map.soffset(rb[
cb]).slice();
563 int soffset_num = map.soffset(rb[
cb]).size();
565 int low = soffset_num*myId/threads_num;
566 int high = soffset_num*(myId+1)/threads_num;
568 func_gather_M_0_0(low,high,0,
true,0,soffset_slice,send_buf[6],psi_ptr,u0_ptr);
572 if (myId == comm_thread)
573 comms_send_receive(6);
576 const Map& map = shift.getMap(+1,0);
577 const int* soffset_slice = map.soffset(rb[
cb]).slice();
578 int soffset_num = map.soffset(rb[
cb]).size();
580 int low = soffset_num*myId/threads_num;
581 int high = soffset_num*(myId+1)/threads_num;
583 func_gather_P_0_0(low,high,0,
true,0,soffset_slice,send_buf[7],psi_ptr);
587 if (myId == comm_thread)
588 comms_send_receive(7);
592 int low = innerCount[
cb]*myId/threads_num;
593 int high = innerCount[
cb]*(myId+1)/threads_num;
595 func_dslash_____0( low , high , 0 ,
false , 0 , innerSites[
cb] , chi_ptr ,
596 shift.getMap(-1,3).goffset(rb[
cb]).slice(), NULL, psi_ptr, u3_ptr ,
597 shift.getMap(+1,3).goffset(rb[
cb]).slice(), NULL, psi_ptr, u3_ptr ,
598 shift.getMap(-1,2).goffset(rb[
cb]).slice(), NULL, psi_ptr, u2_ptr ,
599 shift.getMap(+1,2).goffset(rb[
cb]).slice(), NULL, psi_ptr, u2_ptr ,
600 shift.getMap(-1,1).goffset(rb[
cb]).slice(), NULL, psi_ptr, u1_ptr ,
601 shift.getMap(+1,1).goffset(rb[
cb]).slice(), NULL, psi_ptr, u1_ptr ,
602 shift.getMap(-1,0).goffset(rb[
cb]).slice(), NULL, psi_ptr, u0_ptr ,
603 shift.getMap(+1,0).goffset(rb[
cb]).slice(), NULL, psi_ptr, u0_ptr );
606 if (myId == comm_thread)
611 int low = faceCount[
cb]*myId/threads_num;
612 int high = faceCount[
cb]*(myId+1)/threads_num;
614 func_dslash_____0( low , high , 0 ,
false , 0 , faceSites[
cb] , chi_ptr ,
615 shift.getMap(-1,3).goffset(rb[
cb]).slice(),recv_buf[0],psi_ptr,u3_ptr,
616 shift.getMap(+1,3).goffset(rb[
cb]).slice(),recv_buf[1],psi_ptr,u3_ptr,
617 shift.getMap(-1,2).goffset(rb[
cb]).slice(),recv_buf[2],psi_ptr,u2_ptr,
618 shift.getMap(+1,2).goffset(rb[
cb]).slice(),recv_buf[3],psi_ptr,u2_ptr,
619 shift.getMap(-1,1).goffset(rb[
cb]).slice(),recv_buf[4],psi_ptr,u1_ptr,
620 shift.getMap(+1,1).goffset(rb[
cb]).slice(),recv_buf[5],psi_ptr,u1_ptr,
621 shift.getMap(-1,0).goffset(rb[
cb]).slice(),recv_buf[6],psi_ptr,u0_ptr,
622 shift.getMap(+1,0).goffset(rb[
cb]).slice(),recv_buf[7],psi_ptr,u0_ptr);
632 #pragma omp parallel default(shared)
634 int threads_num = omp_get_num_threads();
635 int myId = omp_get_thread_num();
637 if (comms[0].do_comms)
640 const Map& map = shift.getMap(-1,3);
641 const int* soffset_slice = map.soffset(rb[
cb]).slice();
642 int soffset_num = map.soffset(rb[
cb]).size();
644 int low = soffset_num*myId/threads_num;
645 int high = soffset_num*(myId+1)/threads_num;
647 func_gather_M_3_1(low,high,0,
true,0,soffset_slice,send_buf[0],psi_ptr,u3_ptr);
651 if (myId == comm_thread)
652 comms_send_receive(0);
655 const Map& map = shift.getMap(+1,3);
656 const int* soffset_slice = map.soffset(rb[
cb]).slice();
657 int soffset_num = map.soffset(rb[
cb]).size();
659 int low = soffset_num*myId/threads_num;
660 int high = soffset_num*(myId+1)/threads_num;
662 func_gather_P_3_1(low,high,0,
true,0,soffset_slice,send_buf[1],psi_ptr);
666 if (myId == comm_thread)
667 comms_send_receive(1);
670 if (comms[2].do_comms)
673 const Map& map = shift.getMap(-1,2);
674 const int* soffset_slice = map.soffset(rb[
cb]).slice();
675 int soffset_num = map.soffset(rb[
cb]).size();
677 int low = soffset_num*myId/threads_num;
678 int high = soffset_num*(myId+1)/threads_num;
680 func_gather_M_2_1(low,high,0,
true,0,soffset_slice,send_buf[2],psi_ptr,u2_ptr);
684 if (myId == comm_thread)
685 comms_send_receive(2);
688 const Map& map = shift.getMap(+1,2);
689 const int* soffset_slice = map.soffset(rb[
cb]).slice();
690 int soffset_num = map.soffset(rb[
cb]).size();
692 int low = soffset_num*myId/threads_num;
693 int high = soffset_num*(myId+1)/threads_num;
695 func_gather_P_2_1(low,high,0,
true,0,soffset_slice,send_buf[3],psi_ptr);
699 if (myId == comm_thread)
700 comms_send_receive(3);
703 if (comms[4].do_comms)
706 const Map& map = shift.getMap(-1,1);
707 const int* soffset_slice = map.soffset(rb[
cb]).slice();
708 int soffset_num = map.soffset(rb[
cb]).size();
710 int low = soffset_num*myId/threads_num;
711 int high = soffset_num*(myId+1)/threads_num;
713 func_gather_M_1_1(low,high,0,
true,0,soffset_slice,send_buf[4],psi_ptr,u1_ptr);
717 if (myId == comm_thread)
718 comms_send_receive(4);
721 const Map& map = shift.getMap(+1,1);
722 const int* soffset_slice = map.soffset(rb[
cb]).slice();
723 int soffset_num = map.soffset(rb[
cb]).size();
725 int low = soffset_num*myId/threads_num;
726 int high = soffset_num*(myId+1)/threads_num;
728 func_gather_P_1_1(low,high,0,
true,0,soffset_slice,send_buf[5],psi_ptr);
732 if (myId == comm_thread)
733 comms_send_receive(5);
736 if (comms[6].do_comms)
739 const Map& map = shift.getMap(-1,0);
740 const int* soffset_slice = map.soffset(rb[
cb]).slice();
741 int soffset_num = map.soffset(rb[
cb]).size();
743 int low = soffset_num*myId/threads_num;
744 int high = soffset_num*(myId+1)/threads_num;
746 func_gather_M_0_1(low,high,0,
true,0,soffset_slice,send_buf[6],psi_ptr,u0_ptr);
750 if (myId == comm_thread)
751 comms_send_receive(6);
754 const Map& map = shift.getMap(+1,0);
755 const int* soffset_slice = map.soffset(rb[
cb]).slice();
756 int soffset_num = map.soffset(rb[
cb]).size();
758 int low = soffset_num*myId/threads_num;
759 int high = soffset_num*(myId+1)/threads_num;
761 func_gather_P_0_1(low,high,0,
true,0,soffset_slice,send_buf[7],psi_ptr);
765 if (myId == comm_thread)
766 comms_send_receive(7);
770 int low = innerCount[
cb]*myId/threads_num;
771 int high = innerCount[
cb]*(myId+1)/threads_num;
773 func_dslash_____1( low , high , 0 ,
false , 0 , innerSites[
cb] , chi_ptr ,
774 shift.getMap(-1,3).goffset(rb[
cb]).slice(), NULL, psi_ptr, u3_ptr ,
775 shift.getMap(+1,3).goffset(rb[
cb]).slice(), NULL, psi_ptr, u3_ptr ,
776 shift.getMap(-1,2).goffset(rb[
cb]).slice(), NULL, psi_ptr, u2_ptr ,
777 shift.getMap(+1,2).goffset(rb[
cb]).slice(), NULL, psi_ptr, u2_ptr ,
778 shift.getMap(-1,1).goffset(rb[
cb]).slice(), NULL, psi_ptr, u1_ptr ,
779 shift.getMap(+1,1).goffset(rb[
cb]).slice(), NULL, psi_ptr, u1_ptr ,
780 shift.getMap(-1,0).goffset(rb[
cb]).slice(), NULL, psi_ptr, u0_ptr ,
781 shift.getMap(+1,0).goffset(rb[
cb]).slice(), NULL, psi_ptr, u0_ptr );
784 if (myId == comm_thread)
789 int low = faceCount[
cb]*myId/threads_num;
790 int high = faceCount[
cb]*(myId+1)/threads_num;
792 func_dslash_____1( low , high , 0 ,
false , 0 , faceSites[
cb] , chi_ptr ,
793 shift.getMap(-1,3).goffset(rb[
cb]).slice(),recv_buf[0],psi_ptr,u3_ptr,
794 shift.getMap(+1,3).goffset(rb[
cb]).slice(),recv_buf[1],psi_ptr,u3_ptr,
795 shift.getMap(-1,2).goffset(rb[
cb]).slice(),recv_buf[2],psi_ptr,u2_ptr,
796 shift.getMap(+1,2).goffset(rb[
cb]).slice(),recv_buf[3],psi_ptr,u2_ptr,
797 shift.getMap(-1,1).goffset(rb[
cb]).slice(),recv_buf[4],psi_ptr,u1_ptr,
798 shift.getMap(+1,1).goffset(rb[
cb]).slice(),recv_buf[5],psi_ptr,u1_ptr,
799 shift.getMap(-1,0).goffset(rb[
cb]).slice(),recv_buf[6],psi_ptr,u0_ptr,
800 shift.getMap(+1,0).goffset(rb[
cb]).slice(),recv_buf[7],psi_ptr,u0_ptr);
826 multi1d<LatticeColorMatrixD>,
Base class for all fermion action boundary conditions.
Support class for fermion actions and linear operators.
Class for counted reference semantics.
General Wilson-Dirac dslash.
const FermBC< T, P, Q > & getFermBC() const
Return the fermion BC object for this linear operator.
void create(Handle< FermState< T, P, Q > > state)
Creation routine.
const multi1d< Real > & getCoeffs() const
Get the anisotropy parameters.
QMP_mem_t * send_buf_mem[8]
Handle< FermBC< T, P, Q > > fbc
QMP_msghandle_t mh_a[8][2]
~LLVMWilsonDslashT()
No real need for cleanup here.
void comms_send_receive(int i) const
const int * innerSites[2]
QMP_mem_t * recv_buf_mem[8]
General Wilson-Dirac dslash.
LLVMWilsonDslashT()
Empty constructor. Must use create later.
void apply(T &chi, const T &psi, enum PlusMinus isign, int cb) const
General Wilson-Dirac dslash.
Wilson Dslash linear operator.
Asqtad Staggered-Dirac operator.
QDP_error_exit("too many BiCG iterations", n_count, rsd_sq, cp, c, re_rvr, im_rvr, re_a, im_a, re_b, im_b)
static multi1d< LatticeColorMatrix > u
LinOpSysSolverMGProtoClover::Q Q
LinOpSysSolverMGProtoClover::T T
LLVMWilsonDslashT< LatticeFermionD, multi1d< LatticeColorMatrixD >, multi1d< LatticeColorMatrixD > > LLVMWilsonDslashD
multi1d< LatticeFermion > chi(Ncb)
const WilsonTypeFermAct< multi1d< LatticeFermion > > Handle< const ConnectState > state
multi1d< Real > makeFermCoeffs(const AnisoParam_t &aniso)
Make fermion coefficients.
Support class for fermion actions and linear operators.
Parameters for anisotropy.