3 #include <color_spinor_field.h>
4 #include <clover_field.h>
5 #include <dslash_quda.h>
6 #include <color_spinor_field_order.h>
7 #include <clover_field_order.h>
8 #include <index_helper.cuh>
9 #include <color_spinor.h>
11 #include <dslash_policy.cuh>
12 #include <instantiate.h>
14 #include <cuda/atomic>
19 // these should not be namespaced!!
20 // determines whether the temporal ghost zones are packed with a gather kernel,
21 // as opposed to multiple memcpys
22 static bool kernelPackT = false;
24 void setKernelPackT(bool packT) { kernelPackT = packT; }
26 bool getKernelPackT() { return kernelPackT; }
28 static std::stack<bool> kptstack;
30 void pushKernelPackT(bool packT)
32 kptstack.push(getKernelPackT());
33 setKernelPackT(packT);
35 if (kptstack.size() > 10)
37 warningQuda("KernelPackT stack contains %u elements. Is there a missing popKernelPackT() somewhere?",
38 static_cast<unsigned int>(kptstack.size()));
46 errorQuda("popKernelPackT() called with empty stack");
48 setKernelPackT(kptstack.top());
55 cudaEvent_t packEnd[2];
56 cudaEvent_t gatherStart[Nstream];
57 cudaEvent_t gatherEnd[Nstream];
58 cudaEvent_t scatterStart[Nstream];
59 cudaEvent_t scatterEnd[Nstream];
60 cudaEvent_t dslashStart[2];
62 // for shmem lightweight sync
63 shmem_sync_t sync_counter = 10;
64 shmem_sync_t get_shmem_sync_counter() { return sync_counter; }
65 shmem_sync_t set_shmem_sync_counter(shmem_sync_t count) { return sync_counter = count; }
66 shmem_sync_t inc_shmem_sync_counter() { return sync_counter++; }
68 shmem_sync_t *sync_arr = nullptr;
69 shmem_retcount_intra_t *_retcount_intra = nullptr;
70 shmem_retcount_inter_t *_retcount_inter = nullptr;
71 shmem_interior_done_t *_interior_done = nullptr;
72 shmem_interior_count_t *_interior_count = nullptr;
73 shmem_sync_t *get_shmem_sync_arr() { return sync_arr; }
74 shmem_retcount_intra_t *get_shmem_retcount_intra() { return _retcount_intra; }
75 shmem_retcount_inter_t *get_shmem_retcount_inter() { return _retcount_inter; }
76 shmem_interior_done_t *get_shmem_interior_done() { return _interior_done; }
77 shmem_interior_count_t *get_shmem_interior_count() { return _interior_count; }
80 // these variables are used for benchmarking the dslash components in isolation
81 bool dslash_pack_compute;
82 bool dslash_interior_compute;
83 bool dslash_exterior_compute;
87 // whether the dslash policy tuner has been enabled
88 bool dslash_policy_init;
90 // used to keep track of which policy to start the autotuning
91 int first_active_policy;
92 int first_active_p2p_policy;
94 // list of dslash policies that are enabled
95 std::vector<QudaDslashPolicy> policies;
97 // list of p2p policies that are enabled
98 std::vector<QudaP2PPolicy> p2p_policies;
100 // string used as a tunekey to ensure we retune if the dslash policy env changes
101 char policy_string[TuneKey::aux_n];
103 // FIX this is a hack from hell
104 // Auxiliary work that can be done while waiting on comms to finis
108 // need to use placement new constructor to initialize the atomic counters
109 template <typename T> __global__ void init_dslash_atomic(T *counter, int max)
111 for (int i = 0; i < max; i++) new (counter + i) T {0};
113 // need to use placement new constructor to initialize the atomic counters
114 template <typename T> __global__ void init_sync_arr(T *arr, T val, int max)
116 for (int i = 0; i < max; i++) *(arr + i) = val;
119 void createDslashEvents()
121 using namespace dslash;
122 // add cudaEventDisableTiming for lower sync overhead
123 for (int i=0; i<Nstream; i++) {
124 cudaEventCreateWithFlags(&gatherStart[i], cudaEventDisableTiming);
125 cudaEventCreateWithFlags(&gatherEnd[i], cudaEventDisableTiming);
126 cudaEventCreateWithFlags(&scatterStart[i], cudaEventDisableTiming);
127 cudaEventCreateWithFlags(&scatterEnd[i], cudaEventDisableTiming);
129 for (int i=0; i<2; i++) {
130 cudaEventCreateWithFlags(&packEnd[i], cudaEventDisableTiming);
131 cudaEventCreateWithFlags(&dslashStart[i], cudaEventDisableTiming);
134 sync_arr = static_cast<shmem_sync_t *>(device_comms_pinned_malloc(2 * QUDA_MAX_DIM * sizeof(shmem_sync_t)));
136 tp.grid = dim3(1, 1, 1);
137 tp.block = dim3(1, 1, 1);
139 /* initialize to 9 here so in cases where we need to do tuning we can skip the wait if necessary
140 by using smaller values */
141 qudaLaunchKernel(init_sync_arr<shmem_sync_t>, tp, 0, sync_arr, static_cast<shmem_sync_t>(9), 2 * QUDA_MAX_DIM);
144 // atomic for controlling signaling in nvshmem packing
146 = static_cast<shmem_retcount_intra_t *>(device_pinned_malloc(2 * QUDA_MAX_DIM * sizeof(shmem_retcount_intra_t)));
147 qudaLaunchKernel(init_dslash_atomic<shmem_retcount_intra_t>, tp, 0, _retcount_intra, 2 * QUDA_MAX_DIM);
149 = static_cast<shmem_retcount_inter_t *>(device_pinned_malloc(2 * QUDA_MAX_DIM * sizeof(shmem_retcount_inter_t)));
150 qudaLaunchKernel(init_dslash_atomic<shmem_retcount_inter_t>, tp, 0, _retcount_inter, 2 * QUDA_MAX_DIM);
151 // workspace for interior done sync in uber kernel
152 _interior_done = static_cast<shmem_interior_done_t *>(device_pinned_malloc(sizeof(shmem_interior_done_t)));
153 qudaLaunchKernel(init_dslash_atomic<shmem_interior_done_t>, tp, 0, _interior_done, 1);
154 _interior_count = static_cast<shmem_interior_count_t *>(device_pinned_malloc(sizeof(shmem_interior_count_t)));
155 qudaLaunchKernel(init_dslash_atomic<shmem_interior_count_t>, tp, 0, _interior_count, 1);
162 dslash_pack_compute = true;
163 dslash_interior_compute = true;
164 dslash_exterior_compute = true;
168 dslash_policy_init = false;
169 first_active_policy = 0;
170 first_active_p2p_policy = 0;
172 // list of dslash policies that are enabled
173 policies = std::vector<QudaDslashPolicy>(
174 static_cast<int>(QudaDslashPolicy::QUDA_DSLASH_POLICY_DISABLED), QudaDslashPolicy::QUDA_DSLASH_POLICY_DISABLED);
176 // list of p2p policies that are enabled
177 p2p_policies = std::vector<QudaP2PPolicy>(
178 static_cast<int>(QudaP2PPolicy::QUDA_P2P_POLICY_DISABLED), QudaP2PPolicy::QUDA_P2P_POLICY_DISABLED);
180 strcat(policy_string, ",pol=");
184 void destroyDslashEvents()
186 using namespace dslash;
188 for (int i=0; i<Nstream; i++) {
189 cudaEventDestroy(gatherStart[i]);
190 cudaEventDestroy(gatherEnd[i]);
191 cudaEventDestroy(scatterStart[i]);
192 cudaEventDestroy(scatterEnd[i]);
195 for (int i=0; i<2; i++) {
196 cudaEventDestroy(packEnd[i]);
197 cudaEventDestroy(dslashStart[i]);
200 device_comms_pinned_free(sync_arr);
201 device_pinned_free(_retcount_intra);
202 device_pinned_free(_retcount_inter);
203 device_pinned_free(_interior_done);
204 device_pinned_free(_interior_count);
210 @brief Parameter structure for driving the Gamma operator
212 template <typename Float, int nColor>
214 typedef typename colorspinor_mapper<Float,4,nColor>::type F;
215 typedef typename mapper<Float>::type RegType;
217 F out; // output vector field
218 const F in; // input vector field
219 const int d; // which gamma matrix are we applying
220 const int nParity; // number of parities we're working on
221 bool doublet; // whether we applying the operator to a doublet
222 const int volumeCB; // checkerboarded volume
223 RegType a; // scale factor
224 RegType b; // chiral twist
225 RegType c; // flavor twist
227 GammaArg(ColorSpinorField &out, const ColorSpinorField &in, int d,
228 RegType kappa=0.0, RegType mu=0.0, RegType epsilon=0.0,
229 bool dagger=false, QudaTwistGamma5Type twist=QUDA_TWIST_GAMMA5_INVALID)
230 : out(out), in(in), d(d), nParity(in.SiteSubset()),
231 doublet(in.TwistFlavor() == QUDA_TWIST_DEG_DOUBLET || in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET),
232 volumeCB(doublet ? in.VolumeCB()/2 : in.VolumeCB()), a(0.0), b(0.0), c(0.0)
234 checkPrecision(out, in);
235 checkLocation(out, in);
236 if (d < 0 || d > 4) errorQuda("Undefined gamma matrix %d", d);
237 if (in.Nspin() != 4) errorQuda("Cannot apply gamma5 to nSpin=%d field", in.Nspin());
238 if (!in.isNative() || !out.isNative()) errorQuda("Unsupported field order out=%d in=%d\n", out.FieldOrder(), in.FieldOrder());
240 if (in.TwistFlavor() == QUDA_TWIST_SINGLET) {
241 if (twist == QUDA_TWIST_GAMMA5_DIRECT) {
242 b = 2.0 * kappa * mu;
244 } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) {
245 b = -2.0 * kappa * mu;
246 a = 1.0 / (1.0 + b * b);
249 if (dagger) b *= -1.0;
250 } else if (doublet) {
251 if (twist == QUDA_TWIST_GAMMA5_DIRECT) {
252 b = 2.0 * kappa * mu;
253 c = -2.0 * kappa * epsilon;
255 } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) {
256 b = -2.0 * kappa * mu;
257 c = 2.0 * kappa * epsilon;
258 a = 1.0 / (1.0 + b * b - c * c);
259 if (a <= 0) errorQuda("Invalid twisted mass parameters (kappa=%e, mu=%e, epsilon=%e)\n", kappa, mu, epsilon);
261 if (dagger) b *= -1.0;
266 // CPU kernel for applying the gamma matrix to a colorspinor
267 template <typename Float, int nColor, typename Arg>
268 void gammaCPU(Arg arg)
270 typedef typename mapper<Float>::type RegType;
271 for (int parity= 0; parity < arg.nParity; parity++) {
273 for (int x_cb = 0; x_cb < arg.volumeCB; x_cb++) { // 4-d volume
274 ColorSpinor<RegType,nColor,4> in = arg.in(x_cb, parity);
275 arg.out(x_cb, parity) = in.gamma(arg.d);
281 // GPU Kernel for applying the gamma matrix to a colorspinor
282 template <typename Float, int nColor, int d, typename Arg>
283 __global__ void gammaGPU(Arg arg)
285 typedef typename mapper<Float>::type RegType;
286 int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
287 int parity = blockDim.y*blockIdx.y + threadIdx.y;
289 if (x_cb >= arg.volumeCB) return;
290 if (parity >= arg.nParity) return;
292 ColorSpinor<RegType,nColor,4> in = arg.in(x_cb, parity);
293 arg.out(x_cb, parity) = in.gamma(d);
296 template <typename Float, int nColor>
297 class Gamma : public TunableVectorY {
299 GammaArg<Float, nColor> arg;
300 const ColorSpinorField &meta;
302 long long flops() const { return 0; }
303 long long bytes() const { return arg.out.Bytes() + arg.in.Bytes(); }
304 bool tuneGridDim() const { return false; }
305 unsigned int minThreads() const { return arg.volumeCB; }
308 Gamma(ColorSpinorField &out, const ColorSpinorField &in, int d) :
309 TunableVectorY(in.SiteSubset()),
313 strcpy(aux, meta.AuxString());
315 apply(streams[Nstream-1]);
318 void apply(const qudaStream_t &stream) {
319 if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
320 gammaCPU<Float,nColor>(arg);
322 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
324 case 4: qudaLaunchKernel(gammaGPU<Float,nColor,4,decltype(arg)>, tp, stream, arg); break;
325 default: errorQuda("%d not instantiated", arg.d);
330 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
332 void preTune() { arg.out.save(); }
333 void postTune() { arg.out.load(); }
336 //Apply the Gamma matrix to a colorspinor field
337 //out(x) = gamma_d*in
338 void ApplyGamma(ColorSpinorField &out, const ColorSpinorField &in, int d)
340 instantiate<Gamma>(out, in, d);
343 // CPU kernel for applying the gamma matrix to a colorspinor
344 template <bool doublet, typename Float, int nColor, typename Arg>
345 void twistGammaCPU(Arg arg)
347 typedef typename mapper<Float>::type RegType;
348 for (int parity= 0; parity < arg.nParity; parity++) {
349 for (int x_cb = 0; x_cb < arg.volumeCB; x_cb++) { // 4-d volume
351 ColorSpinor<RegType,nColor,4> in = arg.in(x_cb, parity);
352 arg.out(x_cb, parity) = arg.a * (in + arg.b * in.igamma(arg.d));
354 ColorSpinor<RegType,nColor,4> in_1 = arg.in(x_cb+0*arg.volumeCB, parity);
355 ColorSpinor<RegType,nColor,4> in_2 = arg.in(x_cb+1*arg.volumeCB, parity);
356 arg.out(x_cb + 0 * arg.volumeCB, parity) = arg.a * (in_1 + arg.b * in_1.igamma(arg.d) + arg.c * in_2);
357 arg.out(x_cb + 1 * arg.volumeCB, parity) = arg.a * (in_2 - arg.b * in_2.igamma(arg.d) + arg.c * in_1);
364 // GPU Kernel for applying the gamma matrix to a colorspinor
365 template <bool doublet, typename Float, int nColor, int d, typename Arg>
366 __global__ void twistGammaGPU(Arg arg)
368 typedef typename mapper<Float>::type RegType;
369 int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
370 int parity = blockDim.y*blockIdx.y + threadIdx.y;
371 if (x_cb >= arg.volumeCB) return;
374 ColorSpinor<RegType,nColor,4> in = arg.in(x_cb, parity);
375 arg.out(x_cb, parity) = arg.a * (in + arg.b * in.igamma(d));
377 ColorSpinor<RegType,nColor,4> in_1 = arg.in(x_cb+0*arg.volumeCB, parity);
378 ColorSpinor<RegType,nColor,4> in_2 = arg.in(x_cb+1*arg.volumeCB, parity);
379 arg.out(x_cb + 0 * arg.volumeCB, parity) = arg.a * (in_1 + arg.b * in_1.igamma(d) + arg.c * in_2);
380 arg.out(x_cb + 1 * arg.volumeCB, parity) = arg.a * (in_2 - arg.b * in_2.igamma(d) + arg.c * in_1);
384 template <typename Float, int nColor>
385 class TwistGamma : public TunableVectorY {
387 GammaArg<Float, nColor> arg;
388 const ColorSpinorField &meta;
390 long long flops() const { return 0; }
391 long long bytes() const { return arg.out.Bytes() + arg.in.Bytes(); }
392 bool tuneGridDim() const { return false; }
393 unsigned int minThreads() const { return arg.volumeCB; }
396 TwistGamma(ColorSpinorField &out, const ColorSpinorField &in, int d, double kappa, double mu, double epsilon, int dagger, QudaTwistGamma5Type type) :
397 TunableVectorY(in.SiteSubset()),
398 arg(out, in, d, kappa, mu, epsilon, dagger, type),
401 strcpy(aux, meta.AuxString());
403 apply(streams[Nstream-1]);
406 void apply(const qudaStream_t &stream) {
407 if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
408 if (arg.doublet) twistGammaCPU<true,Float,nColor>(arg);
409 twistGammaCPU<false,Float,nColor>(arg);
411 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
414 case 4: qudaLaunchKernel(twistGammaGPU<true,Float,nColor,4,decltype(arg)>, tp, stream, arg); break;
415 default: errorQuda("%d not instantiated", arg.d);
419 case 4: qudaLaunchKernel(twistGammaGPU<false,Float,nColor,4,decltype(arg)>, tp, stream, arg); break;
420 default: errorQuda("%d not instantiated", arg.d);
425 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
426 void preTune() { if (arg.out.field == arg.in.field) arg.out.save(); }
427 void postTune() { if (arg.out.field == arg.in.field) arg.out.load(); }
430 //Apply the Gamma matrix to a colorspinor field
431 //out(x) = gamma_d*in
432 void ApplyTwistGamma(ColorSpinorField &out, const ColorSpinorField &in, int d, double kappa, double mu, double epsilon, int dagger, QudaTwistGamma5Type type)
434 #ifdef GPU_TWISTED_MASS_DIRAC
435 instantiate<TwistGamma>(out, in, d, kappa, mu, epsilon, dagger, type);
437 errorQuda("Twisted mass dslash has not been built");
438 #endif // GPU_TWISTED_MASS_DIRAC
441 // Applies a gamma5 matrix to a spinor (wrapper to ApplyGamma)
442 void gamma5(ColorSpinorField &out, const ColorSpinorField &in) { ApplyGamma(out,in,4); }
445 @brief Parameteter structure for driving the clover and twist-clover application kernels
446 @tparam Float Underlying storage precision
447 @tparam nSpin Number of spin components
448 @tparam nColor Number of colors
449 @tparam dynamic_clover Whether we are inverting the clover field on the fly
451 template <typename Float, int nSpin, int nColor>
453 static constexpr int length = (nSpin / (nSpin/2)) * 2 * nColor * nColor * (nSpin/2) * (nSpin/2) / 2;
454 static constexpr bool dynamic_clover = dynamic_clover_inverse();
456 typedef typename colorspinor_mapper<Float,nSpin,nColor>::type F;
457 typedef typename clover_mapper<Float,length>::type C;
458 typedef typename mapper<Float>::type RegType;
460 F out; // output vector field
461 const F in; // input vector field
462 const C clover; // clover field
463 const C cloverInv; // inverse clover field (only set if not dynamic clover and doing twisted clover)
464 const int nParity; // number of parities we're working on
465 const int parity; // which parity we're acting on (if nParity=1)
466 bool inverse; // whether we are applying the inverse
467 bool doublet; // whether we applying the operator to a doublet
468 const int volumeCB; // checkerboarded volume
472 QudaTwistGamma5Type twist;
474 CloverArg(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover,
475 bool inverse, int parity, RegType kappa=0.0, RegType mu=0.0, RegType epsilon=0.0,
476 bool dagger = false, QudaTwistGamma5Type twist=QUDA_TWIST_GAMMA5_INVALID)
477 : out(out), clover(clover, twist == QUDA_TWIST_GAMMA5_INVALID ? inverse : false),
478 cloverInv(clover, (twist != QUDA_TWIST_GAMMA5_INVALID && !dynamic_clover) ? true : false),
479 in(in), nParity(in.SiteSubset()), parity(parity), inverse(inverse),
480 doublet(in.TwistFlavor() == QUDA_TWIST_DEG_DOUBLET || in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET),
481 volumeCB(doublet ? in.VolumeCB()/2 : in.VolumeCB()), a(0.0), b(0.0), c(0.0), twist(twist)
483 checkPrecision(out, in, clover);
484 checkLocation(out, in, clover);
485 if (in.TwistFlavor() == QUDA_TWIST_SINGLET) {
486 if (twist == QUDA_TWIST_GAMMA5_DIRECT) {
487 a = 2.0 * kappa * mu;
489 } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) {
490 a = -2.0 * kappa * mu;
491 b = 1.0 / (1.0 + a*a);
494 if (dagger) a *= -1.0;
495 } else if (doublet) {
496 errorQuda("ERROR: Non-degenerated twisted-mass not supported in this regularization\n");
501 template <typename Float, int nSpin, int nColor, typename Arg>
502 __device__ __host__ inline void cloverApply(Arg &arg, int x_cb, int parity) {
503 using namespace linalg; // for Cholesky
504 typedef typename mapper<Float>::type RegType;
505 typedef ColorSpinor<RegType, nColor, nSpin> Spinor;
506 typedef ColorSpinor<RegType, nColor, nSpin / 2> HalfSpinor;
507 int spinor_parity = arg.nParity == 2 ? parity : 0;
508 Spinor in = arg.in(x_cb, spinor_parity);
511 in.toRel(); // change to chiral basis here
514 for (int chirality=0; chirality<2; chirality++) {
516 HMatrix<RegType,nColor*nSpin/2> A = arg.clover(x_cb, parity, chirality);
517 HalfSpinor chi = in.chiral_project(chirality);
519 if (arg.dynamic_clover) {
520 Cholesky<HMatrix, RegType, nColor * nSpin / 2> cholesky(A);
521 chi = static_cast<RegType>(0.25) * cholesky.backward(cholesky.forward(chi));
526 out += chi.chiral_reconstruct(chirality);
529 out.toNonRel(); // change basis back
531 arg.out(x_cb, spinor_parity) = out;
534 template <typename Float, int nSpin, int nColor, typename Arg>
535 void cloverCPU(Arg &arg) {
536 for (int parity=0; parity<arg.nParity; parity++) {
537 parity = (arg.nParity == 2) ? parity : arg.parity;
538 for (int x_cb=0; x_cb<arg.volumeCB; x_cb++) cloverApply<Float,nSpin,nColor>(arg, x_cb, parity);
542 template <typename Float, int nSpin, int nColor, typename Arg>
543 __global__ void cloverGPU(Arg arg) {
544 int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
545 int parity = (arg.nParity == 2) ? blockDim.y*blockIdx.y + threadIdx.y : arg.parity;
546 if (x_cb >= arg.volumeCB) return;
547 cloverApply<Float,nSpin,nColor>(arg, x_cb, parity);
550 template <typename Float, int nColor>
551 class Clover : public TunableVectorY {
553 static constexpr int nSpin = 4;
554 CloverArg<Float, nSpin, nColor> arg;
555 const ColorSpinorField &meta;
557 long long flops() const { return arg.nParity*arg.volumeCB*504ll; }
558 long long bytes() const { return arg.out.Bytes() + arg.in.Bytes() + arg.nParity*arg.volumeCB*arg.clover.Bytes(); }
559 bool tuneGridDim() const { return false; }
560 unsigned int minThreads() const { return arg.volumeCB; }
563 Clover(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover, bool inverse, int parity) :
564 TunableVectorY(in.SiteSubset()),
565 arg(out, in, clover, inverse, parity),
568 if (in.Nspin() != 4 || out.Nspin() != 4) errorQuda("Unsupported nSpin=%d %d", out.Nspin(), in.Nspin());
569 if (!inverse) errorQuda("Unsupported direct application");
570 strcpy(aux, meta.AuxString());
572 apply(streams[Nstream-1]);
575 void apply(const qudaStream_t &stream)
577 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
578 if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
579 cloverCPU<Float,nSpin,nColor>(arg);
581 qudaLaunchKernel(cloverGPU<Float,nSpin,nColor,decltype(arg)>, tp, stream, arg);
585 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
586 void preTune() { if (arg.out.field == arg.in.field) arg.out.save(); } // Need to save the out field if it aliases the in field
587 void postTune() { if (arg.out.field == arg.in.field) arg.out.load(); } // Restore if the in and out fields alias
590 //Apply the clover matrix field to a colorspinor field
592 void ApplyClover(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover, bool inverse, int parity)
594 #ifdef GPU_CLOVER_DIRAC
595 instantiate<Clover>(out, in, clover, inverse, parity);
597 errorQuda("Clover dslash has not been built");
598 #endif // GPU_TWISTED_MASS_DIRAC
601 // if (!inverse) apply (Clover + i*a*gamma_5) to the input spinor
602 // else apply (Clover + i*a*gamma_5)/(Clover^2 + a^2) to the input spinor
603 template <bool inverse, typename Float, int nSpin, int nColor, typename Arg>
604 __device__ __host__ inline void twistCloverApply(Arg &arg, int x_cb, int parity) {
605 using namespace linalg; // for Cholesky
606 constexpr int N = nColor*nSpin/2;
607 typedef typename mapper<Float>::type RegType;
608 typedef ColorSpinor<RegType,nColor,nSpin> Spinor;
609 typedef ColorSpinor<RegType,nColor,nSpin/2> HalfSpinor;
610 typedef HMatrix<RegType,N> Mat;
611 int spinor_parity = arg.nParity == 2 ? parity : 0;
612 Spinor in = arg.in(x_cb, spinor_parity);
615 in.toRel(); // change to chiral basis here
618 for (int chirality=0; chirality<2; chirality++) {
619 // factor of 2 comes from clover normalization we need to correct for
620 const complex<RegType> j(0.0, chirality == 0 ? static_cast<RegType>(0.5) : -static_cast<RegType>(0.5));
622 Mat A = arg.clover(x_cb, parity, chirality);
624 HalfSpinor in_chi = in.chiral_project(chirality);
625 HalfSpinor out_chi = A*in_chi + j*arg.a*in_chi;
628 if (arg.dynamic_clover) {
630 A2 += arg.a*arg.a*static_cast<RegType>(0.25);
631 Cholesky<HMatrix,RegType,N> cholesky(A2);
632 out_chi = static_cast<RegType>(0.25)*cholesky.backward(cholesky.forward(out_chi));
634 Mat Ainv = arg.cloverInv(x_cb, parity, chirality);
635 out_chi = static_cast<RegType>(2.0)*(Ainv*out_chi);
639 out += (out_chi).chiral_reconstruct(chirality);
642 out.toNonRel(); // change basis back
644 arg.out(x_cb, spinor_parity) = out;
647 template <bool inverse, typename Float, int nSpin, int nColor, typename Arg>
648 void twistCloverCPU(Arg &arg) {
649 for (int parity=0; parity<arg.nParity; parity++) {
650 parity = (arg.nParity == 2) ? parity : arg.parity;
651 for (int x_cb=0; x_cb<arg.volumeCB; x_cb++) twistCloverApply<inverse,Float,nSpin,nColor>(arg, x_cb, parity);
655 template <bool inverse, typename Float, int nSpin, int nColor, typename Arg>
656 __global__ void twistCloverGPU(Arg arg) {
657 int x_cb = blockIdx.x*blockDim.x + threadIdx.x;
658 int parity = (arg.nParity == 2) ? blockDim.y*blockIdx.y + threadIdx.y : arg.parity;
659 if (x_cb >= arg.volumeCB) return;
660 twistCloverApply<inverse,Float,nSpin,nColor>(arg, x_cb, parity);
663 template <typename Float, int nColor>
664 class TwistClover : public TunableVectorY {
666 static constexpr int nSpin = 4;
667 CloverArg<Float,nSpin,nColor> arg;
668 const ColorSpinorField &meta;
670 long long flops() const { return (arg.inverse ? 1056ll : 552ll) * arg.nParity*arg.volumeCB; }
671 long long bytes() const {
672 long long rtn = arg.out.Bytes() + arg.in.Bytes() + arg.nParity*arg.volumeCB*arg.clover.Bytes();
673 if (arg.twist == QUDA_TWIST_GAMMA5_INVERSE && !arg.dynamic_clover)
674 rtn += arg.nParity*arg.volumeCB*arg.cloverInv.Bytes();
677 bool tuneGridDim() const { return false; }
678 unsigned int minThreads() const { return arg.volumeCB; }
681 TwistClover(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover,
682 double kappa, double mu, double epsilon, int parity, int dagger, QudaTwistGamma5Type twist) :
683 TunableVectorY(in.SiteSubset()),
684 arg(out, in, clover, twist != QUDA_TWIST_GAMMA5_DIRECT, parity, kappa, mu, epsilon, dagger, twist),
687 if (in.Nspin() != 4 || out.Nspin() != 4) errorQuda("Unsupported nSpin=%d %d", out.Nspin(), in.Nspin());
688 strcpy(aux, meta.AuxString());
689 strcat(aux, arg.inverse ? ",inverse" : ",direct");
691 apply(streams[Nstream-1]);
694 void apply(const qudaStream_t &stream)
696 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
697 if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
698 if (arg.inverse) twistCloverCPU<true,Float,nSpin,nColor>(arg);
699 else twistCloverCPU<false,Float,nSpin,nColor>(arg);
701 if (arg.inverse) qudaLaunchKernel(twistCloverGPU<true,Float,nSpin,nColor,decltype(arg)>, tp, stream, arg);
702 else qudaLaunchKernel(twistCloverGPU<false,Float,nSpin,nColor,decltype(arg)>, tp, stream, arg);
706 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
707 void preTune() { if (arg.out.field == arg.in.field) arg.out.save(); } // Need to save the out field if it aliases the in field
708 void postTune() { if (arg.out.field == arg.in.field) arg.out.load(); } // Restore if the in and out fields alias
711 //Apply the twisted-clover matrix field to a colorspinor field
712 void ApplyTwistClover(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover,
713 double kappa, double mu, double epsilon, int parity, int dagger, QudaTwistGamma5Type twist)
715 #ifdef GPU_CLOVER_DIRAC
716 instantiate<TwistClover>(out, in, clover, kappa, mu, epsilon, parity, dagger, twist);
718 errorQuda("Clover dslash has not been built");
719 #endif // GPU_TWISTED_MASS_DIRAC