QUDA  v1.1.0
A library for QCD on GPUs
deflation.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <invert_quda.h>
4 #include <vector>
5 #include <complex_quda.h>
6 
7 namespace quda {
8 
13  struct DeflationParam {
14 
18 
23 
25  double *invRitzVals;
26 
29 
32 
34  int ld;
35 
37  int tot_dim;
38 
40  int cur_dim;
41 
44 
47 
49  char filename[100];
50 
53 
54  if(param.nk == 0 || param.np == 0 || (param.np % param.nk != 0)) errorQuda("\nIncorrect deflation space parameters...\n");
55  // redesign: param.nk => param.n_ev, param.np => param.deflation_grid*param.n_ev;
56  tot_dim = param.np;
57  ld = ((tot_dim+15) / 16) * tot_dim;
58  //allocate deflation resources:
59  matProj = new Complex[ld*tot_dim];
60  invRitzVals = new double[tot_dim];
61 
62  //Check that RV is a composite field:
63  if(RV->IsComposite() == false) errorQuda("\nRitz vectors must be contained in a composite field.\n");
64 
65  cudaHostRegister(matProj,ld*tot_dim*sizeof(Complex),cudaHostRegisterDefault);
66  }
67 
69  cudaHostUnregister(matProj);
70  if(matProj) delete[] matProj;
71  if(invRitzVals) delete[] invRitzVals;
72  }
73  };
74 
78  class Deflation {
79 
80  private:
82  DeflationParam &param;
83 
85  TimeProfile profile;
86 
89 
91  ColorSpinorField *Av;
92 
94  ColorSpinorField *r_sloppy;
95 
97  ColorSpinorField *Av_sloppy;
98 
99 
100  public:
107 
112  virtual ~Deflation();
113 
119  void verify();
120 
126  void increment(ColorSpinorField &V, int n_ev);
127 
134  void reduce(double tol, int max_n_ev);
135 
142 
147  void loadVectors(ColorSpinorField *RV);
148 
153  void saveVectors(ColorSpinorField *RV);
154 
159  bool is_complete() {return (param.cur_dim == param.tot_dim);}
160 
164  int size() {return param.cur_dim;}
165 
166 
170  double flops() const;
171 
172  };
173 
181 
184 
185  ColorSpinorField *RV;//Ritz vectors
186 
188 
191 
193 
195  {
196  profile.TPSTART(QUDA_PROFILE_FREE);
197 
198  if (defl) delete defl;
199  if (deflParam) delete deflParam;
200 
201  if (RV) delete RV;
202 
203  if (m) delete m;
204  if (d) delete d;
205 
206  profile.TPSTOP(QUDA_PROFILE_FREE);
207  }
208  };
209 
210 } // namespace quda
211 
212 
bool is_complete()
Test whether the deflation space is complete and therefore cannot be further extended
Definition: deflation.h:159
void operator()(ColorSpinorField &out, ColorSpinorField &in)
Definition: deflation.cpp:126
Deflation(DeflationParam &param, TimeProfile &profile)
Definition: deflation.cpp:22
int size()
return deflation space size
Definition: deflation.h:164
void increment(ColorSpinorField &V, int n_ev)
Definition: deflation.cpp:180
virtual ~Deflation()
Definition: deflation.cpp:61
void saveVectors(ColorSpinorField *RV)
Save the eigen space vectors in file.
Definition: deflation.cpp:378
double flops() const
Return the total flops done on this and all coarser levels.
Definition: deflation.cpp:74
void reduce(double tol, int max_n_ev)
Definition: deflation.cpp:258
void loadVectors(ColorSpinorField *RV)
Load the eigen space vectors from file.
Definition: deflation.cpp:344
double tol
int V
Definition: host_utils.cpp:37
enum QudaFieldLocation_s QudaFieldLocation
std::complex< double > Complex
Definition: quda_internal.h:86
@ QUDA_PROFILE_FREE
Definition: timer.h:111
QudaGaugeParam param
Definition: pack_test.cpp:18
char filename[100]
Definition: deflation.h:49
QudaEigParam & eig_global
Definition: deflation.h:17
Complex * matProj
Definition: deflation.h:31
DiracMatrix & matDeflation
Definition: deflation.h:28
ColorSpinorField * RV
Definition: deflation.h:22
DeflationParam(QudaEigParam &param, ColorSpinorField *RV, DiracMatrix &matDeflation, int cur_dim=0)
Definition: deflation.h:51
double * invRitzVals
Definition: deflation.h:25
QudaFieldLocation location
Definition: deflation.h:46
virtual ~deflated_solver()
Definition: deflation.h:194
TimeProfile & profile
Definition: deflation.h:190
Deflation * defl
Definition: deflation.h:189
deflated_solver(QudaEigParam &eig_param, TimeProfile &profile)
ColorSpinorField * RV
Definition: deflation.h:185
DiracMatrix * m
Definition: deflation.h:183
DeflationParam * deflParam
Definition: deflation.h:187
#define errorQuda(...)
Definition: util_quda.h:120