4 #ifdef QDPJIT_IS_QDPJITPTX
9 CUfunction function_get_fs_bs_exec(CUfunction
function,
10 const LatticeColorMatrix&
Q,
11 const LatticeColorMatrix& QQ,
12 multi1d<LatticeComplex>& f,
13 multi1d<LatticeComplex>& b1,
14 multi1d<LatticeComplex>& b2,
17 AddressLeaf addr_leaf;
19 int junk_0 = forEach(
Q, addr_leaf, NullCombine());
20 int junk_1 = forEach(QQ, addr_leaf, NullCombine());
21 int junk_2 = forEach(f[0], addr_leaf, NullCombine());
22 int junk_3 = forEach(f[1], addr_leaf, NullCombine());
23 int junk_4 = forEach(f[2], addr_leaf, NullCombine());
24 int junk_5 = forEach(b1[0], addr_leaf, NullCombine());
25 int junk_6 = forEach(b1[1], addr_leaf, NullCombine());
26 int junk_7 = forEach(b1[2], addr_leaf, NullCombine());
27 int junk_8 = forEach(b2[0], addr_leaf, NullCombine());
28 int junk_9 = forEach(b2[1], addr_leaf, NullCombine());
29 int junk_10= forEach(b2[2], addr_leaf, NullCombine());
34 int hi = Layout::sitesOnNode();
35 unsigned short dobs_u8 = dobs ? 1 : 0;
37 std::vector<void*> addr;
39 addr.push_back( &lo );
42 addr.push_back( &hi );
45 addr.push_back( &dobs_u8 );
47 int addr_dest=addr.size();
48 for(
int i=0;
i < addr_leaf.addr.size(); ++
i) {
49 addr.push_back( &addr_leaf.addr[
i] );
53 jit_launch(
function,hi-lo,addr);
58 WordREG<REAL> jit_constant(
double f )
60 return WordREG<REAL>(f);
64 CUfunction function_get_fs_bs_build(
const LatticeColorMatrix&
Q,
65 const LatticeColorMatrix& QQ,
66 multi1d<LatticeComplex>& f,
67 multi1d<LatticeComplex>& b1,
68 multi1d<LatticeComplex>& b2,
75 jit_start_new_function();
77 jit_value r_lo = jit_add_param( jit_ptx_type::s32 );
78 jit_value r_hi = jit_add_param( jit_ptx_type::s32 );
79 jit_value r_dobs = jit_add_param( jit_ptx_type::pred );
80 jit_value r_nobs = jit_ins_not( r_dobs );
82 jit_value r_idx = jit_geom_get_linear_th_idx();
84 jit_value r_out_of_range = jit_ins_ge( r_idx , r_hi );
85 jit_ins_exit( r_out_of_range );
87 ParamLeaf param_leaf( r_idx );
89 typedef typename LeafFunctor<LatticeColorMatrix, ParamLeaf>::Type_t LCMJIT;
90 typedef typename LeafFunctor<LatticeComplex , ParamLeaf>::Type_t LCJIT;
92 LCMJIT Q_jit(forEach(
Q, param_leaf, TreeCombine()));
93 LCMJIT QQ_jit(forEach(QQ, param_leaf, TreeCombine()));
94 LCJIT f0_jit(forEach(f[0], param_leaf, TreeCombine()));
95 LCJIT f1_jit(forEach(f[1], param_leaf, TreeCombine()));
96 LCJIT f2_jit(forEach(f[2], param_leaf, TreeCombine()));
97 LCJIT b10_jit(forEach(b1[0], param_leaf, TreeCombine()));
98 LCJIT b11_jit(forEach(b1[1], param_leaf, TreeCombine()));
99 LCJIT b12_jit(forEach(b1[2], param_leaf, TreeCombine()));
100 LCJIT b20_jit(forEach(b2[0], param_leaf, TreeCombine()));
101 LCJIT b21_jit(forEach(b2[1], param_leaf, TreeCombine()));
102 LCJIT b22_jit(forEach(b2[2], param_leaf, TreeCombine()));
104 auto& Q_j = Q_jit.elem(JitDeviceLayout::Coalesced);
105 auto& QQ_j = QQ_jit.elem(JitDeviceLayout::Coalesced);
107 auto& f0_j = f0_jit.elem(JitDeviceLayout::Coalesced);
108 auto& f1_j = f1_jit.elem(JitDeviceLayout::Coalesced);
109 auto& f2_j = f2_jit.elem(JitDeviceLayout::Coalesced);
111 auto& b10_j = b10_jit.elem(JitDeviceLayout::Coalesced);
112 auto& b11_j = b11_jit.elem(JitDeviceLayout::Coalesced);
113 auto& b12_j = b12_jit.elem(JitDeviceLayout::Coalesced);
114 auto& b20_j = b20_jit.elem(JitDeviceLayout::Coalesced);
115 auto& b21_j = b21_jit.elem(JitDeviceLayout::Coalesced);
116 auto& b22_j = b22_jit.elem(JitDeviceLayout::Coalesced);
121 PColorMatrixREG< RComplexREG< WordREG<REAL> >, Nc> Q_site = Q_j.elem();
122 PColorMatrixREG< RComplexREG< WordREG<REAL> >, Nc> QQ_site = QQ_j.elem();
123 PColorMatrixREG< RComplexREG< WordREG<REAL> >, Nc> QQQ = QQ_site*Q_site;
133 PScalarREG< RScalarREG< WordREG<REAL> > > trQQQ = realTrace(QQQ);
134 PScalarREG< RScalarREG< WordREG<REAL> > > trQQ = realTrace(QQ_site);
136 WordREG<REAL> c0 = jit_constant((REAL)1/(REAL)3) * trQQQ.elem().elem();
137 WordREG<REAL> c1 = jit_constant((REAL)1/(REAL)2) * trQQ.elem().elem();
139 jit_label_t not_c1_lt;
140 jit_label_t label_exit;
141 jit_ins_branch( not_c1_lt , jit_ins_ge( c1.get_val() , jit_value( 4.0e-3 ) ) );
144 f0_j.elem().elem().real() = jit_constant(1.0) - c0 * c0 / jit_constant(720.0);
145 f0_j.elem().elem().imag() = -( c0 / jit_constant(6.0) )*( jit_constant(1.0) -(c1/jit_constant(20.0))*(jit_constant(1.0)-(c1/jit_constant(42.0)))) ;
147 f1_j.elem().elem().real() = c0/jit_constant(24.0)*(jit_constant(1.0)-c1/jit_constant(15.0)*(jit_constant(1.0)-jit_constant(3.0)*c1/jit_constant(112.0))) ;
148 f1_j.elem().elem().imag() = jit_constant(1.0)-c1/jit_constant(6.0)*(jit_constant(1.0)-c1/jit_constant(20.0)*(jit_constant(1.0)-c1/jit_constant(42.0)))-c0*c0/jit_constant(5040.0);
150 f2_j.elem().elem().real() = jit_constant(0.5)*(jit_constant(-1.0)+c1/jit_constant(12.0)*(jit_constant(1.0)-c1/jit_constant(30.0)*(jit_constant(1.0)-c1/jit_constant(56.0)))+c0*c0/jit_constant(20160.0));
151 f2_j.elem().elem().imag() = jit_constant(0.5)*(c0/jit_constant(60.0)*(jit_constant(1.0)-c1/jit_constant(21.0)*(jit_constant(1.0)-c1/jit_constant(48.0))));
155 jit_ins_branch( cont_0 , r_nobs );
159 b20_j.elem().elem().real() = -c0/jit_constant(360.0);
160 b20_j.elem().elem().imag() = -jit_constant(1.0/6.0)*(jit_constant(1.0)-(c1/jit_constant(20.0))*(jit_constant(1.0)-c1/jit_constant(42.0)));
164 b10_j.elem().elem().real() = jit_constant(0);
165 b10_j.elem().elem().imag() = (c0/jit_constant(120.0))*(jit_constant(1.0)-c1/jit_constant(21.0));
169 b21_j.elem().elem().real() = jit_constant(1.0/24.0)*(jit_constant(1.0)-c1/jit_constant(15.0)*(jit_constant(1.0)-jit_constant(3.0)*c1/jit_constant(112.0)));
170 b21_j.elem().elem().imag() = -c0/jit_constant(2520.0);
174 b11_j.elem().elem().real() = -c0/jit_constant(360.0)*(jit_constant(1.0) - jit_constant(3.0)*c1/jit_constant(56.0) );
175 b11_j.elem().elem().imag() = -jit_constant(1.0/6.0)*(jit_constant(1.0)-c1/jit_constant(10.0)*(jit_constant(1.0)-c1/jit_constant(28.0)));
178 b22_j.elem().elem().real() = jit_constant(0.5)*c0/jit_constant(10080.0);
179 b22_j.elem().elem().imag() = jit_constant(0.5)*( jit_constant(1.0/60.0)*(jit_constant(1.0)-c1/jit_constant(21.0)*(jit_constant(1.0)-c1/jit_constant(48.0))) );
182 b12_j.elem().elem().real() = jit_constant(0.5)*( jit_constant(1.0/12.0)*(jit_constant(1.0)-(jit_constant(2.0)*c1/jit_constant(30.0))*(jit_constant(1.0)-jit_constant(3.0)*c1/jit_constant(112.0))) );
183 b12_j.elem().elem().imag() = jit_constant(0.5)*( -c0/jit_constant(1260.0)*(jit_constant(1.0)-c1/jit_constant(24.0)) );
185 jit_ins_label( cont_0 );
186 jit_ins_branch( label_exit );
188 jit_ins_label( not_c1_lt );
191 jit_value c0_negativeP = jit_ins_lt( c0.get_val() , jit_value(0.0) );
192 WordREG<REAL> c0abs = fabs(c0);
193 WordREG<REAL> c0max = jit_constant(2.0) * pow( c1 / jit_constant(3.0) , jit_constant(1.5) );
195 WordREG<REAL>
eps = (c0max - c0abs)/c0max;
199 jit_label_t label_theta_exit;
200 jit_ins_branch( cont_1 , jit_ins_ge(
eps.get_val() , jit_value( 0.0 ) ) );
207 theta = jit_constant(0.0);
210 jit_ins_branch( label_theta_exit );
211 jit_ins_label( cont_1 );
213 jit_ins_branch( cont_2 , jit_ins_ge(
eps.get_val() , jit_value( 1.0e-3 ) ) );
223 WordREG<REAL> sqtwo = sqrt( jit_constant(2.0) );
228 ( jit_constant(1.0) +
229 ( jit_constant(1/(REAL)12) +
230 ( jit_constant(3/(REAL)160) +
231 ( jit_constant(5/(REAL)896) +
232 ( jit_constant(35/(REAL)18432) +
233 jit_constant(63/(REAL)90112) *
eps ) *
239 jit_ins_branch( label_theta_exit );
241 jit_ins_label( cont_2 );
244 theta = acos( c0abs/c0max );
247 jit_ins_label( label_theta_exit );
249 multi1d<WordREG<REAL> > f_site_re(3);
250 multi1d<WordREG<REAL> > f_site_im(3);
252 multi1d<WordREG<REAL> > b1_site_re(3);
253 multi1d<WordREG<REAL> > b1_site_im(3);
255 multi1d<WordREG<REAL> > b2_site_re(3);
256 multi1d<WordREG<REAL> > b2_site_im(3);
260 WordREG<REAL>
u = sqrt(c1/jit_constant(3.0))*cos(theta/jit_constant(3.0));
261 WordREG<REAL> w = sqrt(c1)*sin(theta/jit_constant(3.0));
263 WordREG<REAL> u_sq =
u*
u;
264 WordREG<REAL> w_sq = w*w;
266 WordREG<REAL> xi0,xi1;
270 jit_label_t label_90;
272 jit_value Nw_smallP = jit_ins_ge( (fabs( w )).get_val() , jit_value( 0.05 ) );
273 jit_ins_branch( label_90 , Nw_smallP );
278 (jit_constant(1.0/6.0)*w_sq*( jit_constant(1.0) -
279 (jit_constant(1.0/20.)*w_sq*( jit_constant(1.0) -
280 (jit_constant(1.0/42.0)*w_sq ) ))));
282 jit_ins_branch( cont_4 );
284 jit_ins_label( label_90 );
288 jit_ins_label( cont_4 );
291 jit_ins_branch( cont_3 , r_nobs );
293 jit_label_t label_91;
295 jit_ins_branch( label_91 , Nw_smallP );
299 ( jit_constant((REAL)1/(REAL)3) -
300 jit_constant((REAL)1/(REAL)30)*w_sq*( jit_constant((REAL)1) -
301 jit_constant((REAL)1/(REAL)28)*w_sq*( jit_constant((REAL)1) -
302 jit_constant((REAL)1/(REAL)54)*w_sq ) ) );
303 jit_ins_branch( cont_5 );
306 jit_ins_label( label_91 );
307 xi1 = cos(w)/w_sq - sin(w)/(w_sq*w);
309 jit_ins_label( cont_5 );
310 jit_ins_label( cont_3 );
315 WordREG<REAL> cosu = cos(
u);
316 WordREG<REAL> sinu = sin(
u);
317 WordREG<REAL> cosw = cos(w);
318 WordREG<REAL> sinw = sin(w);
319 WordREG<REAL> sin2u = sin(jit_constant(2.0)*
u);
320 WordREG<REAL> cos2u = cos(jit_constant(2.0)*
u);
321 WordREG<REAL> ucosu =
u*cosu;
322 WordREG<REAL> usinu =
u*sinu;
323 WordREG<REAL> ucos2u =
u*cos2u;
324 WordREG<REAL> usin2u =
u*sin2u;
326 WordREG<REAL> denum = jit_constant(9.0) * u_sq - w_sq;
330 WordREG<REAL> subexp1 = u_sq - w_sq;
331 WordREG<REAL> subexp2 = jit_constant(8.0)*u_sq*cosw;
332 WordREG<REAL> subexp3 = (jit_constant(3.0)*u_sq + w_sq)*xi0;
334 f_site_re[0] = ( (subexp1)*cos2u + cosu*subexp2 + jit_constant(2.0)*usinu*subexp3 ) / denum ;
335 f_site_im[0] = ( (subexp1)*sin2u - sinu*subexp2 + jit_constant(2.0)*ucosu*subexp3 ) / denum ;
339 WordREG<REAL> subexp = (jit_constant(3.0)*u_sq -w_sq)*xi0;
341 f_site_re[1] = (jit_constant(2.0)*(ucos2u - ucosu*cosw)+subexp*sinu)/denum;
342 f_site_im[1] = (jit_constant(2.0)*(usin2u + usinu*cosw)+subexp*cosu)/denum;
347 WordREG<REAL> subexp=jit_constant(3.0)*xi0;
349 f_site_re[2] = (cos2u - cosu*cosw -usinu*subexp) /denum ;
350 f_site_im[2] = (sin2u + sinu*cosw -ucosu*subexp) /denum ;
354 jit_ins_branch( cont_6 , r_nobs );
357 multi1d<WordREG<REAL> > r_1_re(3);
358 multi1d<WordREG<REAL> > r_1_im(3);
359 multi1d<WordREG<REAL> > r_2_re(3);
360 multi1d<WordREG<REAL> > r_2_im(3);
366 WordREG<REAL> subexp1 = u_sq - w_sq;
367 WordREG<REAL> subexp2 = jit_constant(8.0)*cosw + (jit_constant(3.0)*u_sq + w_sq)*xi0 ;
368 WordREG<REAL> subexp3 = jit_constant(4.0)*u_sq*cosw - (jit_constant(9.0)*u_sq + w_sq)*xi0 ;
370 r_1_re[0] = jit_constant(2.0)*(ucos2u - sin2u *(subexp1)+ucosu*( subexp2 )- sinu*( subexp3 ) );
371 r_1_im[0] = jit_constant(2.0)*(usin2u + cos2u *(subexp1)-usinu*( subexp2 )- cosu*( subexp3 ) );
376 WordREG<REAL> subexp1 = cosw + jit_constant(3.0) * xi0;
377 WordREG<REAL> subexp2 = jit_constant(2.0)*cosw + xi0*(w_sq - jit_constant(3.0)*u_sq);
379 r_1_re[1] = jit_constant(2.0)*((cos2u - jit_constant(2.0)*usin2u) + usinu*( subexp1 )) - cosu*( subexp2 );
380 r_1_im[1] = jit_constant(2.0)*((sin2u + jit_constant(2.0)*ucos2u) + ucosu*( subexp1 )) + sinu*( subexp2 );
386 WordREG<REAL> subexp = cosw - jit_constant(3.0)*xi0;
387 r_1_re[2] = -jit_constant(2.0)*sin2u -jit_constant(3.0)*ucosu*xi0 + sinu*( subexp );
388 r_1_im[2] = jit_constant(2.0)*cos2u +jit_constant(3.0)*usinu*xi0 + cosu*( subexp );
395 WordREG<REAL> subexp = cosw + xi0 + jit_constant(3.0)*u_sq*xi1;
396 r_2_re[0] = -jit_constant(2.0)*(cos2u +
u*( jit_constant(4.0)*ucosu*xi0 - sinu*(subexp )) );
397 r_2_im[0] = -jit_constant(2.0)*(sin2u -
u*( jit_constant(4.0)*usinu*xi0 + cosu*(subexp )) );
404 WordREG<REAL> subexp = cosw + xi0 - jit_constant(3.0)*u_sq*xi1;
405 r_2_re[1] = jit_constant(2.0)*ucosu*xi0 - sinu*( subexp ) ;
406 r_2_im[1] = jit_constant(-2.0)*usinu*xi0 - cosu*( subexp ) ;
411 WordREG<REAL> subexp = jit_constant(3.0)*xi1;
413 r_2_re[2] = cosu*xi0 - usinu*subexp ;
414 r_2_im[2] = -( sinu*xi0 + ucosu*subexp ) ;
417 WordREG<REAL> b_denum=jit_constant(2.0)*denum*denum;
420 for(
int j=0;
j < 3;
j++) {
423 WordREG<REAL> subexp1 = jit_constant(2.0)*
u;
424 WordREG<REAL> subexp2 = jit_constant(3.0)*u_sq - w_sq;
425 WordREG<REAL> subexp3 = jit_constant(2.0)*(jit_constant(15.0)*u_sq + w_sq);
427 b1_site_re[
j]=( subexp1*r_1_re[
j] + subexp2*r_2_re[
j] - subexp3*f_site_re[
j] )/b_denum;
428 b1_site_im[
j]=( subexp1*r_1_im[
j] + subexp2*r_2_im[
j] - subexp3*f_site_im[
j] )/b_denum;
432 WordREG<REAL> subexp1 = jit_constant(3.0)*
u;
433 WordREG<REAL> subexp2 = jit_constant(24.0)*
u;
435 b2_site_re[
j]=( r_1_re[
j]- subexp1*r_2_re[
j] - subexp2 * f_site_re[
j] )/b_denum;
436 b2_site_im[
j]=( r_1_im[
j] -subexp1*r_2_im[
j] - subexp2 * f_site_im[
j] )/b_denum;
442 jit_ins_branch( cont_7 , jit_ins_not(c0_negativeP) );
446 b1_site_im[0] *= jit_constant(-1.0);
449 b1_site_re[1] *= jit_constant(-1.0);
452 b1_site_im[2] *= jit_constant(-1.0);
455 b2_site_re[0] *= jit_constant(-1.0);
458 b2_site_im[1] *= jit_constant(-1.0);
461 b2_site_re[2] *= jit_constant(-1.0);
463 jit_ins_label( cont_7 );
468 b10_j.elem().elem().real() = b1_site_re[0];
469 b10_j.elem().elem().imag() = b1_site_im[0];
471 b20_j.elem().elem().real() = b2_site_re[0];
472 b20_j.elem().elem().imag() = b2_site_im[0];
474 b11_j.elem().elem().real() = b1_site_re[1];
475 b11_j.elem().elem().imag() = b1_site_im[1];
477 b21_j.elem().elem().real() = b2_site_re[1];
478 b21_j.elem().elem().imag() = b2_site_im[1];
480 b12_j.elem().elem().real() = b1_site_re[2];
481 b12_j.elem().elem().imag() = b1_site_im[2];
483 b22_j.elem().elem().real() = b2_site_re[2];
484 b22_j.elem().elem().imag() = b2_site_im[2];
489 jit_ins_label( cont_6 );
495 jit_ins_branch( cont_8 , jit_ins_not(c0_negativeP) );
499 f_site_im[0] *= jit_constant(-1.0);
502 f_site_re[1] *= jit_constant(-1.0);
505 f_site_im[2] *= jit_constant(-1.0);
508 jit_ins_label( cont_8 );
512 f0_j.elem().elem().real() = f_site_re[0];
513 f0_j.elem().elem().imag() = f_site_im[0];
515 f1_j.elem().elem().real() = f_site_re[1];
516 f1_j.elem().elem().imag() = f_site_im[1];
518 f2_j.elem().elem().real() = f_site_re[2];
519 f2_j.elem().elem().imag() = f_site_im[2];
521 jit_ins_label(label_exit);
524 return jit_get_cufunction(
"ptx_get_fs_bs.ptx");
static multi1d< LatticeColorMatrix > u
multi1d< LatticeColorMatrix > Q