11 using namespace gauge;
13 template<
typename Oprod,
typename Gauge,
typename Mom>
17 #ifndef BUILD_TIFR_INTERFACE 26 KSForceArg(Oprod& oprod, Gauge &gauge, Mom& mom,
int dim[4])
27 : oprod(oprod), gauge(gauge), mom(mom){
30 for(
int dir=0; dir<4; ++dir) threads *= dim[dir];
32 for(
int dir=0; dir<4; ++dir) X[dir] = dim[dir];
33 #ifndef BUILD_TIFR_INTERFACE 35 for(
int dir=0; dir<4; ++dir) border[dir] = 2;
42 template<
typename Float,
typename Oprod,
typename Gauge,
typename Mom>
52 for(
int dir=0; dir<4; ++dir) X[dir] = arg.
X[dir];
56 #ifndef BUILD_TIFR_INTERFACE 58 for(
int dir=0; dir<4; ++dir){
59 x[dir] += arg.border[dir];
60 X[dir] += 2*arg.border[dir];
67 int dx[4] = {0,0,0,0};
68 for(
int dir=0; dir<4; ++dir){
79 arg.
mom(dir, idx, parity) = M;
83 template<
typename Float,
typename Oprod,
typename Gauge,
typename Mom>
86 int idx = threadIdx.x + blockIdx.x*blockDim.x;
89 completeKSForceCore<Float,Oprod,Gauge,Mom>(
arg,idx);
92 template<
typename Float,
typename Oprod,
typename Gauge,
typename Mom>
95 for(
int idx=0; idx<arg.
threads; idx++){
96 completeKSForceCore<Float,Oprod,Gauge,Mom>(
arg,idx);
100 template<
typename Float,
typename Oprod,
typename Gauge,
typename Mom>
117 : arg(arg), meta(meta), location(location) {
118 writeAuxString(
"prec=%lu,stride=%d",
sizeof(Float),arg.
mom.stride);
126 dim3 blockDim(128, 1, 1);
127 dim3 gridDim((arg.
threads + blockDim.x - 1) / blockDim.x, 1, 1);
128 completeKSForceKernel<Float><<<gridDim,blockDim>>>(
arg);
130 completeKSForceCPU<Float>(
arg);
136 long long flops()
const {
return 792*arg.
X[0]*arg.
X[1]*arg.
X[2]*arg.
X[3]; }
137 long long bytes()
const {
return 0; }
140 template<
typename Float,
typename Oprod,
typename Gauge,
typename Mom>
145 completeForce.
apply(0);
146 if(flops) *flops = completeForce.
flops();
151 template<
typename Float>
156 errorQuda(
"Only QUDA_CUDA_FIELD_LOCATION currently supported");
159 errorQuda(
"Reconstruct type not supported");
164 const_cast<int*>(mom.
X()),
165 gauge, location, flops);
175 errorQuda(
"Half precision not supported");
179 completeKSForce<float>(mom, oprod, gauge, location,
flops);
181 completeKSForce<double>(mom, oprod, gauge, location,
flops);
191 template<
typename Result,
typename Oprod,
typename Gauge>
204 : coeff(1.0), res(res), oprod(oprod), gauge(gauge){
208 for(
int dir=0; dir<4; ++dir) threads *= (dim[dir]-2);
209 for(
int dir=0; dir<4; ++dir) X[dir] = dim[dir]-2;
210 for(
int dir=0; dir<4; ++dir) border[dir] = 2;
212 for(
int dir=0; dir<4; ++dir) threads *= dim[dir];
213 for(
int dir=0; dir<4; ++dir) X[dir] = dim[dir];
221 template<
typename Float,
typename Result,
typename Oprod,
typename Gauge>
285 template<
typename Float,
typename Result,
typename Oprod,
typename Gauge>
288 int idx = threadIdx.x + blockIdx.x*blockDim.x;
291 computeKSLongLinkForceCore<Float,Result,Oprod,Gauge>(
arg,idx);
297 template<
typename Float,
typename Result,
typename Oprod,
typename Gauge>
300 for(
int idx=0; idx<arg.
threads; idx++){
301 computeKSLongLinkForceCore<Float,Result,Oprod,Gauge>(
arg,idx);
308 template<
typename Float,
typename Result,
typename Oprod,
typename Gauge>
326 : arg(arg), meta(meta), location(location) {
327 writeAuxString(
"prec=%lu,stride=%d",
sizeof(Float),arg.
res.stride);
335 dim3 blockDim(128, 1, 1);
336 dim3 gridDim((arg.
threads + blockDim.x - 1) / blockDim.x, 1, 1);
337 computeKSLongLinkForceKernel<Float><<<gridDim,blockDim>>>(
arg);
339 computeKSLongLinkForceCPU<Float>(
arg);
345 long long flops()
const {
return 0; }
346 long long bytes()
const {
return 0; }
352 template<
typename Float,
typename Result,
typename Oprod,
typename Gauge>
357 computeLongLink.
apply(0);
361 template<
typename Float>
365 errorQuda(
"Only QUDA_CUDA_FIELD_LOCATION currently supported");
370 errorQuda(
"Reconstruct type not supported");
375 const_cast<int*>(result.
X()),
386 errorQuda(
"Half precision not supported");
390 computeKSLongLinkForce<float>(result, oprod, gauge, location);
392 computeKSLongLinkForce<double>(result, oprod, gauge, location);
KSForceArg(Oprod &oprod, Gauge &gauge, Mom &mom, int dim[4])
__global__ void completeKSForceKernel(KSForceArg< Oprod, Gauge, Mom > arg)
static __device__ __host__ int linkIndexShift(const I x[], const J dx[], const K X[4])
unsigned int sharedBytesPerThread() const
__global__ void computeKSLongLinkForceKernel(KSLongLinkArg< Result, Oprod, Gauge > arg)
KSLongLinkArg< Result, Oprod, Gauge > arg
KSForceComplete(KSForceArg< Oprod, Gauge, Mom > &arg, const GaugeField &meta, QudaFieldLocation location)
bool tuneSharedBytes() const
void completeKSForceCPU(KSForceArg< Oprod, Gauge, Mom > &arg)
const QudaFieldLocation location
void completeKSForce(GaugeField &mom, const GaugeField &oprod, const GaugeField &gauge, QudaFieldLocation location, long long *flops=NULL)
const char * VolString() const
unsigned int sharedBytesPerBlock(const TuneParam ¶m) const
void apply(const cudaStream_t &stream)
unsigned int minThreads() const
unsigned int minThreads() const
#define qudaDeviceSynchronize()
__host__ __device__ void completeKSForceCore(KSForceArg< Oprod, Gauge, Mom > &arg, int idx)
virtual ~KSForceComplete()
KSLongLinkForce(KSLongLinkArg< Result, Oprod, Gauge > &arg, const GaugeField &meta, QudaFieldLocation location)
Main header file for host and device accessors to GaugeFields.
const QudaFieldLocation location
enum QudaFieldLocation_s QudaFieldLocation
__host__ __device__ void computeKSLongLinkForceCore(KSLongLinkArg< Result, Oprod, Gauge > &arg, int idx)
virtual ~KSLongLinkForce()
void computeKSLongLinkForce(Result res, Oprod oprod, Gauge gauge, int dim[4], const GaugeField &meta, QudaFieldLocation location)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
Accessor routine for CloverFields in native field order.
__device__ __host__ void makeAntiHerm(Matrix< Complex, N > &m)
unsigned int sharedBytesPerBlock(const TuneParam ¶m) const
QudaReconstructType Reconstruct() const
void computeKSLongLinkForceCPU(KSLongLinkArg< Result, Oprod, Gauge > &arg)
unsigned int sharedBytesPerThread() const
void apply(const cudaStream_t &stream)
QudaPrecision Precision() const
KSLongLinkArg(Result &res, Oprod &oprod, Gauge &gauge, int dim[4])
bool tuneSharedBytes() const
KSForceArg< Oprod, Gauge, Mom > arg
__host__ __device__ int getCoords(int coord[], const Arg &arg, int &idx, int parity, int &dim)
Compute the space-time coordinates we are at.