13 template <
typename Float,
int n,
typename Arg>
14 class CalculateYhat :
public TunableVectorYZ {
18 const LatticeField &meta;
20 bool compute_max_only;
22 long long flops()
const {
return 2l * arg.coarseVolumeCB * 8 * n * n * (8*n-2); }
23 long long bytes()
const {
return 2l * (arg.Xinv.Bytes() + 8*arg.Y.Bytes() + 8*arg.Yhat.Bytes()) * n; }
25 unsigned int minThreads()
const {
return arg.coarseVolumeCB; }
27 bool tuneGridDim()
const {
return false; }
30 CalculateYhat(Arg &arg,
const LatticeField &meta) :
31 TunableVectorYZ(2 * n, 4 * n),
34 compute_max_only(false)
38 create_jitify_program(
"kernels/coarse_op_preconditioned.cuh");
46 virtual ~CalculateYhat() {
53 void apply(
const cudaStream_t &
stream) {
58 CalculateYhatCPU<Float, n, true, Arg>(
arg);
60 CalculateYhatCPU<Float, n, false, Arg>(
arg);
63 if (compute_max_only) {
66 cudaMemsetAsync(arg.max_d, 0,
sizeof(Float), stream);
70 using namespace jitify::reflection;
71 jitify_error = program->kernel(
"quda::CalculateYhatGPU")
72 .instantiate(Type<Float>(), n, compute_max_only, Type<Arg>())
73 .configure(tp.grid, tp.block, tp.shared_bytes, stream)
77 CalculateYhatGPU<Float, n, true, Arg><<<tp.grid, tp.block, tp.shared_bytes, stream>>>(
arg);
79 CalculateYhatGPU<Float, n, false, Arg><<<tp.grid, tp.block, tp.shared_bytes, stream>>>(
arg);
81 if (compute_max_only) {
83 qudaMemcpyAsync(arg.max_h, arg.max_d,
sizeof(Float), cudaMemcpyDeviceToHost, stream);
93 void setComputeMaxOnly(
bool compute_max_only_) { compute_max_only = compute_max_only_; }
96 bool advanceSharedBytes(TuneParam &
param)
const {
return false; }
98 bool advanceTuneParam(TuneParam &
param)
const {
103 TuneKey tuneKey()
const {
106 if (compute_max_only) strcat(Aux,
",compute_max_only");
113 return TuneKey(meta.VolString(),
typeid(*this).name(), Aux);
125 template<
typename storeFloat,
typename Float,
int N, QudaGaugeFieldOrder gOrder>
126 void calculateYhat(GaugeField &Yhat, GaugeField &Xinv,
const GaugeField &Y,
const GaugeField &
X)
129 const int n = X.Ncolor();
131 GaugeFieldParam
param(X);
135 cudaGaugeField X_(param);
136 cudaGaugeField Xinv_(param);
145 const cpuGaugeField *
X_h =
static_cast<const cpuGaugeField*
>(&
X);
146 cpuGaugeField *
Xinv_h =
static_cast<cpuGaugeField*
>(&Xinv);
149 errorQuda(
"Unsupported location=%d and order=%d", X.Location(), X.Order());
160 for (
int i=0; i<4; i++) xc_size[i] = X.X()[i];
164 typedef typename gauge::FieldOrder<Float,N,1,gOrder,true,storeFloat> gCoarse;
165 typedef typename gauge::FieldOrder<Float,N,1,gOrder,true,storeFloat> gPreconditionedCoarse;
166 gCoarse yAccessor(const_cast<GaugeField&>(Y));
167 gPreconditionedCoarse yHatAccessor(const_cast<GaugeField&>(Yhat));
168 gCoarse xInvAccessor(const_cast<GaugeField&>(Xinv));
173 typedef CalculateYhatArg<Float, gPreconditionedCoarse, gCoarse, N> yHatArg;
174 yHatArg
arg(yHatAccessor, yAccessor, xInvAccessor, xc_size, comm_dim, 1);
176 CalculateYhat<Float, N, yHatArg> yHat(arg, Y);
178 yHat.setComputeMaxOnly(
true);
181 double max_h_double = *arg.max_h;
183 *arg.max_h =
static_cast<Float
>(max_h_double);
187 Yhat.Scale(*arg.max_h);
188 arg.Yhat.resetScale(*arg.max_h);
190 yHat.setComputeMaxOnly(
false);
194 for (
int d = 0; d < 8; d++)
195 printfQuda(
"Yhat[%d] = %e (%e %e = %e x %e)\n", d, Yhat.norm2(d), Yhat.abs_max(d),
196 Y.abs_max(d) * Xinv.abs_max(0), Y.abs_max(d), Xinv.abs_max(0));
210 template <
typename storeFloat,
typename Float,
int N>
211 void calculateYhat(GaugeField &Yhat, GaugeField &Xinv,
const GaugeField &Y,
const GaugeField &X)
215 if (Y.FieldOrder() != gOrder)
errorQuda(
"Unsupported field order %d\n", Y.FieldOrder());
216 calculateYhat<storeFloat,Float,N,gOrder>(Yhat, Xinv, Y,
X);
219 if (Y.FieldOrder() != gOrder)
errorQuda(
"Unsupported field order %d\n", Y.FieldOrder());
220 calculateYhat<storeFloat,Float,N,gOrder>(Yhat, Xinv, Y,
X);
225 template <
typename storeFloat,
typename Float>
226 void calculateYhat(GaugeField &Yhat, GaugeField &Xinv,
const GaugeField &Y,
const GaugeField &X) {
227 switch (Y.Ncolor()) {
228 case 2: calculateYhat<storeFloat,Float, 2>(Yhat, Xinv, Y,
X);
break;
229 case 4: calculateYhat<storeFloat,Float, 4>(Yhat, Xinv, Y,
X);
break;
230 case 8: calculateYhat<storeFloat,Float, 8>(Yhat, Xinv, Y,
X);
break;
231 case 12: calculateYhat<storeFloat,Float,12>(Yhat, Xinv, Y,
X);
break;
232 case 16: calculateYhat<storeFloat,Float,16>(Yhat, Xinv, Y,
X);
break;
233 case 20: calculateYhat<storeFloat,Float,20>(Yhat, Xinv, Y,
X);
break;
234 case 24: calculateYhat<storeFloat,Float,24>(Yhat, Xinv, Y,
X);
break;
235 case 32: calculateYhat<storeFloat,Float,32>(Yhat, Xinv, Y,
X);
break;
236 case 48: calculateYhat<storeFloat,Float,48>(Yhat, Xinv, Y,
X);
break;
237 case 64: calculateYhat<storeFloat,Float,64>(Yhat, Xinv, Y,
X);
break;
238 default:
errorQuda(
"Unsupported number of coarse dof %d\n", Y.Ncolor());
break;
252 #ifdef GPU_MULTIGRID_DOUBLE 254 calculateYhat<double,double>(Yhat, Xinv, Y,
X);
256 errorQuda(
"Double precision multigrid has not been enabled");
260 calculateYhat<float, float>(Yhat, Xinv, Y,
X);
262 errorQuda(
"Unsupported precision %d\n", precision);
266 calculateYhat<short, float>(Yhat, Xinv, Y,
X);
268 errorQuda(
"Unsupported precision %d\n", precision);
271 errorQuda(
"Unsupported precision %d\n", precision);
276 errorQuda(
"Multigrid has not been built");
#define pool_pinned_free(ptr)
enum QudaPrecision_s QudaPrecision
QudaVerbosity getVerbosity()
#define checkPrecision(...)
Helper file when using jitify run-time compilation. This file should be included in source code...
const char * comm_dim_partitioned_string(const int *comm_dim_override=0)
Return a string that defines the comm partitioning (used as a tuneKey)
const char * compile_type_str(const LatticeField &meta, QudaFieldLocation location_=QUDA_INVALID_FIELD_LOCATION)
Helper function for setting auxilary string.
cudaError_t qudaStreamSynchronize(cudaStream_t &stream)
Wrapper around cudaStreamSynchronize or cuStreamSynchronize.
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const int batch, QudaPrecision precision, QudaFieldLocation location)
#define pool_device_malloc(size)
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
enum QudaGaugeFieldOrder_s QudaGaugeFieldOrder
char * getOmpThreadStr()
Returns a string of the form ",omp_threads=$OMP_NUM_THREADS", which can be used for storing the numbe...
bool activeTuning()
query if tuning is in progress
void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X)
Calculate preconditioned coarse links and coarse clover inverse field.
#define pool_pinned_malloc(size)
#define qudaMemcpyAsync(dst, src, count, kind, stream)
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
#define pool_device_free(ptr)
void comm_allreduce_max(double *data)
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
QudaPrecision Precision() const
int comm_dim_partitioned(int dim)
virtual bool advanceTuneParam(TuneParam ¶m) const