QUDA  1.0.0
milc_interface.cpp
Go to the documentation of this file.
1 #include <cstdio>
2 #include <cstdlib>
3 #include <iostream>
4 #include <quda.h>
5 #include <quda_milc_interface.h>
6 #include <quda_internal.h>
7 #include <color_spinor_field.h>
8 #include <string.h>
9 #include <unitarization_links.h>
10 #include <ks_improved_force.h>
11 #include <dslash_quda.h>
12 
13 #define MAX(a,b) ((a)>(b)?(a):(b))
14 
15 #ifdef BUILD_MILC_INTERFACE
16 
17 // code for NVTX taken from Jiri Kraus' blog post:
18 // http://devblogs.nvidia.com/parallelforall/cuda-pro-tip-generate-custom-application-profile-timelines-nvtx/
19 
20 #ifdef INTERFACE_NVTX
21 
22 #if QUDA_NVTX_VERSION == 3
23 #include "nvtx3/nvToolsExt.h"
24 #else
25 #include "nvToolsExt.h"
26 #endif
27 
28 static const uint32_t colors[] = { 0x0000ff00, 0x000000ff, 0x00ffff00, 0x00ff00ff, 0x0000ffff, 0x00ff0000, 0x00ffffff };
29 static const int num_colors = sizeof(colors)/sizeof(uint32_t);
30 
31 #define PUSH_RANGE(name,cid) { \
32  int color_id = cid; \
33  color_id = color_id%num_colors;\
34  nvtxEventAttributes_t eventAttrib = {0}; \
35  eventAttrib.version = NVTX_VERSION; \
36  eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; \
37  eventAttrib.colorType = NVTX_COLOR_ARGB; \
38  eventAttrib.color = colors[color_id]; \
39  eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; \
40  eventAttrib.message.ascii = name; \
41  nvtxRangePushEx(&eventAttrib); \
42 }
43 #define POP_RANGE nvtxRangePop();
44 #else
45 #define PUSH_RANGE(name,cid)
46 #define POP_RANGE
47 #endif
48 
49 
50 static bool initialized = false;
51 static int gridDim[4];
52 static int localDim[4];
53 
54 static bool invalidate_quda_gauge = true;
55 static bool create_quda_gauge = false;
56 
57 static bool invalidate_quda_mom = true;
58 
59 static void *df_preconditioner = nullptr;
60 
61 // set to 1 for GPU resident pipeline (not yet supported in mainline MILC)
62 #define MOM_PIPE 0
63 
64 using namespace quda;
65 using namespace quda::fermion_force;
66 
67 
68 #define QUDAMILC_VERBOSE 1
69 
70 template <bool start> void inline qudamilc_called(const char *func, QudaVerbosity verb)
71 {
72  // add NVTX markup if enabled
73  if (start) {
74  PUSH_RANGE(func, 1);
75  } else {
76  POP_RANGE;
77  }
78 
79  #ifdef QUDAMILC_VERBOSE
80  if (verb >= QUDA_VERBOSE) {
81  if (start) {
82  printfQuda("QUDA_MILC_INTERFACE: %s (called) \n", func);
83  } else {
84  printfQuda("QUDA_MILC_INTERFACE: %s (return) \n", func);
85  }
86  }
87 #endif
88 }
89 
90 template <bool start> void inline qudamilc_called(const char *func) { qudamilc_called<start>(func, getVerbosity()); }
91 
92 void qudaSetMPICommHandle(void *mycomm) { setMPICommHandleQuda(mycomm); }
93 
94 void qudaInit(QudaInitArgs_t input)
95 {
96  if (initialized) return;
97  setVerbosityQuda(input.verbosity, "", stdout);
98  qudamilc_called<true>(__func__);
99  qudaSetLayout(input.layout);
100  initialized = true;
101  qudamilc_called<false>(__func__);
102 }
103 
104 void qudaFinalize()
105 {
106  qudamilc_called<true>(__func__);
107  endQuda();
108  qudamilc_called<false>(__func__);
109 }
110 #if defined(MULTI_GPU) && !defined(QMP_COMMS)
111 
115 static int rankFromCoords(const int *coords, void *fdata)
116 {
117  int *dims = static_cast<int *>(fdata);
118 
119  int rank = coords[3];
120  for (int i = 2; i >= 0; i--) {
121  rank = dims[i] * rank + coords[i];
122  }
123  return rank;
124 }
125 #endif
126 
127 void qudaSetLayout(QudaLayout_t input)
128 {
129  int local_dim[4];
130  for(int dir=0; dir<4; ++dir){ local_dim[dir] = input.latsize[dir]; }
131 #ifdef MULTI_GPU
132  for(int dir=0; dir<4; ++dir){ local_dim[dir] /= input.machsize[dir]; }
133 #endif
134  for(int dir=0; dir<4; ++dir){
135  if(local_dim[dir]%2 != 0){
136  printf("Error: Odd lattice dimensions are not supported\n");
137  exit(1);
138  }
139  }
140 
141  for(int dir=0; dir<4; ++dir) localDim[dir] = local_dim[dir];
142 
143 #ifdef MULTI_GPU
144  for(int dir=0; dir<4; ++dir) gridDim[dir] = input.machsize[dir];
145 #ifdef QMP_COMMS
146  initCommsGridQuda(4, gridDim, nullptr, nullptr);
147 #else
148  initCommsGridQuda(4, gridDim, rankFromCoords, (void *)(gridDim));
149 #endif
150  static int device = -1;
151 #else
152  for(int dir=0; dir<4; ++dir) gridDim[dir] = 1;
153  static int device = input.device;
154 #endif
155 
156  initQuda(device);
157 }
158 
159 void* qudaAllocatePinned(size_t bytes) {
160  return pool_pinned_malloc(bytes);
161 }
162 
163 void qudaFreePinned(void *ptr) {
164  pool_pinned_free(ptr);
165 }
166 
168 {
169 
170  static bool initialized = false;
171 
172  if(initialized) return;
173  qudamilc_called<true>(__func__);
174 
175 #if defined(GPU_HISQ_FORCE) || defined(GPU_UNITARIZE)
176  const bool reunit_allow_svd = (params.reunit_allow_svd) ? true : false;
177  const bool reunit_svd_only = (params.reunit_svd_only) ? true : false;
178  const double unitarize_eps = 1e-14;
179  const double max_error = 1e-10;
180 #endif
181 
182 #ifdef GPU_HISQ_FORCE
184  params.force_filter,
185  max_error,
186  reunit_allow_svd,
187  reunit_svd_only,
188  params.reunit_svd_rel_error,
189  params.reunit_svd_abs_error);
190 #endif
191 
192 #ifdef GPU_UNITARIZE
193  setUnitarizeLinksConstants(unitarize_eps,
194  max_error,
195  reunit_allow_svd,
196  reunit_svd_only,
197  params.reunit_svd_rel_error,
198  params.reunit_svd_abs_error);
199 #endif // UNITARIZE_GPU
200 
201  initialized = true;
202  qudamilc_called<false>(__func__);
203  return;
204 }
205 
206 
207 
208 static QudaGaugeParam newMILCGaugeParam(const int* dim, QudaPrecision prec, QudaLinkType link_type)
209 {
211  for(int dir=0; dir<4; ++dir) gParam.X[dir] = dim[dir];
212  gParam.cuda_prec_sloppy = gParam.cpu_prec = gParam.cuda_prec = prec;
213  gParam.type = link_type;
214 
217  gParam.t_boundary = QUDA_PERIODIC_T;
219  gParam.scale = 1.0;
220  gParam.anisotropy = 1.0;
221  gParam.tadpole_coeff = 1.0;
222  gParam.scale = 0;
223  gParam.ga_pad = 0;
224  gParam.site_ga_pad = 0;
225  gParam.mom_ga_pad = 0;
226  gParam.llfat_ga_pad = 0;
227  return gParam;
228 }
229 
230 static void invalidateGaugeQuda() {
231  qudamilc_called<true>(__func__);
232  freeGaugeQuda();
233  invalidate_quda_gauge = true;
234  qudamilc_called<false>(__func__);
235 }
236 
237 void qudaLoadKSLink(int prec, QudaFatLinkArgs_t fatlink_args,
238  const double act_path_coeff[6], void* inlink, void* fatlink, void* longlink)
239 {
240  qudamilc_called<true>(__func__);
241 
242  QudaGaugeParam param = newMILCGaugeParam(localDim,
245 
246  param.staggered_phase_applied = 1;
247  param.staggered_phase_type = QUDA_STAGGERED_PHASE_MILC;
248 
249  computeKSLinkQuda(fatlink, longlink, nullptr, inlink, const_cast<double*>(act_path_coeff), &param);
250 
251  // requires loadGaugeQuda to be called in subequent solver
252  invalidateGaugeQuda();
253 
254  // this flags that we are using QUDA to create the HISQ links
255  create_quda_gauge = true;
256  qudamilc_called<false>(__func__);
257 }
258 
259 
260 
261 void qudaLoadUnitarizedLink(int prec, QudaFatLinkArgs_t fatlink_args,
262  const double act_path_coeff[6], void* inlink, void* fatlink, void* ulink)
263 {
264  qudamilc_called<true>(__func__);
265 
266  QudaGaugeParam param = newMILCGaugeParam(localDim,
269 
270  computeKSLinkQuda(fatlink, nullptr, ulink, inlink, const_cast<double*>(act_path_coeff), &param);
271  qudamilc_called<false>(__func__);
272 
273  // requires loadGaugeQuda to be called in subequent solver
274  invalidateGaugeQuda();
275 
276  // this flags that we are using QUDA to create the HISQ links
277  create_quda_gauge = true;
278  qudamilc_called<false>(__func__);
279 }
280 
281 
282 void qudaHisqForce(int prec, int num_terms, int num_naik_terms, double dt, double** coeff, void** quark_field,
283  const double level2_coeff[6], const double fat7_coeff[6],
284  const void* const w_link, const void* const v_link, const void* const u_link,
285  void* const milc_momentum)
286 {
287  qudamilc_called<true>(__func__);
288 
290 
291  if (!invalidate_quda_mom) {
292  gParam.use_resident_mom = true;
293  gParam.make_resident_mom = true;
294  gParam.return_result_mom = false;
295  } else {
296  gParam.use_resident_mom = false;
297  gParam.make_resident_mom = false;
298  gParam.return_result_mom = true;
299  }
300 
301  computeHISQForceQuda(milc_momentum, dt, level2_coeff, fat7_coeff,
302  w_link, v_link, u_link,
303  quark_field, num_terms, num_naik_terms, coeff,
304  &gParam);
305  qudamilc_called<false>(__func__);
306  return;
307 }
308 
309 
310 void qudaAsqtadForce(int prec, const double act_path_coeff[6],
311  const void* const one_link_src[4], const void* const naik_src[4],
312  const void* const link, void* const milc_momentum)
313 {
314  errorQuda("This interface has been removed and is no longer supported");
315 }
316 
317 
318 
319 void qudaComputeOprod(int prec, int num_terms, int num_naik_terms, double** coeff, double scale,
320  void** quark_field, void* oprod[3])
321 {
322  errorQuda("This interface has been removed and is no longer supported");
323 }
324 
325 
326 void qudaUpdateU(int prec, double eps, QudaMILCSiteArg_t *arg)
327 {
328  qudamilc_called<true>(__func__);
329  QudaGaugeParam gaugeParam = newMILCGaugeParam(localDim,
332  void *gauge = arg->site ? arg->site : arg->link;
333  void *mom = arg->site ? arg->site : arg->mom;
334 
335  gaugeParam.gauge_offset = arg->link_offset;
336  gaugeParam.mom_offset = arg->mom_offset;
337  gaugeParam.site_size = arg->size;
338  gaugeParam.gauge_order = arg->site ? QUDA_MILC_SITE_GAUGE_ORDER : QUDA_MILC_GAUGE_ORDER;
339 
340  if (!invalidate_quda_mom) {
341  gaugeParam.use_resident_mom = true;
342  gaugeParam.make_resident_mom = true;
343  } else {
344  gaugeParam.use_resident_mom = false;
345  gaugeParam.make_resident_mom = false;
346  }
347 
348  updateGaugeFieldQuda(gauge, mom, eps, 0, 0, &gaugeParam);
349  qudamilc_called<false>(__func__);
350  return;
351 }
352 
353 void qudaRephase(int prec, void *gauge, int flag, double i_mu)
354 {
355  qudamilc_called<true>(__func__);
356  QudaGaugeParam gaugeParam = newMILCGaugeParam(localDim,
359 
360  gaugeParam.staggered_phase_applied = 1-flag;
361  gaugeParam.staggered_phase_type = QUDA_STAGGERED_PHASE_MILC;
362  gaugeParam.i_mu = i_mu;
363  gaugeParam.t_boundary = QUDA_ANTI_PERIODIC_T;
364 
365  staggeredPhaseQuda(gauge, &gaugeParam);
366  qudamilc_called<false>(__func__);
367  return;
368 }
369 
370 void qudaUnitarizeSU3(int prec, double tol, QudaMILCSiteArg_t *arg)
371 {
372  qudamilc_called<true>(__func__);
373  QudaGaugeParam gaugeParam = newMILCGaugeParam(localDim,
376 
377  void *gauge = arg->site ? arg->site : arg->link;
378  gaugeParam.gauge_offset = arg->link_offset;
379  gaugeParam.site_size = arg->size;
380  gaugeParam.gauge_order = arg->site ? QUDA_MILC_SITE_GAUGE_ORDER : QUDA_MILC_GAUGE_ORDER;
381 
382  projectSU3Quda(gauge, tol, &gaugeParam);
383  qudamilc_called<false>(__func__);
384  return;
385 }
386 
387 double qudaMomAction(int prec, void *momentum)
388 {
389  qudamilc_called<true>(__func__);
390 
391  QudaGaugeParam momParam = newMILCGaugeParam(localDim,
394 
395  if (MOM_PIPE) {
396  if (invalidate_quda_mom) {
397  // beginning of trajectory so download the momentum and make
398  // resident
399  momParam.use_resident_mom = false;
400  momParam.make_resident_mom = true;
401  invalidate_quda_mom = false;
402  } else {
403  // end of trajectory so use resident and then invalidate
404  momParam.use_resident_mom = true;
405  momParam.make_resident_mom = false;
406  invalidate_quda_mom = true;
407  }
408  } else { // no momentum residency
409  momParam.use_resident_mom = false;
410  momParam.make_resident_mom = false;
411  invalidate_quda_mom = true;
412  }
413 
414  double action = momActionQuda(momentum, &momParam);
415 
416  qudamilc_called<false>(__func__);
417 
418  return action;
419 }
420 
421 static inline int opp(int dir){
422  return 7-dir;
423 }
424 
425 
426 static void createGaugeForcePaths(int **paths, int dir, int num_loop_types){
427 
428  int index=0;
429  // Plaquette paths
430  if (num_loop_types >= 1)
431  for(int i=0; i<4; ++i){
432  if(i==dir) continue;
433  paths[index][0] = i; paths[index][1] = opp(dir); paths[index++][2] = opp(i);
434  paths[index][0] = opp(i); paths[index][1] = opp(dir); paths[index++][2] = i;
435  }
436 
437  // Rectangle Paths
438  if (num_loop_types >= 2)
439  for(int i=0; i<4; ++i){
440  if(i==dir) continue;
441  paths[index][0] = paths[index][1] = i; paths[index][2] = opp(dir); paths[index][3] = paths[index][4] = opp(i);
442  index++;
443  paths[index][0] = paths[index][1] = opp(i); paths[index][2] = opp(dir); paths[index][3] = paths[index][4] = i;
444  index++;
445  paths[index][0] = dir; paths[index][1] = i; paths[index][2] = paths[index][3] = opp(dir); paths[index][4] = opp(i);
446  index++;
447  paths[index][0] = dir; paths[index][1] = opp(i); paths[index][2] = paths[index][3] = opp(dir); paths[index][4] = i;
448  index++;
449  paths[index][0] = i; paths[index][1] = paths[index][2] = opp(dir); paths[index][3] = opp(i); paths[index][4] = dir;
450  index++;
451  paths[index][0] = opp(i); paths[index][1] = paths[index][2] = opp(dir); paths[index][3] = i; paths[index][4] = dir;
452  index++;
453  }
454 
455  if (num_loop_types >= 3) {
456  // Staple paths
457  for(int i=0; i<4; ++i){
458  for(int j=0; j<4; ++j){
459  if(i==dir || j==dir || i==j) continue;
460  paths[index][0] = i; paths[index][1] = j; paths[index][2] = opp(dir); paths[index][3] = opp(i), paths[index][4] = opp(j);
461  index++;
462  paths[index][0] = i; paths[index][1] = opp(j); paths[index][2] = opp(dir); paths[index][3] = opp(i), paths[index][4] = j;
463  index++;
464  paths[index][0] = opp(i); paths[index][1] = j; paths[index][2] = opp(dir); paths[index][3] = i, paths[index][4] = opp(j);
465  index++;
466  paths[index][0] = opp(i); paths[index][1] = opp(j); paths[index][2] = opp(dir); paths[index][3] = i, paths[index][4] = j;
467  index++;
468  }
469  }
470  }
471 
472 }
473 
474 
475 void qudaGaugeForce( int precision,
476  int num_loop_types,
477  double milc_loop_coeff[3],
478  double eb3,
479  QudaMILCSiteArg_t *arg)
480 {
481  qudamilc_called<true>(__func__);
482 
483  int numPaths = 0;
484  switch (num_loop_types) {
485  case 1:
486  numPaths = 6;
487  break;
488  case 2:
489  numPaths = 24;
490  break;
491  case 3:
492  numPaths = 48;
493  break;
494  default:
495  errorQuda("Invalid num_loop_types = %d\n", num_loop_types);
496  }
497 
498  QudaGaugeParam qudaGaugeParam = newMILCGaugeParam(localDim,
499  (precision==1) ? QUDA_SINGLE_PRECISION : QUDA_DOUBLE_PRECISION,
501  void *gauge = arg->site ? arg->site : arg->link;
502  void *mom = arg->site ? arg->site : arg->mom;
503 
504  qudaGaugeParam.gauge_offset = arg->link_offset;
505  qudaGaugeParam.mom_offset = arg->mom_offset;
506  qudaGaugeParam.site_size = arg->size;
508 
509  double *loop_coeff = static_cast<double*>(safe_malloc(numPaths*sizeof(double)));
510  int *length = static_cast<int*>(safe_malloc(numPaths*sizeof(int)));
511 
512  if (num_loop_types >= 1) for(int i= 0; i< 6; ++i) {
513  loop_coeff[i] = milc_loop_coeff[0];
514  length[i] = 3;
515  }
516  if (num_loop_types >= 2) for(int i= 6; i<24; ++i) {
517  loop_coeff[i] = milc_loop_coeff[1];
518  length[i] = 5;
519  }
520  if (num_loop_types >= 3) for(int i=24; i<48; ++i) {
521  loop_coeff[i] = milc_loop_coeff[2];
522  length[i] = 5;
523  }
524 
525  int** input_path_buf[4];
526  for(int dir=0; dir<4; ++dir){
527  input_path_buf[dir] = static_cast<int**>(safe_malloc(numPaths*sizeof(int*)));
528  for(int i=0; i<numPaths; ++i){
529  input_path_buf[dir][i] = static_cast<int*>(safe_malloc(length[i]*sizeof(int)));
530  }
531  createGaugeForcePaths(input_path_buf[dir], dir, num_loop_types);
532  }
533 
534  if (!invalidate_quda_mom) {
535  qudaGaugeParam.use_resident_mom = true;
536  qudaGaugeParam.make_resident_mom = true;
537  qudaGaugeParam.return_result_mom = false;
538 
539  // this means when we compute the momentum, we acummulate to the
540  // preexisting resident momentum instead of overwriting it
541  qudaGaugeParam.overwrite_mom = false;
542  } else {
543  qudaGaugeParam.use_resident_mom = false;
544  qudaGaugeParam.make_resident_mom = false;
545  qudaGaugeParam.return_result_mom = true;
546 
547  // this means we compute momentum into a fresh field, copy it back
548  // and sum to current momentum in MILC. This saves an initial
549  // CPU->GPU download of the current momentum.
550  qudaGaugeParam.overwrite_mom = false;
551  }
552 
553  int max_length = 6;
554 
555  computeGaugeForceQuda(mom, gauge, input_path_buf, length,
556  loop_coeff, numPaths, max_length, eb3, &qudaGaugeParam);
557 
558  for(int dir=0; dir<4; ++dir){
559  for(int i=0; i<numPaths; ++i) host_free(input_path_buf[dir][i]);
560  host_free(input_path_buf[dir]);
561  }
562 
563  host_free(length);
564  host_free(loop_coeff);
565 
566  qudamilc_called<false>(__func__);
567  return;
568 }
569 
570 
571 static int getLinkPadding(const int dim[4])
572 {
573  int padding = MAX(dim[1]*dim[2]*dim[3]/2, dim[0]*dim[2]*dim[3]/2);
574  padding = MAX(padding, dim[0]*dim[1]*dim[3]/2);
575  padding = MAX(padding, dim[0]*dim[1]*dim[2]/2);
576  return padding;
577 }
578 
579 // set the params for the single mass solver
580 static void setInvertParams(const int dim[4], QudaPrecision cpu_prec, QudaPrecision cuda_prec,
581  QudaPrecision cuda_prec_sloppy, double mass, double target_residual,
582  double target_residual_hq, int maxiter, double reliable_delta, QudaParity parity,
584 {
585  invertParam->verbosity = verbosity;
586  invertParam->mass = mass;
587  invertParam->tol = target_residual;
588  invertParam->tol_hq = target_residual_hq;
589 
590  invertParam->residual_type = static_cast<QudaResidualType_s>(0);
591  invertParam->residual_type = (target_residual != 0) ?
592  static_cast<QudaResidualType_s>(invertParam->residual_type | QUDA_L2_RELATIVE_RESIDUAL) :
593  invertParam->residual_type;
594  invertParam->residual_type = (target_residual_hq != 0) ?
595  static_cast<QudaResidualType_s>(invertParam->residual_type | QUDA_HEAVY_QUARK_RESIDUAL) :
596  invertParam->residual_type;
597 
598  invertParam->heavy_quark_check = (invertParam->residual_type & QUDA_HEAVY_QUARK_RESIDUAL ? 1 : 0);
599  if (invertParam->heavy_quark_check) {
600  invertParam->max_hq_res_increase = 5; // this caps the number of consecutive hq residual increases
601  invertParam->max_hq_res_restart_total = 10; // this caps the number of hq restarts in case of solver stalling
602  }
603 
604  invertParam->use_sloppy_partial_accumulator = 0;
605  invertParam->num_offset = 0;
606 
607  invertParam->inv_type = inverter;
608  invertParam->maxiter = maxiter;
609  invertParam->reliable_delta = reliable_delta;
610 
612  invertParam->cpu_prec = cpu_prec;
613  invertParam->cuda_prec = cuda_prec;
614  invertParam->cuda_prec_sloppy = invertParam->heavy_quark_check ? cuda_prec : cuda_prec_sloppy;
616 
617  invertParam->solution_type = QUDA_MATPC_SOLUTION;
618  invertParam->solve_type = QUDA_DIRECT_PC_SOLVE;
620  invertParam->gamma_basis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; // not used, but required by the code.
621  invertParam->dirac_order = QUDA_DIRAC_ORDER;
622 
623  invertParam->dslash_type = QUDA_ASQTAD_DSLASH;
624  invertParam->Ls = 1;
625  invertParam->gflops = 0.0;
626 
629 
630  if (parity == QUDA_EVEN_PARITY) { // even parity
631  invertParam->matpc_type = QUDA_MATPC_EVEN_EVEN;
632  } else if (parity == QUDA_ODD_PARITY) {
633  invertParam->matpc_type = QUDA_MATPC_ODD_ODD;
634  } else {
635  errorQuda("Invalid parity\n");
636  }
637 
638  invertParam->dagger = QUDA_DAG_NO;
639  invertParam->sp_pad = 0;
641 
642  // for the preconditioner
644  invertParam->tol_precondition = 1e-1;
645  invertParam->maxiter_precondition = 2;
646  invertParam->verbosity_precondition = QUDA_SILENT;
647 
648  invertParam->compute_action = 0;
649 }
650 
651 
652 // Set params for the multi-mass solver.
653 static void setInvertParams(const int dim[4], QudaPrecision cpu_prec, QudaPrecision cuda_prec,
654  QudaPrecision cuda_prec_sloppy, int num_offset, const double offset[],
655  const double target_residual_offset[], const double target_residual_hq_offset[],
657  QudaInverterType inverter, QudaInvertParam *invertParam)
658 {
659  const double null_mass = -1;
660 
661  setInvertParams(dim, cpu_prec, cuda_prec, cuda_prec_sloppy, null_mass, target_residual_offset[0],
662  target_residual_hq_offset[0], maxiter, reliable_delta, parity, verbosity, inverter, invertParam);
663 
664  invertParam->num_offset = num_offset;
665  for (int i = 0; i < num_offset; ++i) {
666  invertParam->offset[i] = offset[i];
667  invertParam->tol_offset[i] = target_residual_offset[i];
668  invertParam->tol_hq_offset[i] = target_residual_hq_offset[i];
669  }
670 }
671 
672 static void getReconstruct(QudaReconstructType &reconstruct, QudaReconstructType &reconstruct_sloppy)
673 {
674  {
675  char *reconstruct_env = getenv("QUDA_MILC_HISQ_RECONSTRUCT");
676  if (!reconstruct_env || strcmp(reconstruct_env, "18") == 0) {
677  reconstruct = QUDA_RECONSTRUCT_NO;
678  } else if (strcmp(reconstruct_env, "13") == 0) {
679  reconstruct = QUDA_RECONSTRUCT_13;
680  } else if (strcmp(reconstruct_env, "9") == 0) {
681  reconstruct = QUDA_RECONSTRUCT_9;
682  } else {
683  errorQuda("QUDA_MILC_HISQ_RECONSTRUCT=%s not supported", reconstruct_env);
684  }
685  }
686 
687  {
688  char *reconstruct_sloppy_env = getenv("QUDA_MILC_HISQ_RECONSTRUCT_SLOPPY");
689  if (!reconstruct_sloppy_env) { // if env is not set, default to using outer reconstruct type
690  reconstruct_sloppy = reconstruct;
691  } else if (strcmp(reconstruct_sloppy_env, "18") == 0) {
692  reconstruct_sloppy = QUDA_RECONSTRUCT_NO;
693  } else if (strcmp(reconstruct_sloppy_env, "13") == 0) {
694  reconstruct_sloppy = QUDA_RECONSTRUCT_13;
695  } else if (strcmp(reconstruct_sloppy_env, "9") == 0) {
696  reconstruct_sloppy = QUDA_RECONSTRUCT_9;
697  } else {
698  errorQuda("QUDA_MILC_HISQ_RECONSTRUCT_SLOPPY=%s not supported", reconstruct_sloppy_env);
699  }
700  }
701 }
702 
703 static void setGaugeParams(QudaGaugeParam &fat_param, QudaGaugeParam &long_param, const void *const fatlink,
704  const void *const longlink, const int dim[4], QudaPrecision cpu_prec,
705  QudaPrecision cuda_prec, QudaPrecision cuda_prec_sloppy, double tadpole, double naik_epsilon)
706 {
707  for (int dir = 0; dir < 4; ++dir) fat_param.X[dir] = dim[dir];
708 
709  fat_param.cpu_prec = cpu_prec;
710  fat_param.cuda_prec = cuda_prec;
713  fat_param.reconstruct = QUDA_RECONSTRUCT_NO;
716  fat_param.gauge_fix = QUDA_GAUGE_FIXED_NO;
717  fat_param.anisotropy = 1.0;
718  fat_param.t_boundary = QUDA_PERIODIC_T; // anti-periodic boundary conditions are built into the gauge field
720  fat_param.ga_pad = getLinkPadding(dim);
721 
722  if (longlink != nullptr) {
723  // improved staggered parameters
724  fat_param.type = QUDA_ASQTAD_FAT_LINKS;
725 
726  // now set the long link parameters needed
727  long_param = fat_param;
728  long_param.tadpole_coeff = tadpole;
729  long_param.scale = -(1.0 + naik_epsilon) / (24.0 * long_param.tadpole_coeff * long_param.tadpole_coeff);
730  long_param.type = QUDA_THREE_LINKS;
731  long_param.ga_pad = 3*fat_param.ga_pad;
732  getReconstruct(long_param.reconstruct, long_param.reconstruct_sloppy);
733  long_param.reconstruct_precondition = long_param.reconstruct_sloppy;
734  } else {
735  // naive staggered parameters
736  fat_param.type = QUDA_SU3_LINKS;
738  }
739 
740 }
741 
742 static void setColorSpinorParams(const int dim[4], QudaPrecision precision, ColorSpinorParam *param)
743 {
744  param->nColor = 3;
745  param->nSpin = 1;
746  param->nDim = 4;
747 
748  for (int dir = 0; dir < 4; ++dir) param->x[dir] = dim[dir];
749  param->x[0] /= 2;
750 
751  param->setPrecision(precision);
752  param->pad = 0;
756  param->gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; // meaningless, but required by the code.
758 }
759 
761  QudaExtLibType deflation_ext_lib, char vec_infile[], char vec_outfile[], QudaEigParam *df_param)
762 {
763 
764  df_param->import_vectors = strcmp(vec_infile,"") ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE;
765 
766  df_param->cuda_prec_ritz = ritz_prec;
767  df_param->location = location_ritz;
768  df_param->mem_type_ritz = mem_type_ritz;
769 
770 
771  df_param->run_verify = QUDA_BOOLEAN_FALSE;
772 
773  df_param->nk = df_param->invert_param->nev;
774  df_param->np = df_param->invert_param->nev*df_param->invert_param->deflation_grid;
775 
776  // set file i/o parameters
777  strcpy(df_param->vec_infile, vec_infile);
778  strcpy(df_param->vec_outfile, vec_outfile);
779 }
780 
781 static size_t getColorVectorOffset(QudaParity local_parity, bool even_odd_exchange, const int dim[4])
782 {
783  size_t offset;
784  int volume = dim[0]*dim[1]*dim[2]*dim[3];
785 
786  if(local_parity == QUDA_EVEN_PARITY){
787  offset = even_odd_exchange ? volume*6/2 : 0;
788  }else{
789  offset = even_odd_exchange ? 0 : volume*6/2;
790  }
791  return offset;
792 }
793 
794 void qudaMultishiftInvert(int external_precision, int quda_precision, int num_offsets, double *const offset,
795  QudaInvertArgs_t inv_args, const double target_residual[],
796  const double target_fermilab_residual[], const void *const fatlink,
797  const void *const longlink, void *source, void **solutionArray, double *const final_residual,
798  double *const final_fermilab_residual, int *num_iters)
799 {
800  static const QudaVerbosity verbosity = getVerbosity();
801  qudamilc_called<true>(__func__, verbosity);
802 
803  if (target_residual[0] == 0) errorQuda("qudaMultishiftInvert: zeroth target residual cannot be zero\n");
804 
805  QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
806  QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
807  const bool use_mixed_precision = (((quda_precision==2) && inv_args.mixed_precision) ||
808  ((quda_precision==1) && (inv_args.mixed_precision==2)) ) ? true : false;
809  QudaPrecision device_precision_sloppy;
810  switch(inv_args.mixed_precision) {
811  case 2: device_precision_sloppy = QUDA_HALF_PRECISION; break;
812  case 1: device_precision_sloppy = QUDA_SINGLE_PRECISION; break;
813  default: device_precision_sloppy = device_precision;
814  }
815 
816  QudaGaugeParam fat_param = newQudaGaugeParam();
817  QudaGaugeParam long_param = newQudaGaugeParam();
818  setGaugeParams(fat_param, long_param, fatlink, longlink, localDim, host_precision, device_precision,
819  device_precision_sloppy, inv_args.tadpole, inv_args.naik_epsilon);
820 
821  QudaInvertParam invertParam = newQudaInvertParam();
822 
823  QudaParity local_parity = inv_args.evenodd;
824  const double reliable_delta = (use_mixed_precision ? 1e-1 : 0.0);
825  setInvertParams(localDim, host_precision, device_precision, device_precision_sloppy, num_offsets, offset,
826  target_residual, target_fermilab_residual, inv_args.max_iter, reliable_delta, local_parity, verbosity,
827  QUDA_CG_INVERTER, &invertParam);
828 
829  if (inv_args.mixed_precision == 1) {
832  long_param.reconstruct_refinement_sloppy = long_param.reconstruct_sloppy;
834  invertParam.reliable_delta_refinement = 0.1;
835  }
836 
838  setColorSpinorParams(localDim, host_precision, &csParam);
839 
840  // dirty hack to invalidate the cached gauge field without breaking interface compatability
841  if (*num_iters == -1) invalidateGaugeQuda();
842 
843  // set the solver
844  if (invalidate_quda_gauge || !create_quda_gauge) {
845  loadGaugeQuda(const_cast<void *>(fatlink), &fat_param);
846  if (longlink != nullptr) loadGaugeQuda(const_cast<void *>(longlink), &long_param);
847  invalidate_quda_gauge = false;
848  }
849 
850  if (longlink == nullptr) invertParam.dslash_type = QUDA_STAGGERED_DSLASH;
851 
852  void** sln_pointer = (void**)malloc(num_offsets*sizeof(void*));
853  int quark_offset = getColorVectorOffset(local_parity, false, localDim) * host_precision;
854  void* src_pointer = static_cast<char*>(source) + quark_offset;
855 
856  for (int i = 0; i < num_offsets; ++i) sln_pointer[i] = static_cast<char *>(solutionArray[i]) + quark_offset;
857 
858  invertMultiShiftQuda(sln_pointer, src_pointer, &invertParam);
859  free(sln_pointer);
860 
861  // return the number of iterations taken by the inverter
862  *num_iters = invertParam.iter;
863  for (int i = 0; i < num_offsets; ++i) {
864  final_residual[i] = invertParam.true_res_offset[i];
865  final_fermilab_residual[i] = invertParam.true_res_hq_offset[i];
866  } // end loop over number of offsets
867 
868  if (!create_quda_gauge) invalidateGaugeQuda();
869 
870  qudamilc_called<false>(__func__, verbosity);
871 } // qudaMultiShiftInvert
872 
873 void qudaInvert(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args,
874  double target_residual, double target_fermilab_residual, const void *const fatlink,
875  const void *const longlink, void *source, void *solution, double *const final_residual,
876  double *const final_fermilab_residual, int *num_iters)
877 {
878  static const QudaVerbosity verbosity = getVerbosity();
879  qudamilc_called<true>(__func__, verbosity);
880 
881  if (target_fermilab_residual == 0 && target_residual == 0) errorQuda("qudaInvert: requesting zero residual\n");
882 
883  // static const QudaVerbosity verbosity = getVerbosity();
884  QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
885  QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
886  QudaPrecision device_precision_sloppy;
887 
888  switch(inv_args.mixed_precision) {
889  case 2: device_precision_sloppy = QUDA_HALF_PRECISION; break;
890  case 1: device_precision_sloppy = QUDA_SINGLE_PRECISION; break;
891  default: device_precision_sloppy = device_precision;
892  }
893 
894  QudaGaugeParam fat_param = newQudaGaugeParam();
895  QudaGaugeParam long_param = newQudaGaugeParam();
896  setGaugeParams(fat_param, long_param, fatlink, longlink, localDim, host_precision, device_precision,
897  device_precision_sloppy, inv_args.tadpole, inv_args.naik_epsilon);
898 
899  QudaInvertParam invertParam = newQudaInvertParam();
900 
901  QudaParity local_parity = inv_args.evenodd;
902  const double reliable_delta = 1e-1;
903 
904  setInvertParams(localDim, host_precision, device_precision, device_precision_sloppy, mass, target_residual,
905  target_fermilab_residual, inv_args.max_iter, reliable_delta, local_parity, verbosity,
906  QUDA_CG_INVERTER, &invertParam);
907 
909  setColorSpinorParams(localDim, host_precision, &csParam);
910 
911  // dirty hack to invalidate the cached gauge field without breaking interface compatability
912  if (*num_iters == -1 || !canReuseResidentGauge(&invertParam)) invalidateGaugeQuda();
913 
914  if (invalidate_quda_gauge || !create_quda_gauge) {
915  loadGaugeQuda(const_cast<void *>(fatlink), &fat_param);
916  if (longlink != nullptr) loadGaugeQuda(const_cast<void *>(longlink), &long_param);
917  invalidate_quda_gauge = false;
918  }
919 
920  if (longlink == nullptr) invertParam.dslash_type = QUDA_STAGGERED_DSLASH;
921 
922  int quark_offset = getColorVectorOffset(local_parity, false, localDim) * host_precision;
923 
924  invertQuda(static_cast<char *>(solution) + quark_offset, static_cast<char *>(source) + quark_offset, &invertParam);
925 
926  // return the number of iterations taken by the inverter
927  *num_iters = invertParam.iter;
928  *final_residual = invertParam.true_res;
929  *final_fermilab_residual = invertParam.true_res_hq;
930 
931  if (!create_quda_gauge) invalidateGaugeQuda();
932 
933  qudamilc_called<false>(__func__, verbosity);
934 } // qudaInvert
935 
936 
937 void qudaDslash(int external_precision, int quda_precision, QudaInvertArgs_t inv_args, const void *const fatlink,
938  const void *const longlink, void* src, void* dst, int* num_iters)
939 {
940  static const QudaVerbosity verbosity = getVerbosity();
941  qudamilc_called<true>(__func__, verbosity);
942 
943  // static const QudaVerbosity verbosity = getVerbosity();
944  QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
945  QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
946  QudaPrecision device_precision_sloppy = device_precision;
947 
948  QudaGaugeParam fat_param = newQudaGaugeParam();
949  QudaGaugeParam long_param = newQudaGaugeParam();
950  setGaugeParams(fat_param, long_param, fatlink, longlink, localDim, host_precision, device_precision,
951  device_precision_sloppy, inv_args.tadpole, inv_args.naik_epsilon);
952 
953  QudaInvertParam invertParam = newQudaInvertParam();
954 
955  QudaParity local_parity = inv_args.evenodd;
956  QudaParity other_parity = local_parity == QUDA_EVEN_PARITY ? QUDA_ODD_PARITY : QUDA_EVEN_PARITY;
957 
958  setInvertParams(localDim, host_precision, device_precision, device_precision_sloppy, 0.0, 0, 0, 0, 0.0, local_parity,
959  verbosity, QUDA_CG_INVERTER, &invertParam);
960 
962  setColorSpinorParams(localDim, host_precision, &csParam);
963 
964  // dirty hack to invalidate the cached gauge field without breaking interface compatability
965  if (*num_iters == -1 || !canReuseResidentGauge(&invertParam)) invalidateGaugeQuda();
966 
967  if (invalidate_quda_gauge || !create_quda_gauge) {
968  loadGaugeQuda(const_cast<void *>(fatlink), &fat_param);
969  if (longlink != nullptr) loadGaugeQuda(const_cast<void *>(longlink), &long_param);
970  invalidate_quda_gauge = false;
971  }
972 
973  if (longlink == nullptr) invertParam.dslash_type = QUDA_STAGGERED_DSLASH;
974 
975  int src_offset = getColorVectorOffset(other_parity, false, localDim);
976  int dst_offset = getColorVectorOffset(local_parity, false, localDim);
977 
978  dslashQuda(static_cast<char*>(dst) + dst_offset*host_precision,
979  static_cast<char*>(src) + src_offset*host_precision,
980  &invertParam, local_parity);
981 
982  if (!create_quda_gauge) invalidateGaugeQuda();
983 
984  qudamilc_called<false>(__func__, verbosity);
985 } // qudaDslash
986 
987 void qudaInvertMsrc(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args,
988  double target_residual, double target_fermilab_residual, const void *const fatlink,
989  const void *const longlink, void **sourceArray, void **solutionArray, double *const final_residual,
990  double *const final_fermilab_residual, int *num_iters, int num_src)
991 {
992  static const QudaVerbosity verbosity = getVerbosity();
993  qudamilc_called<true>(__func__, verbosity);
994 
995  if (target_fermilab_residual == 0 && target_residual == 0) errorQuda("qudaInvert: requesting zero residual\n");
996 
997  // static const QudaVerbosity verbosity = getVerbosity();
998  QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
999  QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
1000  QudaPrecision device_precision_sloppy;
1001 
1002  switch(inv_args.mixed_precision) {
1003  case 2: device_precision_sloppy = QUDA_HALF_PRECISION; break;
1004  case 1: device_precision_sloppy = QUDA_SINGLE_PRECISION; break;
1005  default: device_precision_sloppy = device_precision;
1006  }
1007 
1008  QudaGaugeParam fat_param = newQudaGaugeParam();
1009  QudaGaugeParam long_param = newQudaGaugeParam();
1010  setGaugeParams(fat_param, long_param, fatlink, longlink, localDim, host_precision, device_precision,
1011  device_precision_sloppy, inv_args.tadpole, inv_args.naik_epsilon);
1012 
1013  QudaInvertParam invertParam = newQudaInvertParam();
1014 
1015  QudaParity local_parity = inv_args.evenodd;
1016  const double reliable_delta = 1e-1;
1017 
1018  setInvertParams(localDim, host_precision, device_precision, device_precision_sloppy, mass, target_residual,
1019  target_fermilab_residual, inv_args.max_iter, reliable_delta, local_parity, verbosity,
1020  QUDA_CG_INVERTER, &invertParam);
1021  invertParam.num_src = num_src;
1022 
1024  setColorSpinorParams(localDim, host_precision, &csParam);
1025 
1026  // dirty hack to invalidate the cached gauge field without breaking interface compatability
1027  if (*num_iters == -1 || !canReuseResidentGauge(&invertParam)) invalidateGaugeQuda();
1028 
1029  if (invalidate_quda_gauge || !create_quda_gauge) {
1030  loadGaugeQuda(const_cast<void *>(fatlink), &fat_param);
1031  if (longlink != nullptr) loadGaugeQuda(const_cast<void *>(longlink), &long_param);
1032  invalidate_quda_gauge = false;
1033  }
1034 
1035  if (longlink == nullptr) invertParam.dslash_type = QUDA_STAGGERED_DSLASH;
1036 
1037  int quark_offset = getColorVectorOffset(local_parity, false, localDim) * host_precision;
1038  void** sln_pointer = (void**)malloc(num_src*sizeof(void*));
1039  void** src_pointer = (void**)malloc(num_src*sizeof(void*));
1040 
1041  for (int i = 0; i < num_src; ++i) sln_pointer[i] = static_cast<char *>(solutionArray[i]) + quark_offset;
1042  for (int i = 0; i < num_src; ++i) src_pointer[i] = static_cast<char *>(sourceArray[i]) + quark_offset;
1043 
1044  invertMultiSrcQuda(sln_pointer, src_pointer, &invertParam);
1045 
1046  free(sln_pointer);
1047  free(src_pointer);
1048 
1049  // return the number of iterations taken by the inverter
1050  *num_iters = invertParam.iter;
1051  *final_residual = invertParam.true_res;
1052  *final_fermilab_residual = invertParam.true_res_hq;
1053 
1054  if (!create_quda_gauge) invalidateGaugeQuda();
1055 
1056  qudamilc_called<false>(__func__, verbosity);
1057 } // qudaInvert
1058 
1059 void qudaEigCGInvert(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args,
1060  double target_residual, double target_fermilab_residual, const void *const fatlink,
1061  const void *const longlink,
1062  void *source, // array of source vectors -> overwritten on exit
1063  void *solution, // temporary
1064  QudaEigArgs_t eig_args,
1065  const int rhs_idx, // current rhs
1066  const int last_rhs_flag, // is this the last rhs to solve
1067  double *const final_residual, double *const final_fermilab_residual, int *num_iters)
1068 {
1069  static const QudaVerbosity verbosity = getVerbosity();
1070  qudamilc_called<true>(__func__, verbosity);
1071 
1072  if (target_fermilab_residual == 0 && target_residual == 0) errorQuda("qudaInvert: requesting zero residual\n");
1073 
1074  QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
1075  QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
1076  QudaPrecision device_precision_sloppy;
1077 
1078  switch(inv_args.mixed_precision) {
1079  case 2: device_precision_sloppy = QUDA_HALF_PRECISION; break;
1080  case 1: device_precision_sloppy = QUDA_SINGLE_PRECISION; break;
1081  default: device_precision_sloppy = device_precision;
1082  }
1083 
1084  QudaGaugeParam fat_param = newQudaGaugeParam();
1085  QudaGaugeParam long_param = newQudaGaugeParam();
1086  setGaugeParams(fat_param, long_param, fatlink, longlink, localDim, host_precision, device_precision,
1087  device_precision_sloppy, inv_args.tadpole, inv_args.naik_epsilon);
1088 
1089  QudaInvertParam invertParam = newQudaInvertParam();
1090 
1091  QudaParity local_parity = inv_args.evenodd;
1092  double& target_res = target_residual;
1093  double& target_res_hq = target_fermilab_residual;
1094  const double reliable_delta = 1e-1;
1095 
1096  setInvertParams(localDim, host_precision, device_precision, device_precision_sloppy, mass, target_res, target_res_hq,
1097  inv_args.max_iter, reliable_delta, local_parity, verbosity, QUDA_CG_INVERTER, &invertParam);
1098 
1099  QudaEigParam df_param = newQudaEigParam();
1100  df_param.invert_param = &invertParam;
1101 
1102  invertParam.nev = eig_args.nev;
1103  invertParam.max_search_dim = eig_args.max_search_dim;
1104  invertParam.deflation_grid = eig_args.deflation_grid;
1105  invertParam.cuda_prec_ritz = eig_args.prec_ritz;
1106  invertParam.tol_restart = eig_args.tol_restart;
1107  invertParam.eigcg_max_restarts = eig_args.eigcg_max_restarts;
1108  invertParam.max_restart_num = eig_args.max_restart_num;
1109  invertParam.inc_tol = eig_args.inc_tol;
1110  invertParam.eigenval_tol = eig_args.eigenval_tol;
1111  invertParam.rhs_idx = rhs_idx;
1112 
1113  if ((inv_args.solver_type != QUDA_INC_EIGCG_INVERTER) && (inv_args.solver_type != QUDA_EIGCG_INVERTER))
1114  errorQuda("Incorrect inverter type.\n");
1115  invertParam.inv_type = inv_args.solver_type;
1116 
1118 
1119  setDeflationParam(eig_args.prec_ritz, eig_args.location_ritz, eig_args.mem_type_ritz, eig_args.deflation_ext_lib, eig_args.vec_infile, eig_args.vec_outfile, &df_param);
1120 
1122  setColorSpinorParams(localDim, host_precision, &csParam);
1123 
1124  // dirty hack to invalidate the cached gauge field without breaking interface compatability
1125  if (*num_iters == -1 || !canReuseResidentGauge(&invertParam)) invalidateGaugeQuda();
1126 
1127  if ((invalidate_quda_gauge || !create_quda_gauge) && (rhs_idx == 0)) { // do this for the first RHS
1128  loadGaugeQuda(const_cast<void *>(fatlink), &fat_param);
1129  if (longlink != nullptr) loadGaugeQuda(const_cast<void *>(longlink), &long_param);
1130  invalidate_quda_gauge = false;
1131  }
1132 
1133  if (longlink == nullptr) invertParam.dslash_type = QUDA_STAGGERED_DSLASH;
1134 
1135  int quark_offset = getColorVectorOffset(local_parity, false, localDim) * host_precision;
1136 
1137  if(rhs_idx == 0) df_preconditioner = newDeflationQuda(&df_param);
1138 
1139  invertParam.deflation_op = df_preconditioner;
1140 
1141  invertQuda(static_cast<char *>(solution) + quark_offset, static_cast<char *>(source) + quark_offset, &invertParam);
1142 
1143  if (last_rhs_flag) destroyDeflationQuda(df_preconditioner);
1144 
1145  // return the number of iterations taken by the inverter
1146  *num_iters = invertParam.iter;
1147  *final_residual = invertParam.true_res;
1148  *final_fermilab_residual = invertParam.true_res_hq;
1149 
1150  if (!create_quda_gauge && last_rhs_flag) invalidateGaugeQuda();
1151 
1152  qudamilc_called<false>(__func__, verbosity);
1153 } // qudaEigCGInvert
1154 
1155 
1156 static int clover_alloc = 0;
1157 
1158 void* qudaCreateGaugeField(void* gauge, int geometry, int precision)
1159 {
1160  qudamilc_called<true>(__func__);
1161  QudaPrecision qudaPrecision = (precision==2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
1162  QudaGaugeParam gaugeParam = newMILCGaugeParam(localDim, qudaPrecision,
1163  (geometry==1) ? QUDA_GENERAL_LINKS : QUDA_SU3_LINKS);
1164  qudamilc_called<false>(__func__);
1165  return createGaugeFieldQuda(gauge, geometry, &gaugeParam);
1166 }
1167 
1168 
1169 void qudaSaveGaugeField(void* gauge, void* inGauge)
1170 {
1171  qudamilc_called<true>(__func__);
1172  cudaGaugeField* cudaGauge = reinterpret_cast<cudaGaugeField*>(inGauge);
1173  QudaGaugeParam gaugeParam = newMILCGaugeParam(localDim, cudaGauge->Precision(), QUDA_GENERAL_LINKS);
1174  saveGaugeFieldQuda(gauge, inGauge, &gaugeParam);
1175  qudamilc_called<false>(__func__);
1176 }
1177 
1178 
1179 void qudaDestroyGaugeField(void* gauge)
1180 {
1181  qudamilc_called<true>(__func__);
1182  destroyGaugeFieldQuda(gauge);
1183  qudamilc_called<false>(__func__);
1184 }
1185 
1186 
1187 void setInvertParam(QudaInvertParam &invertParam, QudaInvertArgs_t &inv_args,
1188  int external_precision, int quda_precision, double kappa, double reliable_delta);
1189 
1190 void qudaCloverForce(void *mom, double dt, void **x, void **p, double *coeff, double kappa, double ck,
1191  int nvec, double multiplicity, void *gauge, int precision, QudaInvertArgs_t inv_args)
1192 {
1193  qudamilc_called<true>(__func__);
1194  QudaGaugeParam gaugeParam = newMILCGaugeParam(localDim,
1195  (precision==1) ? QUDA_SINGLE_PRECISION : QUDA_DOUBLE_PRECISION,
1197  gaugeParam.gauge_order = QUDA_MILC_GAUGE_ORDER; // refers to momentume gauge order
1198 
1199  QudaInvertParam invertParam = newQudaInvertParam();
1200  setInvertParam(invertParam, inv_args, precision, precision, kappa, 0);
1201  invertParam.num_offset = nvec;
1202  for (int i=0; i<nvec; ++i) invertParam.offset[i] = 0.0; // not needed
1203  invertParam.clover_coeff = 0.0; // not needed
1204 
1205  // solution types
1207  invertParam.solve_type = QUDA_NORMOP_PC_SOLVE;
1208  invertParam.inv_type = QUDA_CG_INVERTER;
1210 
1211  invertParam.verbosity = getVerbosity();
1212  invertParam.verbosity_precondition = QUDA_SILENT;
1213  invertParam.use_resident_solution = inv_args.use_resident_solution;
1214 
1215  computeCloverForceQuda(mom, dt, x, p, coeff, -kappa*kappa, ck, nvec, multiplicity,
1216  gauge, &gaugeParam, &invertParam);
1217  qudamilc_called<false>(__func__);
1218 }
1219 
1220 
1221 void setGaugeParams(QudaGaugeParam &gaugeParam, const int dim[4], QudaInvertArgs_t &inv_args,
1222  int external_precision, int quda_precision) {
1223 
1224  const QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
1225  const QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
1226  QudaPrecision device_precision_sloppy;
1227 
1228  switch(inv_args.mixed_precision) {
1229  case 2: device_precision_sloppy = QUDA_HALF_PRECISION; break;
1230  case 1: device_precision_sloppy = QUDA_SINGLE_PRECISION; break;
1231  default: device_precision_sloppy = device_precision;
1232  }
1233 
1234  for(int dir=0; dir<4; ++dir) gaugeParam.X[dir] = dim[dir];
1235 
1236  gaugeParam.anisotropy = 1.0;
1237  gaugeParam.type = QUDA_WILSON_LINKS;
1238  gaugeParam.gauge_order = QUDA_MILC_GAUGE_ORDER;
1239 
1240  // Check the boundary conditions
1241  // Can't have twisted or anti-periodic boundary conditions in the spatial
1242  // directions with 12 reconstruct at the moment.
1243  bool trivial_phase = true;
1244  for(int dir=0; dir<3; ++dir){
1245  if(inv_args.boundary_phase[dir] != 0) trivial_phase = false;
1246  }
1247  if(inv_args.boundary_phase[3] != 0 && inv_args.boundary_phase[3] != 1) trivial_phase = false;
1248 
1249  if(trivial_phase){
1250  gaugeParam.t_boundary = (inv_args.boundary_phase[3]) ? QUDA_ANTI_PERIODIC_T : QUDA_PERIODIC_T;
1251  gaugeParam.reconstruct = QUDA_RECONSTRUCT_12;
1253  }else{
1254  gaugeParam.t_boundary = QUDA_PERIODIC_T;
1255  gaugeParam.reconstruct = QUDA_RECONSTRUCT_NO;
1257  }
1258 
1259  gaugeParam.cpu_prec = host_precision;
1260  gaugeParam.cuda_prec = device_precision;
1261  gaugeParam.cuda_prec_sloppy = device_precision_sloppy;
1262  gaugeParam.cuda_prec_precondition = device_precision_sloppy;
1263  gaugeParam.gauge_fix = QUDA_GAUGE_FIXED_NO;
1264  gaugeParam.ga_pad = getLinkPadding(dim);
1265 }
1266 
1267 
1268 
1269 void setInvertParam(QudaInvertParam &invertParam, QudaInvertArgs_t &inv_args,
1270  int external_precision, int quda_precision, double kappa, double reliable_delta) {
1271 
1272  const QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
1273  const QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
1274  QudaPrecision device_precision_sloppy;
1275  switch(inv_args.mixed_precision) {
1276  case 2: device_precision_sloppy = QUDA_HALF_PRECISION; break;
1277  case 1: device_precision_sloppy = QUDA_SINGLE_PRECISION; break;
1278  default: device_precision_sloppy = device_precision;
1279  }
1280 
1281  static const QudaVerbosity verbosity = getVerbosity();
1282 
1283  invertParam.dslash_type = QUDA_CLOVER_WILSON_DSLASH;
1284  invertParam.kappa = kappa;
1285  invertParam.dagger = QUDA_DAG_NO;
1287  invertParam.gcrNkrylov = 30;
1288  invertParam.reliable_delta = reliable_delta;
1289  invertParam.maxiter = inv_args.max_iter;
1290 
1291  invertParam.cuda_prec_precondition = device_precision_sloppy;
1292  invertParam.verbosity_precondition = verbosity;
1293  invertParam.verbosity = verbosity;
1294  invertParam.cpu_prec = host_precision;
1295  invertParam.cuda_prec = device_precision;
1296  invertParam.cuda_prec_sloppy = device_precision_sloppy;
1299  invertParam.dirac_order = QUDA_DIRAC_ORDER;
1300  invertParam.sp_pad = 0;
1301  invertParam.cl_pad = 0;
1302  invertParam.clover_cpu_prec = host_precision;
1303  invertParam.clover_cuda_prec = device_precision;
1304  invertParam.clover_cuda_prec_sloppy = device_precision_sloppy;
1305  invertParam.clover_cuda_prec_precondition = device_precision_sloppy;
1306  invertParam.clover_order = QUDA_PACKED_CLOVER_ORDER;
1307 
1308  invertParam.compute_action = 0;
1309 }
1310 
1311 
1312 void qudaLoadGaugeField(int external_precision,
1313  int quda_precision,
1314  QudaInvertArgs_t inv_args,
1315  const void* milc_link) {
1316  qudamilc_called<true>(__func__);
1317  QudaGaugeParam gaugeParam = newQudaGaugeParam();
1318  setGaugeParams(gaugeParam, localDim, inv_args, external_precision, quda_precision);
1319 
1320  loadGaugeQuda(const_cast<void*>(milc_link), &gaugeParam);
1321  qudamilc_called<false>(__func__);
1322 } // qudaLoadGaugeField
1323 
1324 
1325 void qudaFreeGaugeField() {
1326  qudamilc_called<true>(__func__);
1327  freeGaugeQuda();
1328  qudamilc_called<false>(__func__);
1329 } // qudaFreeGaugeField
1330 
1331 void qudaLoadCloverField(int external_precision, int quda_precision, QudaInvertArgs_t inv_args, void *milc_clover,
1332  void *milc_clover_inv, QudaSolutionType solution_type, QudaSolveType solve_type, QudaInverterType inverter,
1333  double clover_coeff, int compute_trlog, double *trlog)
1334 {
1335  qudamilc_called<true>(__func__);
1336  QudaInvertParam invertParam = newQudaInvertParam();
1337  setInvertParam(invertParam, inv_args, external_precision, quda_precision, 0.0, 0.0);
1338  invertParam.solution_type = solution_type;
1339  invertParam.solve_type = solve_type;
1340  invertParam.inv_type = inverter;
1341  invertParam.matpc_type = QUDA_MATPC_EVEN_EVEN_ASYMMETRIC;
1342  invertParam.compute_clover_trlog = compute_trlog;
1343  invertParam.clover_coeff = clover_coeff;
1344 
1345  // Hacks to mollify checkInvertParams which is called from
1346  // loadCloverQuda. These "required" parameters are irrelevant here.
1347  // Better procedure: invertParam should be defined in
1348  // qudaCloverInvert and qudaEigCGCloverInvert and passed here
1349  // instead of redefining a partial version here
1350  invertParam.tol = 0.;
1351  invertParam.tol_hq = 0.;
1352  invertParam.residual_type = static_cast<QudaResidualType_s>(0);
1353 
1354  if(invertParam.dslash_type == QUDA_CLOVER_WILSON_DSLASH) {
1355  if (clover_alloc == 0) {
1356  loadCloverQuda(milc_clover, milc_clover_inv, &invertParam);
1357  clover_alloc = 1;
1358  } else {
1359  errorQuda("Clover term already allocated");
1360  }
1361  }
1362 
1363  if (compute_trlog) {
1364  trlog[0] = invertParam.trlogA[0];
1365  trlog[1] = invertParam.trlogA[1];
1366  }
1367  qudamilc_called<false>(__func__);
1368 } // qudaLoadCoverField
1369 
1370 void qudaFreeCloverField() {
1371  qudamilc_called<true>(__func__);
1372  if (clover_alloc==1) {
1373  freeCloverQuda();
1374  clover_alloc = 0;
1375  } else {
1376  errorQuda("Trying to free non-allocated clover term");
1377  }
1378  qudamilc_called<false>(__func__);
1379 } // qudaFreeCloverField
1380 
1381 
1382 void qudaCloverInvert(int external_precision,
1383  int quda_precision,
1384  double kappa,
1385  double clover_coeff,
1386  QudaInvertArgs_t inv_args,
1387  double target_residual,
1388  double target_fermilab_residual,
1389  const void* link,
1390  void* clover, // could be stored in Milc format
1391  void* cloverInverse,
1392  void* source,
1393  void* solution,
1394  double* const final_residual,
1395  double* const final_fermilab_residual,
1396  int* num_iters)
1397 {
1398  qudamilc_called<true>(__func__);
1399  if (target_fermilab_residual == 0 && target_residual == 0) errorQuda("qudaCloverInvert: requesting zero residual\n");
1400 
1401  if (link) qudaLoadGaugeField(external_precision, quda_precision, inv_args, link);
1402 
1403  if (clover || cloverInverse) {
1404  qudaLoadCloverField(external_precision, quda_precision, inv_args, clover, cloverInverse, QUDA_MAT_SOLUTION,
1406  }
1407 
1408  double reliable_delta = 1e-1;
1409 
1410  QudaInvertParam invertParam = newQudaInvertParam();
1411  setInvertParam(invertParam, inv_args, external_precision, quda_precision, kappa, reliable_delta);
1412  invertParam.residual_type = static_cast<QudaResidualType_s>(0);
1413  invertParam.residual_type = (target_residual != 0) ? static_cast<QudaResidualType_s> ( invertParam.residual_type | QUDA_L2_RELATIVE_RESIDUAL) : invertParam.residual_type;
1414  invertParam.residual_type = (target_fermilab_residual != 0) ? static_cast<QudaResidualType_s> (invertParam.residual_type | QUDA_HEAVY_QUARK_RESIDUAL) : invertParam.residual_type;
1415 
1416  invertParam.tol = target_residual;
1417  invertParam.tol_hq = target_fermilab_residual;
1418  invertParam.heavy_quark_check = (invertParam.residual_type & QUDA_HEAVY_QUARK_RESIDUAL ? 1 : 0);
1419  invertParam.clover_coeff = clover_coeff;
1420 
1421  // solution types
1422  invertParam.solution_type = QUDA_MAT_SOLUTION;
1425  invertParam.matpc_type = QUDA_MATPC_ODD_ODD;
1426 
1427  invertQuda(solution, source, &invertParam);
1428 
1429  *num_iters = invertParam.iter;
1430  *final_residual = invertParam.true_res;
1431  *final_fermilab_residual = invertParam.true_res_hq;
1432 
1433  if (clover || cloverInverse) qudaFreeCloverField();
1434  if (link) qudaFreeGaugeField();
1435  qudamilc_called<false>(__func__);
1436 } // qudaCloverInvert
1437 
1438 void qudaEigCGCloverInvert(int external_precision, int quda_precision, double kappa, double clover_coeff,
1439  QudaInvertArgs_t inv_args, double target_residual, double target_fermilab_residual,
1440  const void *link,
1441  void *clover, // could be stored in Milc format
1442  void *cloverInverse,
1443  void *source, // array of source vectors -> overwritten on exit!
1444  void *solution, // temporary
1445  QudaEigArgs_t eig_args,
1446  const int rhs_idx, // current rhs
1447  const int last_rhs_flag, // is this the last rhs to solve?
1448  double *const final_residual, double *const final_fermilab_residual, int *num_iters)
1449 {
1450  qudamilc_called<true>(__func__);
1451  if (target_fermilab_residual == 0 && target_residual == 0) errorQuda("qudaCloverInvert: requesting zero residual\n");
1452 
1453  if (link && (rhs_idx == 0)) qudaLoadGaugeField(external_precision, quda_precision, inv_args, link);
1454 
1455  if ( (clover || cloverInverse) && (rhs_idx == 0)) {
1456  qudaLoadCloverField(external_precision, quda_precision, inv_args, clover, cloverInverse, QUDA_MAT_SOLUTION,
1458  }
1459 
1460  double reliable_delta = 1e-1;
1461 
1462  QudaInvertParam invertParam = newQudaInvertParam();
1463  setInvertParam(invertParam, inv_args, external_precision, quda_precision, kappa, reliable_delta);
1464  invertParam.residual_type = static_cast<QudaResidualType_s>(0);
1465  invertParam.residual_type = (target_residual != 0) ? static_cast<QudaResidualType_s> ( invertParam.residual_type | QUDA_L2_RELATIVE_RESIDUAL) : invertParam.residual_type;
1466  invertParam.residual_type = (target_fermilab_residual != 0) ? static_cast<QudaResidualType_s> (invertParam.residual_type | QUDA_HEAVY_QUARK_RESIDUAL) : invertParam.residual_type;
1467 
1468  invertParam.tol = target_residual;
1469  invertParam.tol_hq = target_fermilab_residual;
1470  invertParam.heavy_quark_check = (invertParam.residual_type & QUDA_HEAVY_QUARK_RESIDUAL ? 1 : 0);
1471  invertParam.clover_coeff = clover_coeff;
1472 
1473  // solution types
1474  invertParam.solution_type = QUDA_MAT_SOLUTION;
1475  invertParam.matpc_type = QUDA_MATPC_ODD_ODD;
1476 
1478  QudaEigParam df_param = newQudaEigParam();
1479  df_param.invert_param = &invertParam;
1480 
1481  invertParam.solve_type = QUDA_NORMOP_PC_SOLVE;
1482  invertParam.nev = eig_args.nev;
1483  invertParam.max_search_dim = eig_args.max_search_dim;
1484  invertParam.deflation_grid = eig_args.deflation_grid;
1485  invertParam.cuda_prec_ritz = eig_args.prec_ritz;
1486  invertParam.tol_restart = eig_args.tol_restart;
1487  invertParam.eigcg_max_restarts = eig_args.eigcg_max_restarts;
1488  invertParam.max_restart_num = eig_args.max_restart_num;
1489  invertParam.inc_tol = eig_args.inc_tol;
1490  invertParam.eigenval_tol = eig_args.eigenval_tol;
1491  invertParam.rhs_idx = rhs_idx;
1492 
1493 
1494  if((inv_args.solver_type != QUDA_INC_EIGCG_INVERTER) && (inv_args.solver_type != QUDA_EIGCG_INVERTER)) errorQuda("Incorrect inverter type.\n");
1495  invertParam.inv_type = inv_args.solver_type;
1496 
1498 
1499  setDeflationParam(eig_args.prec_ritz, eig_args.location_ritz, eig_args.mem_type_ritz, eig_args.deflation_ext_lib, eig_args.vec_infile, eig_args.vec_outfile, &df_param);
1500 
1501  if(rhs_idx == 0) df_preconditioner = newDeflationQuda(&df_param);
1502  invertParam.deflation_op = df_preconditioner;
1503 
1504  invertQuda(solution, source, &invertParam);
1505 
1506  if (last_rhs_flag) destroyDeflationQuda(df_preconditioner);
1507 
1508  *num_iters = invertParam.iter;
1509  *final_residual = invertParam.true_res;
1510  *final_fermilab_residual = invertParam.true_res_hq;
1511 
1512  if ( (clover || cloverInverse) && last_rhs_flag) qudaFreeCloverField();
1513  if (link && last_rhs_flag) qudaFreeGaugeField();
1514  qudamilc_called<false>(__func__);
1515 } // qudaEigCGCloverInvert
1516 
1517 
1518 void qudaCloverMultishiftInvert(int external_precision,
1519  int quda_precision,
1520  int num_offsets,
1521  double* const offset,
1522  double kappa,
1523  double clover_coeff,
1524  QudaInvertArgs_t inv_args,
1525  const double* target_residual_offset,
1526  const void* milc_link,
1527  void* milc_clover,
1528  void* milc_clover_inv,
1529  void* source,
1530  void** solutionArray,
1531  double* const final_residual,
1532  int* num_iters)
1533 {
1534  static const QudaVerbosity verbosity = getVerbosity();
1535  qudamilc_called<true>(__func__, verbosity);
1536 
1537  for (int i = 0; i < num_offsets; ++i) {
1538  if (target_residual_offset[i] == 0) errorQuda("qudaCloverMultishiftInvert: target residual cannot be zero\n");
1539  }
1540 
1541  // if doing a pure double-precision multi-shift solve don't use reliable updates
1542  const bool use_mixed_precision = (((quda_precision==2) && inv_args.mixed_precision) ||
1543  ((quda_precision==1) && (inv_args.mixed_precision==2)) ) ? true : false;
1544  double reliable_delta = (use_mixed_precision) ? 1e-2 : 0.0;
1545  QudaInvertParam invertParam = newQudaInvertParam();
1546  setInvertParam(invertParam, inv_args, external_precision, quda_precision, kappa, reliable_delta);
1548  invertParam.num_offset = num_offsets;
1549  for(int i=0; i<num_offsets; ++i){
1550  invertParam.offset[i] = offset[i];
1551  invertParam.tol_offset[i] = target_residual_offset[i];
1552  }
1553  invertParam.tol = target_residual_offset[0];
1554  invertParam.clover_coeff = clover_coeff;
1555 
1556  // solution types
1558  invertParam.solve_type = QUDA_NORMOP_PC_SOLVE;
1559  invertParam.inv_type = QUDA_CG_INVERTER;
1561 
1562  invertParam.verbosity = verbosity;
1563  invertParam.verbosity_precondition = QUDA_SILENT;
1564 
1565  invertParam.make_resident_solution = inv_args.make_resident_solution;
1566  invertParam.compute_true_res = 0;
1567 
1568  if (num_offsets==1 && offset[0] == 0) {
1569  // set the solver
1570  char *quda_solver = getenv("QUDA_MILC_CLOVER_SOLVER");
1571 
1572  // default is chronological CG
1573  if (!quda_solver || strcmp(quda_solver,"CHRONO_CG_SOLVER")==0) {
1574  // use CG with chronological forecasting
1575  invertParam.chrono_use_resident = 1;
1576  invertParam.chrono_make_resident = 1;
1577  invertParam.chrono_max_dim = 10;
1578  } else if (strcmp(quda_solver,"BICGSTAB_SOLVER")==0){
1579  // use two-step BiCGStab
1580  invertParam.inv_type = QUDA_BICGSTAB_INVERTER;
1581  invertParam.solve_type = QUDA_DIRECT_PC_SOLVE;
1582  } else if (strcmp(quda_solver,"CG_SOLVER")==0){
1583  // regular CG
1584  invertParam.chrono_use_resident = 0;
1585  invertParam.chrono_make_resident = 0;
1586  }
1587 
1588  invertQuda(solutionArray[0], source, &invertParam);
1589  *final_residual = invertParam.true_res;
1590  } else {
1591  invertMultiShiftQuda(solutionArray, source, &invertParam);
1592  for (int i=0; i<num_offsets; ++i) final_residual[i] = invertParam.true_res_offset[i];
1593  }
1594 
1595  // return the number of iterations taken by the inverter
1596  *num_iters = invertParam.iter;
1597 
1598  qudamilc_called<false>(__func__, verbosity);
1599 } // qudaCloverMultishiftInvert
1600 
1601 void qudaGaugeFixingOVR(int precision, unsigned int gauge_dir, int Nsteps, int verbose_interval, double relax_boost,
1602  double tolerance, unsigned int reunit_interval, unsigned int stopWtheta, void *milc_sitelink)
1603 {
1604  QudaGaugeParam qudaGaugeParam = newMILCGaugeParam(localDim,
1605  (precision==1) ? QUDA_SINGLE_PRECISION : QUDA_DOUBLE_PRECISION,
1606  QUDA_SU3_LINKS);
1607  qudaGaugeParam.reconstruct = QUDA_RECONSTRUCT_NO;
1608  //qudaGaugeParam.reconstruct = QUDA_RECONSTRUCT_12;
1609 
1610  double timeinfo[3];
1611  computeGaugeFixingOVRQuda(milc_sitelink, gauge_dir, Nsteps, verbose_interval, relax_boost, tolerance, reunit_interval, stopWtheta, \
1612  &qudaGaugeParam, timeinfo);
1613 
1614  printfQuda("Time H2D: %lf\n", timeinfo[0]);
1615  printfQuda("Time to Compute: %lf\n", timeinfo[1]);
1616  printfQuda("Time D2H: %lf\n", timeinfo[2]);
1617  printfQuda("Time all: %lf\n", timeinfo[0]+timeinfo[1]+timeinfo[2]);
1618 }
1619 
1620 void qudaGaugeFixingFFT( int precision,
1621  unsigned int gauge_dir,
1622  int Nsteps,
1623  int verbose_interval,
1624  double alpha,
1625  unsigned int autotune,
1626  double tolerance,
1627  unsigned int stopWtheta,
1628  void* milc_sitelink
1629  )
1630 {
1631  QudaGaugeParam qudaGaugeParam = newMILCGaugeParam(localDim,
1632  (precision==1) ? QUDA_SINGLE_PRECISION : QUDA_DOUBLE_PRECISION,
1634  qudaGaugeParam.reconstruct = QUDA_RECONSTRUCT_NO;
1635  //qudaGaugeParam.reconstruct = QUDA_RECONSTRUCT_12;
1636 
1637 
1638  double timeinfo[3];
1639  computeGaugeFixingFFTQuda(milc_sitelink, gauge_dir, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta, \
1640  &qudaGaugeParam, timeinfo);
1641 
1642  printfQuda("Time H2D: %lf\n", timeinfo[0]);
1643  printfQuda("Time to Compute: %lf\n", timeinfo[1]);
1644  printfQuda("Time D2H: %lf\n", timeinfo[2]);
1645  printfQuda("Time all: %lf\n", timeinfo[0]+timeinfo[1]+timeinfo[2]);
1646 }
1647 
1648 #endif // BUILD_MILC_INTERFACE
void computeCloverForceQuda(void *mom, double dt, void **x, void **p, double *coeff, double kappa2, double ck, int nvector, double multiplicity, void *gauge, QudaGaugeParam *gauge_param, QudaInvertParam *inv_param)
int maxiter_precondition
Definition: quda.h:292
static QudaGaugeParam qudaGaugeParam
static bool reunit_allow_svd
QudaDiracFieldOrder dirac_order
Definition: quda.h:219
QudaMassNormalization mass_normalization
Definition: quda.h:208
double tol_hq_offset[QUDA_MAX_MULTI_SHIFT]
Definition: quda.h:182
QudaReconstructType reconstruct_sloppy
Definition: quda.h:53
double anisotropy
Definition: quda.h:38
void freeCloverQuda(void)
QudaGaugeParam gaugeParam
Definition: covdev_test.cpp:36
void setPrecision(QudaPrecision precision, QudaPrecision ghost_precision=QUDA_INVALID_PRECISION, bool force_native=false)
void invertMultiShiftQuda(void **_hp_x, void *_hp_b, QudaInvertParam *param)
void setVerbosityQuda(QudaVerbosity verbosity, const char prefix[], FILE *outfile)
void endQuda(void)
int max_hq_res_increase
Definition: quda.h:157
#define pool_pinned_free(ptr)
Definition: malloc_quda.h:128
void qudaHisqParamsInit(QudaHisqParams_t hisq_params)
QudaSolveType solve_type
Definition: quda.h:205
QudaVerbosity verbosity_precondition
Definition: quda.h:286
enum QudaPrecision_s QudaPrecision
void qudaUnitarizeSU3(int prec, double tol, QudaMILCSiteArg_t *arg)
int ga_pad
Definition: quda.h:63
void destroyDeflationQuda(void *df_instance)
int make_resident_mom
Definition: quda.h:83
void qudaGaugeFixingFFT(int precision, unsigned int gauge_dir, int Nsteps, int verbose_interval, double alpha, unsigned int autotune, double tolerance, unsigned int stopWtheta, void *milc_sitelink)
Gauge fixing with Steepest descent method with FFTs with support for single GPU only.
size_t gauge_offset
Definition: quda.h:87
void setMPICommHandleQuda(void *mycomm)
QudaGaugeFixed gauge_fix
Definition: quda.h:61
QudaExtLibType deflation_ext_lib
void qudaDslash(int external_precision, int quda_precision, QudaInvertArgs_t inv_args, const void *const milc_fatlink, const void *const milc_longlink, void *source, void *solution, int *num_iters)
void setUnitarizeForceConstants(double unitarize_eps, double hisq_force_filter, double max_det_error, bool allow_svd, bool svd_only, double svd_rel_error, double svd_abs_error)
Set the constant parameters for the force unitarization.
QudaInverterType inv_type_precondition
Definition: quda.h:270
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21
double kappa
Definition: test_util.cpp:1647
QudaLinkType type
Definition: quda.h:42
double kappa
Definition: quda.h:106
QudaPrecision cuda_prec_ritz
Definition: quda.h:324
void invertQuda(void *h_x, void *h_b, QudaInvertParam *param)
#define errorQuda(...)
Definition: util_quda.h:121
double tol
Definition: quda.h:121
void setUnitarizeLinksConstants(double unitarize_eps, double max_error, bool allow_svd, bool svd_only, double svd_rel_error, double svd_abs_error)
QudaDslashType dslash_type
Definition: quda.h:102
QudaReconstructType reconstruct_precondition
Definition: quda.h:59
QudaInverterType inv_type
Definition: quda.h:103
QudaPrecision prec_ritz
QudaPrecision cuda_prec
Definition: quda.h:214
#define host_free(ptr)
Definition: malloc_quda.h:71
enum QudaSolveType_s QudaSolveType
void loadGaugeQuda(void *h_gauge, QudaGaugeParam *param)
QudaExtLibType deflation_ext_lib
Definition: test_util.cpp:1718
void qudaInit(QudaInitArgs_t input)
QudaPrecision cpu_prec
Definition: quda.h:213
QudaMemoryType mem_type_ritz
QudaPrecision & cuda_prec
void setDeflationParam(QudaEigParam &df_param)
static int rank
Definition: comm_mpi.cpp:44
double momActionQuda(void *momentum, QudaGaugeParam *param)
QudaStaggeredPhase staggered_phase_type
Definition: quda.h:71
void qudaLoadGaugeField(int external_precision, int quda_precision, QudaInvertArgs_t inv_args, const void *milc_link)
void setInvertParam(QudaInvertParam &inv_param)
QudaDagType dagger
Definition: quda.h:207
#define MAX(a, b)
void qudaInvert(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args, double target_residual, double target_fermilab_residual, const void *const milc_fatlink, const void *const milc_longlink, void *source, void *solution, double *const final_resid, double *const final_rel_resid, int *num_iters)
void qudaEigCGInvert(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args, double target_residual, double target_fermilab_residual, const void *const fatlink, const void *const longlink, void *source, void *solution, QudaEigArgs_t eig_args, const int rhs_idx, const int last_rhs_flag, double *const final_residual, double *const final_fermilab_residual, int *num_iters)
double reliable_delta
Definition: test_util.cpp:1658
QudaPrecision cuda_prec_refinement_sloppy
Definition: quda.h:216
void qudaCloverMultishiftInvert(int external_precision, int quda_precision, int num_offsets, double *const offset, double kappa, double clover_coeff, QudaInvertArgs_t inv_args, const double *target_residual, const void *milc_link, void *milc_clover, void *milc_clover_inv, void *source, void **solutionArray, double *const final_residual, int *num_iters)
void qudaEigCGCloverInvert(int external_precision, int quda_precision, double kappa, double clover_coeff, QudaInvertArgs_t inv_args, double target_residual, double target_fermilab_residual, const void *milc_link, void *milc_clover, void *milc_clover_inv, void *source, void *solution, QudaEigArgs_t eig_args, const int rhs_idx, const int last_rhs_flag, double *const final_residual, double *const final_fermilab_residual, int *num_iters)
QudaGaugeFieldOrder gauge_order
Definition: quda.h:43
double true_res
Definition: quda.h:126
size_t mom_offset
Definition: quda.h:88
void computeKSLinkQuda(void *fatlink, void *longlink, void *ulink, void *inlink, double *path_coeff, QudaGaugeParam *param)
void qudaGaugeFixingOVR(const int precision, const unsigned int gauge_dir, const int Nsteps, const int verbose_interval, const double relax_boost, const double tolerance, const unsigned int reunit_interval, const unsigned int stopWtheta, void *milc_sitelink)
Gauge fixing with overrelaxation with support for single and multi GPU.
void qudaSaveGaugeField(void *gauge, void *inGauge)
void loadCloverQuda(void *h_clover, void *h_clovinv, QudaInvertParam *inv_param)
int length[]
int make_resident_solution
Definition: quda.h:347
int overwrite_mom
Definition: quda.h:78
double qudaMomAction(int precision, void *momentum)
QudaSiteSubset siteSubset
Definition: lattice_field.h:71
void qudaSetLayout(QudaLayout_t layout)
QudaPrecision clover_cuda_prec_sloppy
Definition: quda.h:226
int compute_action
Definition: quda.h:197
void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity parity)
QudaFieldLocation input_location
Definition: quda.h:99
void freeGaugeQuda(void)
void initCommsGridQuda(int nDim, const int *dims, QudaCommsMap func, void *fdata)
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]
Definition: quda.h:191
double reliable_delta
Definition: quda.h:129
size_t site_size
Definition: quda.h:89
QudaUseInitGuess use_init_guess
Definition: quda.h:231
void qudaHisqForce(int precision, int num_terms, int num_naik_terms, double dt, double **coeff, void **quark_field, const double level2_coeff[6], const double fat7_coeff[6], const void *const w_link, const void *const v_link, const void *const u_link, void *const milc_momentum)
int computeGaugeFixingOVRQuda(void *gauge, const unsigned int gauge_dir, const unsigned int Nsteps, const unsigned int verbose_interval, const double relax_boost, const double tolerance, const unsigned int reunit_interval, const unsigned int stopWtheta, QudaGaugeParam *param, double *timeinfo)
Gauge fixing with overrelaxation with support for single and multi GPU.
QudaGaugeParam param
Definition: pack_test.cpp:17
int llfat_ga_pad
Definition: quda.h:68
QudaSolutionType solution_type
Definition: quda.h:204
void projectSU3Quda(void *gauge_h, double tol, QudaGaugeParam *param)
QudaMemoryType mem_type_ritz
Definition: quda.h:450
int x[QUDA_MAX_DIM]
Definition: lattice_field.h:67
QudaPrecision clover_cuda_prec
Definition: quda.h:225
QudaPrecision & cuda_prec_sloppy
int chrono_use_resident
Definition: quda.h:359
int computeGaugeForceQuda(void *mom, void *sitelink, int ***input_path_buf, int *path_length, double *loop_coeff, int num_paths, int max_length, double dt, QudaGaugeParam *qudaGaugeParam)
QudaSolutionType solution_type
Definition: test_util.cpp:1664
QudaInvertParam * invert_param
Definition: quda.h:381
double scale
Definition: quda.h:40
void initQuda(int device)
double tol
Definition: test_util.cpp:1656
void qudaFreePinned(void *ptr)
void qudaUpdateU(int precision, double eps, QudaMILCSiteArg_t *arg)
QudaFieldLocation output_location
Definition: quda.h:100
QudaPrecision clover_cuda_prec_precondition
Definition: quda.h:228
int site_ga_pad
Definition: quda.h:65
#define POP_RANGE
Definition: timer.h:168
bool canReuseResidentGauge(QudaInvertParam *inv_param)
QudaBoolean run_verify
Definition: quda.h:456
void qudaFreeCloverField()
void * newDeflationQuda(QudaEigParam *param)
QudaPrecision cuda_prec_sloppy
Definition: quda.h:215
void qudaInvertMsrc(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args, double target_residual, double target_fermilab_residual, const void *const fatlink, const void *const longlink, void **sourceArray, void **solutionArray, double *const final_residual, double *const final_fermilab_residual, int *num_iters, int num_src)
static bool initialized
Profiler for initQuda.
QudaVerbosity verbosity
Definition: quda.h:244
static bool reunit_svd_only
ColorSpinorParam csParam
Definition: pack_test.cpp:24
double tol_offset[QUDA_MAX_MULTI_SHIFT]
Definition: quda.h:179
double true_res_offset[QUDA_MAX_MULTI_SHIFT]
Definition: quda.h:185
int nvec[QUDA_MAX_MG_LEVEL]
Definition: test_util.cpp:1637
QudaInvertParam newQudaInvertParam(void)
double gflops
Definition: quda.h:250
void * qudaCreateGaugeField(void *gauge, int geometry, int precision)
int eigcg_max_restarts
Definition: quda.h:340
QudaPrecision cuda_prec_precondition
Definition: quda.h:58
QudaCloverFieldOrder clover_order
Definition: quda.h:230
void computeHISQForceQuda(void *momentum, double dt, const double level2_coeff[6], const double fat7_coeff[6], const void *const w_link, const void *const v_link, const void *const u_link, void **quark, int num, int num_naik, double **coeff, QudaGaugeParam *param)
void saveGaugeFieldQuda(void *outGauge, void *inGauge, QudaGaugeParam *param)
double tol_hq
Definition: quda.h:123
QudaInverterType solver_type
void qudaRephase(int prec, void *gauge, int flag, double i_mu)
double true_res_hq
Definition: quda.h:127
enum QudaSolutionType_s QudaSolutionType
void qudaComputeOprod(int precision, int num_terms, int num_naik_terms, double **coeff, double scale, void **quark_field, void *oprod[3])
QudaGammaBasis gamma_basis
Definition: quda.h:221
const int * machsize
QudaPrecision cuda_prec_sloppy
Definition: quda.h:52
int max_search_dim
Definition: quda.h:332
int chrono_make_resident
Definition: quda.h:353
double tol_precondition
Definition: quda.h:289
double offset[QUDA_MAX_MULTI_SHIFT]
Definition: quda.h:176
void qudaFreeGaugeField()
int use_sloppy_partial_accumulator
Definition: quda.h:132
int heavy_quark_check
Definition: quda.h:165
enum QudaParity_s QudaParity
QudaReconstructType reconstruct
Definition: quda.h:50
enum QudaLinkType_s QudaLinkType
QudaPrecision cuda_prec
Definition: quda.h:49
int X[4]
Definition: quda.h:36
void qudaLoadCloverField(int external_precision, int quda_precision, QudaInvertArgs_t inv_args, void *milc_clover, void *milc_clover_inv, QudaSolutionType solution_type, QudaSolveType solve_type, double clover_coeff, int compute_trlog, double *trlog)
double mass
Definition: quda.h:105
QudaBoolean import_vectors
Definition: quda.h:444
void qudaFinalize()
QudaFieldLocation location
Definition: quda.h:453
int gcrNkrylov
Definition: quda.h:259
QudaFieldLocation location_ritz
#define safe_malloc(size)
Definition: malloc_quda.h:66
void qudaCloverInvert(int external_precision, int quda_precision, double kappa, double clover_coeff, QudaInvertArgs_t inv_args, double target_residual, double target_fermilab_residual, const void *milc_link, void *milc_clover, void *milc_clover_inv, void *source, void *solution, double *const final_residual, double *const final_fermilab_residual, int *num_iters)
int max_hq_res_restart_total
Definition: quda.h:162
QudaSolveType solve_type
Definition: test_util.cpp:1663
QudaPrecision cuda_prec_refinement_sloppy
Definition: quda.h:55
static int dims[4]
Definition: face_gauge.cpp:41
void staggeredPhaseQuda(void *gauge_h, QudaGaugeParam *param)
static int index(int ndim, const int *dims, const int *x)
Definition: comm_common.cpp:32
#define pool_pinned_malloc(size)
Definition: malloc_quda.h:127
int computeGaugeFixingFFTQuda(void *gauge, const unsigned int gauge_dir, const unsigned int Nsteps, const unsigned int verbose_interval, const double alpha, const unsigned int autotune, const double tolerance, const unsigned int stopWtheta, QudaGaugeParam *param, double *timeinfo)
Gauge fixing with Steepest descent method with FFTs with support for single GPU only.
void qudaLoadUnitarizedLink(int precision, QudaFatLinkArgs_t fatlink_args, const double path_coeff[6], void *inlink, void *fatlink, void *ulink)
char vec_outfile[256]
Definition: quda.h:462
void destroyGaugeFieldQuda(void *gauge)
QudaResidualType_s
Definition: enum_quda.h:186
enum QudaFieldLocation_s QudaFieldLocation
double tadpole_coeff
Definition: quda.h:39
QudaPrecision cuda_prec_precondition
Definition: quda.h:217
int deflation_grid
Definition: quda.h:336
GaugeFieldParam gParam
double tol_restart
Definition: quda.h:122
void qudaSetMPICommHandle(void *mycomm)
void updateGaugeFieldQuda(void *gauge, void *momentum, double dt, int conj_mom, int exact, QudaGaugeParam *param)
double clover_coeff
Definition: test_util.cpp:1653
enum QudaReconstructType_s QudaReconstructType
Main header file for the QUDA library.
void invertMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param)
static double unitarize_eps
const int * latsize
int mom_ga_pad
Definition: quda.h:69
void qudaCloverForce(void *mom, double dt, void **x, void **p, double *coeff, double kappa, double ck, int nvec, double multiplicity, void *gauge, int precision, QudaInvertArgs_t inv_args)
void * qudaAllocatePinned(size_t bytes)
QudaMemoryType mem_type_ritz
Definition: test_util.cpp:1720
#define printfQuda(...)
Definition: util_quda.h:115
QudaTboundary t_boundary
Definition: quda.h:45
cudaGaugeField * cudaGauge
int chrono_max_dim
Definition: quda.h:362
int max_restart_num
Definition: quda.h:342
__host__ __device__ ValueType arg(const complex< ValueType > &z)
Returns the phase angle of z.
void * createGaugeFieldQuda(void *gauge, int geometry, QudaGaugeParam *param)
int use_resident_mom
Definition: quda.h:81
int device
Definition: test_util.cpp:1602
QudaFieldLocation location_ritz
Definition: test_util.cpp:1719
void * longlink
int compute_true_res
Definition: quda.h:125
QudaResidualType residual_type
Definition: quda.h:320
double inc_tol
Definition: quda.h:344
int num_offset
Definition: quda.h:169
enum QudaVerbosity_s QudaVerbosity
QudaPrecision & cpu_prec
void * fatlink
int return_result_mom
Definition: quda.h:85
#define PUSH_RANGE(name, cid)
Definition: timer.h:167
QudaVerbosity verbosity
int use_resident_solution
Definition: quda.h:350
void qudaLoadKSLink(int precision, QudaFatLinkArgs_t fatlink_args, const double act_path_coeff[6], void *inlink, void *fatlink, void *longlink)
void * deflation_op
Definition: quda.h:276
double eigenval_tol
Definition: quda.h:338
QudaPrecision clover_cpu_prec
Definition: quda.h:224
QudaPrecision cuda_prec_ritz
Definition: quda.h:447
QudaParity parity
Definition: covdev_test.cpp:54
void qudaGaugeForce(int precision, int num_loop_types, double milc_loop_coeff[3], double eb3, QudaMILCSiteArg_t *arg)
static int opp(int dir)
QudaPrecision prec
Definition: test_util.cpp:1608
QudaMatPCType matpc_type
Definition: quda.h:206
QudaEigParam newQudaEigParam(void)
char vec_infile[256]
Definition: quda.h:459
enum QudaInverterType_s QudaInverterType
void qudaMultishiftInvert(int external_precision, int precision, int num_offsets, double *const offset, QudaInvertArgs_t inv_args, const double *target_residual, const double *target_fermilab_residual, const void *const milc_fatlink, const void *const milc_longlink, void *source, void **solutionArray, double *const final_residual, double *const final_fermilab_residual, int *num_iters)
enum QudaMemoryType_s QudaMemoryType
void qudaAsqtadForce(int precision, const double act_path_coeff[6], const void *const one_link_src[4], const void *const naik_src[4], const void *const link, void *const milc_momentum)
static void createGaugeForcePaths(int **paths, int dir, int num_loop_types)
QudaReconstructType reconstruct_refinement_sloppy
Definition: quda.h:56
unsigned long long bytes
Definition: blas_quda.cu:23
QudaPrecision cpu_prec
Definition: quda.h:47
enum QudaExtLibType_s QudaExtLibType
QudaLayout_t layout
void qudaDestroyGaugeField(void *gauge)
QudaGaugeParam newQudaGaugeParam(void)
QudaPreserveSource preserve_source
Definition: quda.h:211
double reliable_delta_refinement
Definition: quda.h:130
double clover_coeff
Definition: quda.h:233
QudaVerbosity verbosity
Definition: test_util.cpp:1614