QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
deflation.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <invert_quda.h>
4 #include <vector>
5 #include <complex_quda.h>
6 
7 namespace quda {
8 
13  struct DeflationParam {
14 
18 
23 
25  double *invRitzVals;
26 
29 
32 
34  int ld;
35 
37  int tot_dim;
38 
40  int cur_dim;
41 
44 
47 
49  char filename[100];
50 
51  DeflationParam(QudaEigParam &param, ColorSpinorField *RV, DiracMatrix &matDeflation, int cur_dim = 0) : eig_global(param), RV(RV), matDeflation(matDeflation),
52  cur_dim(cur_dim), use_inv_ritz(false), location(param.location) {
53 
54  if(param.nk == 0 || param.np == 0 || (param.np % param.nk != 0)) errorQuda("\nIncorrect deflation space parameters...\n");
55  //redesign: param.nk => param.nev, param.np => param.deflation_grid*param.nev;
56  tot_dim = param.np;
57  ld = ((tot_dim+15) / 16) * tot_dim;
58  //allocate deflation resources:
59  matProj = new Complex[ld*tot_dim];
60  invRitzVals = new double[tot_dim];
61 
62  //Check that RV is a composite field:
63  if(RV->IsComposite() == false) errorQuda("\nRitz vectors must be contained in a composite field.\n");
64 
65  cudaHostRegister(matProj,ld*tot_dim*sizeof(Complex),cudaHostRegisterDefault);
66  }
67 
69  cudaHostUnregister(matProj);
70  if(matProj) delete[] matProj;
71  if(invRitzVals) delete[] invRitzVals;
72  }
73  };
74 
78  class Deflation {
79 
80  private:
83 
86 
89 
92 
95 
98 
99 
100  public:
106  Deflation(DeflationParam &param, TimeProfile &profile);
107 
112  virtual ~Deflation();
113 
119  void verify();
120 
126  void increment(ColorSpinorField &V, int nev);
127 
134  void reduce(double tol, int max_nev);
135 
141  void operator()(ColorSpinorField &out, ColorSpinorField &in);
142 
147  void loadVectors(ColorSpinorField *RV);
148 
153  void saveVectors(ColorSpinorField *RV);
154 
159  bool is_complete() {return (param.cur_dim == param.tot_dim);}
160 
164  int size() {return param.cur_dim;}
165 
166 
170  double flops() const;
171 
172  };
173 
181 
184 
185  ColorSpinorField *RV;//Ritz vectors
186 
188 
191 
192  deflated_solver(QudaEigParam &eig_param, TimeProfile &profile);
193 
195  {
196  profile.TPSTART(QUDA_PROFILE_FREE);
197 
198  if (defl) delete defl;
199  if (deflParam) delete deflParam;
200 
201  if (RV) delete RV;
202 
203  if (m) delete m;
204  if (d) delete d;
205 
206  profile.TPSTOP(QUDA_PROFILE_FREE);
207  }
208  };
209 
210 } // namespace quda
211 
212 
#define errorQuda(...)
Definition: util_quda.h:121
char filename[100]
Definition: deflation.h:49
int size()
return deflation space size
Definition: deflation.h:164
virtual ~deflated_solver()
Definition: deflation.h:194
DiracMatrix & matDeflation
Definition: deflation.h:28
__device__ void reduce(ReduceArg< T > arg, const T &in, const int idx=0)
Definition: cub_helper.cuh:137
Complex * matProj
Definition: deflation.h:31
DeflationParam * deflParam
Definition: deflation.h:187
QudaGaugeParam param
Definition: pack_test.cpp:17
bool is_complete()
Test whether the deflation space is complete and therefore cannot be further extended.
Definition: deflation.h:159
double tol
Definition: test_util.cpp:1656
TimeProfile & profile
Definition: deflation.h:190
DeflationParam & param
Definition: deflation.h:82
DiracMatrix * m
Definition: deflation.h:183
cpuColorSpinorField * in
std::complex< double > Complex
Definition: quda_internal.h:46
TimeProfile profile
Definition: deflation.h:85
QudaFieldLocation location
Definition: deflation.h:46
Deflation * defl
Definition: deflation.h:189
int V
Definition: test_util.cpp:27
ColorSpinorField * Av_sloppy
Definition: deflation.h:97
enum QudaFieldLocation_s QudaFieldLocation
cpuColorSpinorField * out
int nev
Definition: test_util.cpp:1707
ColorSpinorField * RV
Definition: deflation.h:22
ColorSpinorField * Av
Definition: deflation.h:91
unsigned long long flops
Definition: blas_quda.cu:22
DeflationParam(QudaEigParam &param, ColorSpinorField *RV, DiracMatrix &matDeflation, int cur_dim=0)
Definition: deflation.h:51
ColorSpinorField * r_sloppy
Definition: deflation.h:94
QudaEigParam & eig_global
Definition: deflation.h:17
ColorSpinorField * r
Definition: deflation.h:88
double * invRitzVals
Definition: deflation.h:25
ColorSpinorField * RV
Definition: deflation.h:185