26 #define MAX(a,b) ((a)>(b)?(a):(b))
27 #define staggeredSpinorSiteSize 6
30 extern void usage(
char** argv );
51 const void **ghost_fatlink, **ghost_longlink;
113 int tmpint =
MAX(
X[1]*
X[2]*
X[3], X[0]*X[2]*X[3]);
114 tmpint =
MAX(tmpint, X[0]*X[1]*X[3]);
115 tmpint =
MAX(tmpint, X[0]*X[1]*X[2]);
125 for(
int d = 0; d < 4; d++) {
162 errorQuda(
"ERROR: malloc failed for fatlink/longlink");
180 int x_face_size = X[1]*X[2]*X[3]/2;
181 int y_face_size = X[0]*X[2]*X[3]/2;
182 int z_face_size = X[0]*X[1]*X[3]/2;
183 int t_face_size = X[0]*X[1]*X[2]/2;
184 int pad_size =
MAX(x_face_size, y_face_size);
185 pad_size =
MAX(pad_size, z_face_size);
186 pad_size =
MAX(pad_size, t_face_size);
231 cudaDeviceSynchronize();
236 printfQuda(
"Source CPU = %f, CUDA=%f\n", spinor_norm2, cuda_spinor_norm2);
286 cudaEvent_t start,
end;
287 cudaEventCreate(&start);
288 cudaEventRecord(start, 0);
289 cudaEventSynchronize(start);
291 for (
int i = 0; i <
niter; i++) {
318 cudaEventCreate(&end);
319 cudaEventRecord(end, 0);
320 cudaEventSynchronize(end);
322 cudaEventElapsedTime(&runTime, start, end);
323 cudaEventDestroy(start);
324 cudaEventDestroy(end);
326 double secs = runTime / 1000;
329 cudaError_t stat = cudaGetLastError();
330 if (stat != cudaSuccess)
331 errorQuda(
"with ERROR: %s\n", cudaGetErrorString(stat));
343 printfQuda(
"Calculating reference implementation...");
383 static int dslashTest()
385 int accuracy_level = 0;
401 #ifdef DSLASH_PROFILING
402 printDslashProfile();
412 int spinor_floats = 8*6*2 + 6;
413 int link_float_size =
prec;
414 int spinor_float_size = 0;
416 link_floats =
test_type ? (2*link_floats) : link_floats;
417 spinor_floats =
test_type ? (2*spinor_floats) : spinor_floats;
419 int bytes_for_one_site = link_floats * link_float_size + spinor_floats * spinor_float_size;
422 printfQuda(
"GFLOPS = %f\n", 1.0e-9*flops/secs);
429 printfQuda(
"Results: CPU=%f, CUDA=%f, CPU-CUDA=%f\n", spinor_ref_norm2, cuda_spinor_out_norm2,
434 printfQuda(
"Result: CPU=%f , CPU-CUDA=%f", spinor_ref_norm2, spinor_out_norm2);
441 return accuracy_level;
449 printfQuda(
"prec recon test_type dagger S_dim T_dimension\n");
475 int main(
int argc,
char **argv)
479 for (i =1;i < argc; i++){
485 fprintf(stderr,
"ERROR: Invalid option:%s\n", argv[i]);
494 int accuracy_level = dslashTest();
496 printfQuda(
"accuracy_level =%d\n", accuracy_level);
498 if (accuracy_level >= 1) ret = 0;