1 #include <quda_internal.h>
2 #include <quda_matrix.h>
4 #include <clover_field.h>
5 #include <gauge_field.h>
6 #include <gauge_field_order.h>
7 #include <clover_field_order.h>
11 #ifdef GPU_CLOVER_DIRAC
13 template<typename Float, typename Clover1, typename Clover2, typename Gauge>
14 struct CloverTraceArg {
20 CloverTraceArg(Clover1 &clover1, Clover2 &clover2, Gauge &gauge, Float coeff)
21 : clover1(clover1), clover2(clover2), gauge(gauge), coeff(coeff) {}
25 template <typename Float, typename Arg>
26 __device__ __host__ void cloverSigmaTraceCompute(Arg & arg, const int x, int parity)
30 if (parity==0) arg.clover1.load(A,x,parity);
31 else arg.clover2.load(A,x,parity);
33 // load the clover term into memory
34 for (int mu=0; mu<4; mu++) {
35 for (int nu=0; nu<mu; nu++) {
37 Matrix<complex<Float>,3> mat;
41 complex<Float> tri[2][15];
42 const int idtab[15]={0,1,3,6,10,2,4,7,11,5,8,12,9,13,14};
45 for (int ch=0; ch<2; ++ch) {
46 // factor of two is inherent to QUDA clover storage
47 for (int i=0; i<6; i++) diag[ch][i] = 2.0*A[ch*36+i];
48 for (int i=0; i<15; i++) tri[ch][idtab[i]] = complex<Float>(2.0*A[ch*36+6+2*i], 2.0*A[ch*36+6+2*i+1]);
54 for (int j=0; j<3; ++j) {
55 mat(j,j).y = diag[0][j+3] + diag[1][j+3] - diag[0][j] - diag[1][j];
60 for (int j=1; j<3; ++j) {
61 int jk2 = (j+3)*(j+2)/2 + 3;
62 for (int k=0; k<j; ++k) {
63 ctmp = tri[0][jk2] + tri[1][jk2] - tri[0][jk] - tri[1][jk];
65 mat(j,k).x = -ctmp.imag();
66 mat(j,k).y = ctmp.real();
67 mat(k,j).x = ctmp.imag();
68 mat(k,j).y = ctmp.real();
76 for (int j=0; j<3; ++j) {
77 int jk = (j+3)*(j+2)/2;
78 for (int k=0; k<3; ++k) {
79 int kj = (k+3)*(k+2)/2 + j;
80 mat(j,k) = conj(tri[0][kj]) - tri[0][jk] + conj(tri[1][kj]) - tri[1][jk];
86 for (int j=0; j<3; ++j) {
87 int jk = (j+3)*(j+2)/2;
88 for (int k=0; k<3; ++k) {
89 int kj = (k+3)*(k+2)/2 + j;
90 ctmp = conj(tri[0][kj]) + tri[0][jk] - conj(tri[1][kj]) - tri[1][jk];
91 mat(j,k).x = -ctmp.imag();
92 mat(j,k).y = ctmp.real();
99 for (int j=0; j<3; ++j) {
100 int jk = (j+3)*(j+2)/2;
101 for (int k=0; k<3; ++k) {
102 int kj = (k+3)*(k+2)/2 + j;
103 ctmp = conj(tri[0][kj]) + tri[0][jk] + conj(tri[1][kj]) + tri[1][jk];
104 mat(j,k).x = ctmp.imag();
105 mat(j,k).y = -ctmp.real();
109 } else if (mu == 3){ // Y T
110 for (int j=0; j<3; ++j) {
111 int jk = (j+3)*(j+2)/2;
112 for (int k=0; k<3; ++k) {
113 int kj = (k+3)*(k+2)/2 + j;
114 mat(j,k) = conj(tri[0][kj]) - tri[0][jk] - conj(tri[1][kj]) + tri[1][jk];
122 for (int j=0; j<3; ++j) {
123 mat(j,j).y = diag[0][j] - diag[0][j+3] - diag[1][j] + diag[1][j+3];
126 for (int j=1; j<3; ++j) {
127 int jk2 = (j+3)*(j+2)/2 + 3;
128 for (int k=0; k<j; ++k) {
129 ctmp = tri[0][jk] - tri[0][jk2] - tri[1][jk] + tri[1][jk2];
130 mat(j,k).x = -ctmp.imag();
131 mat(j,k).y = ctmp.real();
133 mat(k,j).x = ctmp.imag();
134 mat(k,j).y = ctmp.real();
142 arg.gauge((mu-1)*mu/2 + nu, x, parity) = mat;
149 template<typename Float, typename Arg>
150 void cloverSigmaTrace(Arg &arg)
152 for (int x=0; x<arg.clover1.volumeCB; x++) {
153 cloverSigmaTraceCompute<Float,Arg>(arg, x, 1);
159 template<typename Float, typename Arg>
160 __global__ void cloverSigmaTraceKernel(Arg arg)
162 int idx = blockIdx.x*blockDim.x + threadIdx.x;
163 if (idx >= arg.clover1.volumeCB) return;
165 cloverSigmaTraceCompute<Float,Arg>(arg, idx, 1);
168 template<typename Float, typename Arg>
169 class CloverSigmaTrace : Tunable {
171 const GaugeField &meta;
174 unsigned int sharedBytesPerThread() const { return 0; }
175 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
177 bool tuneSharedBytes() const { return false; } // Don't tune the shared memory
178 bool tuneGridDim() const { return false; } // Don't tune the grid dimensions.
179 unsigned int minThreads() const { return arg.clover1.volumeCB; }
182 CloverSigmaTrace(Arg &arg, const GaugeField &meta)
183 : arg(arg), meta(meta) {
184 writeAuxString("stride=%d", arg.clover1.stride);
186 virtual ~CloverSigmaTrace() {;}
188 void apply(const qudaStream_t &stream){
189 if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
190 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
191 qudaLaunchKernel(cloverSigmaTraceKernel<Float,Arg>, tp, stream, arg);
193 cloverSigmaTrace<Float,Arg>(arg);
197 TuneKey tuneKey() const { return TuneKey(meta.VolString(), typeid(*this).name(), aux); }
199 long long flops() const { return 0; } // Fix this
200 long long bytes() const { return (arg.clover1.Bytes() + 6*arg.gauge.Bytes()) * arg.clover1.volumeCB; }
202 }; // CloverSigmaTrace
205 template<typename Float, typename Clover1, typename Clover2, typename Gauge>
206 void computeCloverSigmaTrace(Clover1 clover1, Clover2 clover2, Gauge gauge,
207 const GaugeField &meta, Float coeff)
209 typedef CloverTraceArg<Float, Clover1, Clover2, Gauge> Arg;
210 Arg arg(clover1, clover2, gauge, coeff);
211 CloverSigmaTrace<Float, Arg> traceCompute(arg, meta);
212 traceCompute.apply(0);
216 template<typename Float>
217 void computeCloverSigmaTrace(GaugeField& gauge, const CloverField& clover, Float coeff){
219 if(clover.isNative()) {
220 typedef typename clover_mapper<Float>::type C;
221 if (gauge.isNative()) {
222 if (gauge.Reconstruct() == QUDA_RECONSTRUCT_NO) {
223 typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type G;
224 computeCloverSigmaTrace<Float>( C(clover,0), C(clover,1), G(gauge), gauge, coeff);
225 } else if(gauge.Reconstruct() == QUDA_RECONSTRUCT_12) {
226 typedef typename gauge_mapper<Float,QUDA_RECONSTRUCT_NO>::type G;
227 computeCloverSigmaTrace<Float>( C(clover,0), C(clover,1), G(gauge), gauge, coeff);
229 errorQuda("Reconstruction type %d not supported", gauge.Reconstruct());
232 errorQuda("Gauge order %d not supported", gauge.Order());
235 errorQuda("clover order %d not supported", clover.Order());
242 void computeCloverSigmaTrace(GaugeField& output, const CloverField& clover, double coeff) {
244 #ifdef GPU_CLOVER_DIRAC
245 if (clover.Precision() == QUDA_SINGLE_PRECISION) {
246 computeCloverSigmaTrace<float>(output, clover, static_cast<float>(coeff));
247 } else if (clover.Precision() == QUDA_DOUBLE_PRECISION){
248 computeCloverSigmaTrace<double>(output, clover, coeff);
250 errorQuda("Precision %d not supported", clover.Precision());
253 errorQuda("Clover has not been built");