1 #include <color_spinor_field.h>
3 #include <uint_to_char.h>
8 #include <launch_kernel.cuh>
9 #include <jitify_helper.cuh>
10 #include <kernels/block_orthogonalize.cuh>
14 using namespace quda::colorspinor;
16 // B fields in general use float2 ordering except for fine-grid Wilson
17 template <typename store_t, int nSpin, int nColor> struct BOrder { static constexpr QudaFieldOrder order = QUDA_FLOAT2_FIELD_ORDER; };
18 template<> struct BOrder<float, 4, 3> { static constexpr QudaFieldOrder order = QUDA_FLOAT4_FIELD_ORDER; };
20 template<> struct BOrder<short, 4, 3> { static constexpr QudaFieldOrder order = QUDA_FLOAT8_FIELD_ORDER; };
22 template<> struct BOrder<short, 4, 3> { static constexpr QudaFieldOrder order = QUDA_FLOAT4_FIELD_ORDER; };
25 template <typename sumType, typename vFloat, typename bFloat, int nSpin, int spinBlockSize, int nColor_, int coarseSpin, int nVec>
26 class BlockOrtho : public Tunable {
28 // we only support block-format on fine grid where Ncolor=3
29 static constexpr int nColor = isFixed<bFloat>::value ? 3 : nColor_;
31 typedef typename mapper<vFloat>::type RegType;
33 const std::vector<ColorSpinorField*> B;
34 const int *fine_to_coarse;
35 const int *coarse_to_fine;
37 const int n_block_ortho;
41 unsigned int sharedBytesPerThread() const { return 0; }
42 unsigned int sharedBytesPerBlock(const TuneParam ¶m) const { return 0; }
43 unsigned int minThreads() const { return V.VolumeCB(); } // fine parity is the block y dimension
46 BlockOrtho(ColorSpinorField &V, const std::vector<ColorSpinorField *> B, const int *fine_to_coarse,
47 const int *coarse_to_fine, const int *geo_bs, const int n_block_ortho) :
50 fine_to_coarse(fine_to_coarse),
51 coarse_to_fine(coarse_to_fine),
53 n_block_ortho(n_block_ortho)
55 if (nColor_ != nColor)
56 errorQuda("Number of colors %d not supported with this precision %lu\n", nColor_, sizeof(bFloat));
58 if (V.Location() == QUDA_CUDA_FIELD_LOCATION) {
60 create_jitify_program("kernels/block_orthogonalize.cuh");
63 strcat(aux, compile_type_str(V));
64 strcat(aux, V.AuxString());
65 strcat(aux,",block_size=");
69 for (int d = 0; d < V.Ndim(); d++) {
70 geoBlockSize *= geo_bs[d];
71 i32toa(geo_str, geo_bs[d]);
73 if (d < V.Ndim() - 1) strcat(aux, "x");
75 if (geoBlockSize == 1) errorQuda("Invalid MG aggregate size %d", geoBlockSize);
77 strcat(aux, ",n_block_ortho=");
79 i32toa(n_ortho_str, n_block_ortho);
80 strcat(aux, n_ortho_str);
82 if (V.Location() == QUDA_CPU_FIELD_LOCATION) strcat(aux, getOmpThreadStr());
84 int chiralBlocks = (nSpin==1) ? 2 : V.Nspin() / spinBlockSize; //always 2 for staggered.
85 nBlock = (V.Volume()/geoBlockSize) * chiralBlocks;
89 @brief Helper function for expanding the std::vector into a
90 parameter pack that we can use to instantiate the const arrays
91 in BlockOrthoArg and then call the CPU variant of the block
94 template <typename Rotator, typename Vector, std::size_t... S>
95 void CPU(const std::vector<ColorSpinorField*> &B, std::index_sequence<S...>) {
96 typedef BlockOrthoArg<Rotator,Vector,nSpin,spinBlockSize,coarseSpin,nVec> Arg;
97 Arg arg(V, fine_to_coarse, coarse_to_fine, QUDA_INVALID_PARITY, geo_bs, n_block_ortho, V, B[S]...);
98 blockOrthoCPU<sumType,RegType,nSpin,spinBlockSize,nColor,coarseSpin,nVec,Arg>(arg);
102 @brief Helper function for expanding the std::vector into a
103 parameter pack that we can use to instantiate the const arrays
104 in BlockOrthoArg and then call the GPU variant of the block
107 template <typename Rotator, typename Vector, std::size_t... S>
108 void GPU(const TuneParam &tp, const qudaStream_t &stream, const std::vector<ColorSpinorField*> &B, std::index_sequence<S...>) {
109 typedef typename mapper<vFloat>::type RegType; // need to redeclare typedef (WAR for CUDA 7 and 8)
110 typedef BlockOrthoArg<Rotator,Vector,nSpin,spinBlockSize,coarseSpin,nVec> Arg;
111 Arg arg(V, fine_to_coarse, coarse_to_fine, QUDA_INVALID_PARITY, geo_bs, n_block_ortho, V, B[S]...);
112 arg.swizzle = tp.aux.x;
114 using namespace jitify::reflection;
115 auto instance = program->kernel("quda::blockOrthoGPU")
116 .instantiate((int)tp.block.x,Type<sumType>(),Type<RegType>(),nSpin,spinBlockSize,nColor,coarseSpin,nVec,Type<Arg>());
117 cuMemcpyHtoDAsync(instance.get_constant_ptr("quda::B_array_d"), B_array_h, MAX_MATRIX_SIZE, stream);
118 jitify_error = instance.configure(tp.grid,tp.block,tp.shared_bytes,stream).launch(arg);
120 cudaMemcpyToSymbolAsync(B_array_d, B_array_h, MAX_MATRIX_SIZE, 0, cudaMemcpyHostToDevice, stream);
121 LAUNCH_KERNEL_MG_BLOCK_SIZE(blockOrthoGPU,tp,stream,arg,sumType,RegType,nSpin,spinBlockSize,nColor,coarseSpin,nVec,Arg);
125 void apply(const qudaStream_t &stream) {
126 TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
127 if (V.Location() == QUDA_CPU_FIELD_LOCATION) {
128 if (V.FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER && B[0]->FieldOrder() == QUDA_SPACE_SPIN_COLOR_FIELD_ORDER) {
129 typedef FieldOrderCB<RegType,nSpin,nColor,nVec,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER,vFloat,vFloat,DISABLE_GHOST> Rotator;
130 typedef FieldOrderCB<RegType,nSpin,nColor,1,QUDA_SPACE_SPIN_COLOR_FIELD_ORDER,bFloat,bFloat,DISABLE_GHOST> Vector;
131 CPU<Rotator,Vector>(B, std::make_index_sequence<nVec>());
133 errorQuda("Unsupported field order %d\n", V.FieldOrder());
136 if (V.FieldOrder() == QUDA_FLOAT2_FIELD_ORDER && B[0]->FieldOrder() == BOrder<bFloat,nSpin,nColor>::order) {
137 typedef FieldOrderCB<RegType,nSpin,nColor,nVec,QUDA_FLOAT2_FIELD_ORDER,vFloat,vFloat,DISABLE_GHOST> Rotator;
138 typedef FieldOrderCB<RegType,nSpin,nColor,1,BOrder<bFloat,nSpin,nColor>::order,bFloat,bFloat,DISABLE_GHOST,isFixed<bFloat>::value> Vector;
139 GPU<Rotator,Vector>(tp,stream,B,std::make_index_sequence<nVec>());
141 errorQuda("Unsupported field order V=%d B=%d\n", V.FieldOrder(), B[0]->FieldOrder());
146 bool advanceAux(TuneParam ¶m) const
149 if (param.aux.x < 2*deviceProp.multiProcessorCount) {
161 bool advanceTuneParam(TuneParam ¶m) const {
162 if (V.Location() == QUDA_CUDA_FIELD_LOCATION) {
163 return advanceSharedBytes(param) || advanceAux(param);
169 TuneKey tuneKey() const { return TuneKey(V.VolString(), typeid(*this).name(), aux); }
171 void initTuneParam(TuneParam ¶m) const { defaultTuneParam(param); }
173 /** sets default values for when tuning is disabled */
174 void defaultTuneParam(TuneParam ¶m) const {
175 param.block = dim3(geoBlockSize/2, V.SiteSubset(), 1);
176 param.grid = dim3((minThreads() + param.block.x - 1) / param.block.x, 1, coarseSpin);
177 param.shared_bytes = 0;
178 param.aux.x = 1; // swizzle factor
181 long long flops() const
183 // FIXME: verify for staggered
184 return n_block_ortho * nBlock * (geoBlockSize / 2) * (spinBlockSize == 0 ? 1 : 2 * spinBlockSize) / 2 * nColor
185 * (nVec * ((nVec - 1) * (8l + 8l)) + 6l);
188 long long bytes() const
190 return nVec * B[0]->Bytes() + (nVec - 1) * nVec / 2 * V.Bytes() / nVec + V.Bytes()
191 + (n_block_ortho - 1) * (V.Bytes() + (nVec - 1) * nVec / 2 * V.Bytes() / nVec + V.Bytes());
194 char *saveOut, *saveOutNorm;
196 void preTune() { V.backup(); }
197 void postTune() { V.restore(); }
201 template <typename vFloat, typename bFloat, int nSpin, int spinBlockSize, int nColor, int nVec>
202 void BlockOrthogonalize(ColorSpinorField &V, const std::vector<ColorSpinorField *> &B, const int *fine_to_coarse,
203 const int *coarse_to_fine, const int *geo_bs, const int n_block_ortho)
206 int geo_blocksize = 1;
207 for (int d = 0; d < V.Ndim(); d++) geo_blocksize *= geo_bs[d];
209 int blocksize = geo_blocksize * V.Ncolor();
210 if (spinBlockSize == 0) { blocksize /= 2; } else { blocksize *= spinBlockSize; }
211 int chiralBlocks = (spinBlockSize == 0) ? 2 : V.Nspin() / spinBlockSize; //always 2 for staggered.
212 int numblocks = (V.Volume()/geo_blocksize) * chiralBlocks;
213 constexpr int coarseSpin = (nSpin == 4 || nSpin == 2 || spinBlockSize == 0) ? 2 : 1;
215 if (getVerbosity() >= QUDA_VERBOSE)
216 printfQuda("Block Orthogonalizing %d blocks of %d length and width %d repeating %d times\n", numblocks, blocksize,
217 nVec, n_block_ortho);
219 V.Scale(1.0); // by definition this is true
220 BlockOrtho<double, vFloat, bFloat, nSpin, spinBlockSize, nColor, coarseSpin, nVec> ortho(
221 V, B, fine_to_coarse, coarse_to_fine, geo_bs, n_block_ortho);
225 template <typename vFloat, typename bFloat>
226 void BlockOrthogonalize(ColorSpinorField &V, const std::vector<ColorSpinorField *> &B, const int *fine_to_coarse,
227 const int *coarse_to_fine, const int *geo_bs, int spin_bs, int n_block_ortho)
229 const int Nvec = B.size();
230 if (V.Ncolor()/Nvec == 3) {
231 constexpr int nColor = 3;
233 if (V.Nspin() == 4) {
234 constexpr int nSpin = 4;
235 if (spin_bs != 2) errorQuda("Unexpected spin block size = %d", spin_bs);
236 constexpr int spinBlockSize = 2;
238 if (Nvec == 6) { // for Wilson free field
239 BlockOrthogonalize<vFloat, bFloat, nSpin, spinBlockSize, nColor, 6>(V, B, fine_to_coarse, coarse_to_fine,
240 geo_bs, n_block_ortho);
241 } else if (Nvec == 24) {
242 BlockOrthogonalize<vFloat, bFloat, nSpin, spinBlockSize, nColor, 24>(V, B, fine_to_coarse, coarse_to_fine,
243 geo_bs, n_block_ortho);
244 } else if (Nvec == 32) {
245 BlockOrthogonalize<vFloat, bFloat, nSpin, spinBlockSize, nColor, 32>(V, B, fine_to_coarse, coarse_to_fine,
246 geo_bs, n_block_ortho);
248 errorQuda("Unsupported nVec %d\n", Nvec);
253 if (V.Nspin() == 1) {
254 constexpr int nSpin = 1;
255 if (spin_bs != 0) errorQuda("Unexpected spin block size = %d", spin_bs);
256 constexpr int spinBlockSize = 0;
259 BlockOrthogonalize<vFloat, bFloat, nSpin, spinBlockSize, nColor, 24>(V, B, fine_to_coarse, coarse_to_fine,
260 geo_bs, n_block_ortho);
261 } else if (Nvec == 64) {
262 BlockOrthogonalize<vFloat, bFloat, nSpin, spinBlockSize, nColor, 64>(V, B, fine_to_coarse, coarse_to_fine,
263 geo_bs, n_block_ortho);
264 } else if (Nvec == 96) {
265 BlockOrthogonalize<vFloat, bFloat, nSpin, spinBlockSize, nColor, 96>(V, B, fine_to_coarse, coarse_to_fine,
266 geo_bs, n_block_ortho);
268 errorQuda("Unsupported nVec %d\n", Nvec);
274 errorQuda("Unexpected nSpin = %d", V.Nspin());
278 if (V.Nspin() != 2) errorQuda("Unexpected nSpin = %d", V.Nspin());
279 constexpr int nSpin = 2;
280 if (spin_bs != 1) errorQuda("Unexpected spin block size = %d", spin_bs);
281 constexpr int spinBlockSize = 1;
284 if (V.Ncolor()/Nvec == 6) {
285 constexpr int nColor = 6;
287 BlockOrthogonalize<vFloat, bFloat, nSpin, spinBlockSize, nColor, 6>(V, B, fine_to_coarse, coarse_to_fine,
288 geo_bs, n_block_ortho);
290 errorQuda("Unsupported nVec %d\n", Nvec);
294 if (V.Ncolor()/Nvec == 24) {
295 constexpr int nColor = 24;
297 BlockOrthogonalize<vFloat,bFloat,nSpin,spinBlockSize,nColor,24>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, n_block_ortho);
299 } else if (Nvec == 32) {
300 BlockOrthogonalize<vFloat,bFloat,nSpin,spinBlockSize,nColor,32>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, n_block_ortho);
303 } else if (Nvec == 64) {
304 BlockOrthogonalize<vFloat,bFloat,nSpin,spinBlockSize,nColor,64>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, n_block_ortho);
305 } else if (Nvec == 96) {
306 BlockOrthogonalize<vFloat,bFloat,nSpin,spinBlockSize,nColor,96>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, n_block_ortho);
309 errorQuda("Unsupported nVec %d\n", Nvec);
312 } else if (V.Ncolor()/Nvec == 32) {
313 constexpr int nColor = 32;
315 BlockOrthogonalize<vFloat,bFloat,nSpin,spinBlockSize,nColor,32>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, n_block_ortho);
317 errorQuda("Unsupported nVec %d\n", Nvec);
321 } else if (V.Ncolor()/Nvec == 64) {
322 constexpr int nColor = 64;
324 BlockOrthogonalize<vFloat,bFloat,nSpin,spinBlockSize,nColor,64>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, n_block_ortho);
325 } else if (Nvec == 96) {
326 BlockOrthogonalize<vFloat,bFloat,nSpin,spinBlockSize,nColor,96>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, n_block_ortho);
328 errorQuda("Unsupported nVec %d\n", Nvec);
330 } else if (V.Ncolor()/Nvec == 96) {
331 constexpr int nColor = 96;
333 BlockOrthogonalize<vFloat,bFloat,nSpin,spinBlockSize,nColor,96>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, n_block_ortho);
335 errorQuda("Unsupported nVec %d\n", Nvec);
339 errorQuda("Unsupported nColor %d\n", V.Ncolor()/Nvec);
344 void BlockOrthogonalize(ColorSpinorField &V, const std::vector<ColorSpinorField *> &B, const int *fine_to_coarse,
345 const int *coarse_to_fine, const int *geo_bs, const int spin_bs, const int n_block_ortho)
348 if (B[0]->V() == nullptr) {
349 warningQuda("Trying to BlockOrthogonalize staggered transform, skipping...");
352 if (V.Precision() == QUDA_DOUBLE_PRECISION && B[0]->Precision() == QUDA_DOUBLE_PRECISION) {
353 #ifdef GPU_MULTIGRID_DOUBLE
354 BlockOrthogonalize<double>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, spin_bs, n_block_ortho);
356 errorQuda("Double precision multigrid has not been enabled");
358 } else if (V.Precision() == QUDA_SINGLE_PRECISION && B[0]->Precision() == QUDA_SINGLE_PRECISION) {
359 BlockOrthogonalize<float, float>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, spin_bs, n_block_ortho);
360 } else if (V.Precision() == QUDA_HALF_PRECISION && B[0]->Precision() == QUDA_SINGLE_PRECISION) {
361 #if QUDA_PRECISION & 2
362 BlockOrthogonalize<short, float>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, spin_bs, n_block_ortho);
364 errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
366 } else if (V.Precision() == QUDA_HALF_PRECISION && B[0]->Precision() == QUDA_HALF_PRECISION) {
367 #if QUDA_PRECISION & 2
368 BlockOrthogonalize<short, short>(V, B, fine_to_coarse, coarse_to_fine, geo_bs, spin_bs, n_block_ortho);
370 errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION);
373 errorQuda("Unsupported precision combination V=%d B=%d\n", V.Precision(), B[0]->Precision());
376 errorQuda("Multigrid has not been built");
377 #endif // GPU_MULTIGRID