QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
coarse_op_preconditioned.cu
Go to the documentation of this file.
1 #include <gauge_field.h>
2 #include <blas_cublas.h>
3 #include <blas_quda.h>
4 #include <tune_quda.h>
5 
6 #include <jitify_helper.cuh>
8 
9 namespace quda {
10 
11 #ifdef GPU_MULTIGRID
12 
13  template <typename Float, int n, typename Arg>
14  class CalculateYhat : public TunableVectorYZ {
15 
16  protected:
17  Arg &arg;
18  const LatticeField &meta;
19 
20  bool compute_max_only;
21 
22  long long flops() const { return 2l * arg.coarseVolumeCB * 8 * n * n * (8*n-2); } // 8 from dir, 8 from complexity,
23  long long bytes() const { return 2l * (arg.Xinv.Bytes() + 8*arg.Y.Bytes() + 8*arg.Yhat.Bytes()) * n; }
24 
25  unsigned int minThreads() const { return arg.coarseVolumeCB; }
26 
27  bool tuneGridDim() const { return false; } // don't tune the grid dimension
28 
29  public:
30  CalculateYhat(Arg &arg, const LatticeField &meta) :
31  TunableVectorYZ(2 * n, 4 * n),
32  arg(arg),
33  meta(meta),
34  compute_max_only(false)
35  {
36  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
37 #ifdef JITIFY
38  create_jitify_program("kernels/coarse_op_preconditioned.cuh");
39 #endif
40  arg.max_d = static_cast<Float*>(pool_device_malloc(sizeof(Float)));
41  }
42  arg.max_h = static_cast<Float*>(pool_pinned_malloc(sizeof(Float)));
43  strcpy(aux, compile_type_str(meta));
44  strcat(aux, comm_dim_partitioned_string());
45  }
46  virtual ~CalculateYhat() {
47  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
48  pool_device_free(arg.max_d);
49  }
50  pool_pinned_free(arg.max_h);
51  }
52 
53  void apply(const cudaStream_t &stream) {
54  TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
55  if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
56 
57  if (compute_max_only)
58  CalculateYhatCPU<Float, n, true, Arg>(arg);
59  else
60  CalculateYhatCPU<Float, n, false, Arg>(arg);
61 
62  } else {
63  if (compute_max_only) {
64  if (!activeTuning())
65  {
66  cudaMemsetAsync(arg.max_d, 0, sizeof(Float), stream);
67  }
68  }
69 #ifdef JITIFY
70  using namespace jitify::reflection;
71  jitify_error = program->kernel("quda::CalculateYhatGPU")
72  .instantiate(Type<Float>(), n, compute_max_only, Type<Arg>())
73  .configure(tp.grid, tp.block, tp.shared_bytes, stream)
74  .launch(arg);
75 #else
76  if (compute_max_only)
77  CalculateYhatGPU<Float, n, true, Arg><<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
78  else
79  CalculateYhatGPU<Float, n, false, Arg><<<tp.grid, tp.block, tp.shared_bytes, stream>>>(arg);
80 #endif
81  if (compute_max_only) {
82  if (!activeTuning()) { // only do copy once tuning is done
83  qudaMemcpyAsync(arg.max_h, arg.max_d, sizeof(Float), cudaMemcpyDeviceToHost, stream);
84  qudaStreamSynchronize(const_cast<cudaStream_t&>(stream));
85  }
86  }
87  }
88  }
89 
93  void setComputeMaxOnly(bool compute_max_only_) { compute_max_only = compute_max_only_; }
94 
95  // no locality in this kernel so no point in shared-memory tuning
96  bool advanceSharedBytes(TuneParam &param) const { return false; }
97 
98  bool advanceTuneParam(TuneParam &param) const {
99  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION && meta.MemType() == QUDA_MEMORY_DEVICE) return Tunable::advanceTuneParam(param);
100  else return false;
101  }
102 
103  TuneKey tuneKey() const {
104  char Aux[TuneKey::aux_n];
105  strcpy(Aux,aux);
106  if (compute_max_only) strcat(Aux, ",compute_max_only");
107  if (meta.Location() == QUDA_CUDA_FIELD_LOCATION) {
108  strcat(Aux, meta.MemType() == QUDA_MEMORY_MAPPED ? ",GPU-mapped" : ",GPU-device");
109  } else if (meta.Location() == QUDA_CPU_FIELD_LOCATION) {
110  strcat(Aux, ",CPU");
111  strcat(Aux, getOmpThreadStr());
112  }
113  return TuneKey(meta.VolString(), typeid(*this).name(), Aux);
114  }
115  };
116 
125  template<typename storeFloat, typename Float, int N, QudaGaugeFieldOrder gOrder>
126  void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X)
127  {
128  // invert the clover matrix field
129  const int n = X.Ncolor();
130  if (X.Location() == QUDA_CUDA_FIELD_LOCATION && X.Order() == QUDA_FLOAT2_GAUGE_ORDER) {
131  GaugeFieldParam param(X);
132  // need to copy into AoS format for CUBLAS
133  param.order = QUDA_MILC_GAUGE_ORDER;
134  param.setPrecision( X.Precision() < QUDA_SINGLE_PRECISION ? QUDA_SINGLE_PRECISION : X.Precision() );
135  cudaGaugeField X_(param);
136  cudaGaugeField Xinv_(param);
137  X_.copy(X);
138  blas::flops += cublas::BatchInvertMatrix((void*)Xinv_.Gauge_p(), (void*)X_.Gauge_p(), n, X_.Volume(), X_.Precision(), X.Location());
139 
140  if (Xinv.Precision() < QUDA_SINGLE_PRECISION) Xinv.Scale( Xinv_.abs_max() );
141 
142  Xinv.copy(Xinv_);
143 
144  } else if (X.Location() == QUDA_CPU_FIELD_LOCATION && X.Order() == QUDA_QDP_GAUGE_ORDER) {
145  const cpuGaugeField *X_h = static_cast<const cpuGaugeField*>(&X);
146  cpuGaugeField *Xinv_h = static_cast<cpuGaugeField*>(&Xinv);
147  blas::flops += cublas::BatchInvertMatrix(((void**)Xinv_h->Gauge_p())[0], ((void**)X_h->Gauge_p())[0], n, X_h->Volume(), X.Precision(), QUDA_CPU_FIELD_LOCATION);
148  } else {
149  errorQuda("Unsupported location=%d and order=%d", X.Location(), X.Order());
150  }
151 
152  // now exchange Y halos of both forwards and backwards links for multi-process dslash
153  const_cast<GaugeField&>(Y).exchangeGhost(QUDA_LINK_BIDIRECTIONAL);
154 
155  // compute the preconditioned links
156  // Yhat_back(x-\mu) = Y_back(x-\mu) * Xinv^dagger(x) (positive projector)
157  // Yhat_fwd(x) = Xinv(x) * Y_fwd(x) (negative projector)
158  {
159  int xc_size[5];
160  for (int i=0; i<4; i++) xc_size[i] = X.X()[i];
161  xc_size[4] = 1;
162 
163  // use spin-ignorant accessor to make multiplication simpler
164  typedef typename gauge::FieldOrder<Float,N,1,gOrder,true,storeFloat> gCoarse;
165  typedef typename gauge::FieldOrder<Float,N,1,gOrder,true,storeFloat> gPreconditionedCoarse;
166  gCoarse yAccessor(const_cast<GaugeField&>(Y));
167  gPreconditionedCoarse yHatAccessor(const_cast<GaugeField&>(Yhat));
168  gCoarse xInvAccessor(const_cast<GaugeField&>(Xinv));
169  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Xinv = %e\n", Xinv.norm2(0));
170 
171  int comm_dim[4];
172  for (int i=0; i<4; i++) comm_dim[i] = comm_dim_partitioned(i);
173  typedef CalculateYhatArg<Float, gPreconditionedCoarse, gCoarse, N> yHatArg;
174  yHatArg arg(yHatAccessor, yAccessor, xInvAccessor, xc_size, comm_dim, 1);
175 
176  CalculateYhat<Float, N, yHatArg> yHat(arg, Y);
177  if (Yhat.Precision() == QUDA_HALF_PRECISION || Yhat.Precision() == QUDA_QUARTER_PRECISION) {
178  yHat.setComputeMaxOnly(true);
179  yHat.apply(0);
180 
181  double max_h_double = *arg.max_h;
182  comm_allreduce_max(&max_h_double);
183  *arg.max_h = static_cast<Float>(max_h_double);
184 
185  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("Yhat Max = %e\n", *arg.max_h);
186 
187  Yhat.Scale(*arg.max_h);
188  arg.Yhat.resetScale(*arg.max_h);
189  }
190  yHat.setComputeMaxOnly(false);
191  yHat.apply(0);
192 
193  if (getVerbosity() >= QUDA_VERBOSE)
194  for (int d = 0; d < 8; d++)
195  printfQuda("Yhat[%d] = %e (%e %e = %e x %e)\n", d, Yhat.norm2(d), Yhat.abs_max(d),
196  Y.abs_max(d) * Xinv.abs_max(0), Y.abs_max(d), Xinv.abs_max(0));
197  }
198 
199  // fill back in the bulk of Yhat so that the backward link is updated on the previous node
200  // need to put this in the bulk of the previous node - but only send backwards the backwards
201  // links to and not overwrite the forwards bulk
202  Yhat.injectGhost(QUDA_LINK_BACKWARDS);
203 
204  // exchange forwards links for multi-process dslash dagger
205  // need to put this in the ghost zone of the next node - but only send forwards the forwards
206  // links and not overwrite the backwards ghost
207  Yhat.exchangeGhost(QUDA_LINK_FORWARDS);
208  }
209 
210  template <typename storeFloat, typename Float, int N>
211  void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X)
212  {
213  if (Y.Location() == QUDA_CPU_FIELD_LOCATION) {
214  constexpr QudaGaugeFieldOrder gOrder = QUDA_QDP_GAUGE_ORDER;
215  if (Y.FieldOrder() != gOrder) errorQuda("Unsupported field order %d\n", Y.FieldOrder());
216  calculateYhat<storeFloat,Float,N,gOrder>(Yhat, Xinv, Y, X);
217  } else {
218  constexpr QudaGaugeFieldOrder gOrder = QUDA_FLOAT2_GAUGE_ORDER;
219  if (Y.FieldOrder() != gOrder) errorQuda("Unsupported field order %d\n", Y.FieldOrder());
220  calculateYhat<storeFloat,Float,N,gOrder>(Yhat, Xinv, Y, X);
221  }
222  }
223 
224  // template on the number of coarse degrees of freedom
225  template <typename storeFloat, typename Float>
226  void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X) {
227  switch (Y.Ncolor()) {
228  case 2: calculateYhat<storeFloat,Float, 2>(Yhat, Xinv, Y, X); break;
229  case 4: calculateYhat<storeFloat,Float, 4>(Yhat, Xinv, Y, X); break;
230  case 8: calculateYhat<storeFloat,Float, 8>(Yhat, Xinv, Y, X); break;
231  case 12: calculateYhat<storeFloat,Float,12>(Yhat, Xinv, Y, X); break;
232  case 16: calculateYhat<storeFloat,Float,16>(Yhat, Xinv, Y, X); break;
233  case 20: calculateYhat<storeFloat,Float,20>(Yhat, Xinv, Y, X); break;
234  case 24: calculateYhat<storeFloat,Float,24>(Yhat, Xinv, Y, X); break;
235  case 32: calculateYhat<storeFloat,Float,32>(Yhat, Xinv, Y, X); break;
236  case 48: calculateYhat<storeFloat,Float,48>(Yhat, Xinv, Y, X); break;
237  case 64: calculateYhat<storeFloat,Float,64>(Yhat, Xinv, Y, X); break;
238  default: errorQuda("Unsupported number of coarse dof %d\n", Y.Ncolor()); break;
239  }
240  }
241 
242 #endif
243 
244  //Does the heavy lifting of creating the coarse color matrices Y
245  void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X) {
246 
247 #ifdef GPU_MULTIGRID
248  QudaPrecision precision = checkPrecision(Xinv, Y, X);
249  if (getVerbosity() >= QUDA_SUMMARIZE) printfQuda("Computing Yhat field......\n");
250 
251  if (precision == QUDA_DOUBLE_PRECISION) {
252 #ifdef GPU_MULTIGRID_DOUBLE
253  if (Yhat.Precision() != QUDA_DOUBLE_PRECISION) errorQuda("Unsupported precision %d\n", Yhat.Precision());
254  calculateYhat<double,double>(Yhat, Xinv, Y, X);
255 #else
256  errorQuda("Double precision multigrid has not been enabled");
257 #endif
258  } else if (precision == QUDA_SINGLE_PRECISION) {
259  if (Yhat.Precision() == QUDA_SINGLE_PRECISION) {
260  calculateYhat<float, float>(Yhat, Xinv, Y, X);
261  } else {
262  errorQuda("Unsupported precision %d\n", precision);
263  }
264  } else if (precision == QUDA_HALF_PRECISION) {
265  if (Yhat.Precision() == QUDA_HALF_PRECISION) {
266  calculateYhat<short, float>(Yhat, Xinv, Y, X);
267  } else {
268  errorQuda("Unsupported precision %d\n", precision);
269  }
270  } else {
271  errorQuda("Unsupported precision %d\n", precision);
272  }
273 
274  if (getVerbosity() >= QUDA_SUMMARIZE) printfQuda("....done computing Yhat field\n");
275 #else
276  errorQuda("Multigrid has not been built");
277 #endif
278  }
279 
280 } //namespace quda
281 
#define pool_pinned_free(ptr)
Definition: malloc_quda.h:128
enum QudaPrecision_s QudaPrecision
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
#define checkPrecision(...)
#define errorQuda(...)
Definition: util_quda.h:121
Helper file when using jitify run-time compilation. This file should be included in source code...
int comm_dim(int dim)
cudaStream_t * stream
const char * comm_dim_partitioned_string(const int *comm_dim_override=0)
Return a string that defines the comm partitioning (used as a tuneKey)
const char * compile_type_str(const LatticeField &meta, QudaFieldLocation location_=QUDA_INVALID_FIELD_LOCATION)
Helper function for setting auxilary string.
QudaGaugeParam param
Definition: pack_test.cpp:17
cudaError_t qudaStreamSynchronize(cudaStream_t &stream)
Wrapper around cudaStreamSynchronize or cuStreamSynchronize.
cpuGaugeField * Xinv_h
long long BatchInvertMatrix(void *Ainv, void *A, const int n, const int batch, QudaPrecision precision, QudaFieldLocation location)
Definition: blas_cublas.cu:54
#define pool_device_malloc(size)
Definition: malloc_quda.h:125
TuneParam & tuneLaunch(Tunable &tunable, QudaTune enabled, QudaVerbosity verbosity)
Definition: tune.cpp:643
enum QudaGaugeFieldOrder_s QudaGaugeFieldOrder
int X[4]
Definition: covdev_test.cpp:70
char * getOmpThreadStr()
Returns a string of the form ",omp_threads=$OMP_NUM_THREADS", which can be used for storing the numbe...
Definition: util_quda.cpp:134
bool activeTuning()
query if tuning is in progress
Definition: tune.cpp:121
void calculateYhat(GaugeField &Yhat, GaugeField &Xinv, const GaugeField &Y, const GaugeField &X)
Calculate preconditioned coarse links and coarse clover inverse field.
#define pool_pinned_malloc(size)
Definition: malloc_quda.h:127
#define qudaMemcpyAsync(dst, src, count, kind, stream)
Definition: quda_cuda_api.h:38
cpuGaugeField * X_h
static const int aux_n
Definition: tune_key.h:12
#define printfQuda(...)
Definition: util_quda.h:115
unsigned long long flops
Definition: blas_quda.cu:22
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
#define pool_device_free(ptr)
Definition: malloc_quda.h:126
void comm_allreduce_max(double *data)
Definition: comm_mpi.cpp:258
QudaTune getTuning()
Query whether autotuning is enabled or not. Default is enabled but can be overridden by setting QUDA_...
Definition: util_quda.cpp:52
QudaPrecision Precision() const
unsigned long long bytes
Definition: blas_quda.cu:23
int comm_dim_partitioned(int dim)
virtual bool advanceTuneParam(TuneParam &param) const
Definition: tune_quda.h:335