QUDA  v1.1.0
A library for QCD on GPUs
blas_interface.cpp
Go to the documentation of this file.
1 #include <quda.h>
2 #include <blas_lapack.h>
3 #include <tune_quda.h>
4 
5 using namespace quda;
6 
7 // Forward declarations for profiling and parameter checking
8 // The helper functions are defined in interface_quda.cpp
11 
12 void blasGEMMQuda(void *arrayA, void *arrayB, void *arrayC, QudaBoolean use_native, QudaBLASParam *blas_param)
13 {
15  checkBLASParam(*blas_param);
16 
17  // cuBLAS works exclusively in column major order. If the input data is in
18  // row major order, we may treat the A and B and C arrays as A^T, B^T, and C^T.
19  // We swap the order of the A * B multiplication and swap the
20  // operation types and other data to recover the the desired result in the
21  // desired order.
22  // E.g: in row major, the operation,
23  // C = a * A^T * B + b * C
24  //
25  // will become the column major operation
26  // C^T = a * B^T * A + b * C^T
27  //
28  // By inspection, one can see that transposition of the above column major
29  // operation will result in the desired row major answer:
30  //
31  // (C^T)^T = a * (B^T * A)^T + b * (C^T)^T
32  // --> C = a * A^T * B + b * C
33  //
34  // We must also swap around some parameters. The Row major indices,
35  // A_{m, lda}, B_{k, ldb}, C_{m, ldc}
36  // become
37  // A^T_{lda, m}, B^T_{ldb, k}, C^T_{ldc, m}.
38  // so the leading dimensions remain the same. However, we must change the actual
39  // matrix dims m,n,k to reflect the change to column major.
40  // m_{col} = n_{row}
41  // n_{col} = m_{row}
42  // k_{col} = k_{row}
43  // And because we are swapping the A and B arrays, we must also swap their
44  // leading dim values and any offsets. All this is done behind the scenes in the
45  // BatchGEMM function, and before function exit all pointers and values are
46  // restored to the values they had on entry.
47 
50  blas_lapack::generic::stridedBatchGEMM(arrayA, arrayB, arrayC, *blas_param, QUDA_CPU_FIELD_LOCATION);
52  } else {
54 
55  // The data in the arrays is on the host. We transfer the data to the device here
56  // for timing purposes. One can pass host pointers to the BatchGEMM function
57  // and it will handle the data movement for the user.
58 
59  // Extract data from the param struct for device malloc
60  uint64_t arrayA_size = 0, arrayB_size = 0, arrayC_size = 0;
61  if (blas_param->data_order == QUDA_BLAS_DATAORDER_COL) {
62  // leading dimension is in terms of consecutive data
63  // elements in a column, multiplied by number of rows
64  if (blas_param->trans_a == QUDA_BLAS_OP_N) {
65  arrayA_size = blas_param->lda * blas_param->k; // A_mk
66  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array A_{%d, %d}\n", blas_param->lda, blas_param->k);
67  } else {
68  arrayA_size = blas_param->lda * blas_param->m; // A_km
69  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array A_{%d, %d}\n", blas_param->lda, blas_param->m);
70  }
71 
72  if (blas_param->trans_b == QUDA_BLAS_OP_N) {
73  arrayB_size = blas_param->ldb * blas_param->n; // B_kn
74  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array B_{%d, %d}\n", blas_param->ldb, blas_param->n);
75  } else {
76  arrayB_size = blas_param->ldb * blas_param->k; // B_nk
77  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array B_{%d, %d}\n", blas_param->ldb, blas_param->k);
78  }
79  arrayC_size = blas_param->ldc * blas_param->n; // C_mn
80  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array C_{%d, %d}\n", blas_param->ldc, blas_param->n);
81  } else {
82  // leading dimension is in terms of consecutive data
83  // elements in a row, multiplied by number of columns.
84  if (blas_param->trans_a == QUDA_BLAS_OP_N) {
85  arrayA_size = blas_param->lda * blas_param->m; // A_mk
86  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array A_{%d, %d}\n", blas_param->m, blas_param->lda);
87  } else {
88  arrayA_size = blas_param->lda * blas_param->k; // A_km
89  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array A_{%d, %d}\n", blas_param->k, blas_param->lda);
90  }
91  if (blas_param->trans_b == QUDA_BLAS_OP_N) {
92  arrayB_size = blas_param->ldb * blas_param->k; // B_nk
93  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array B_{%d, %d}\n", blas_param->k, blas_param->ldb);
94  } else {
95  arrayB_size = blas_param->ldb * blas_param->n; // B_kn
96  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array B_{%d, %d}\n", blas_param->n, blas_param->ldb);
97  }
98  arrayC_size = blas_param->ldc * blas_param->m; // C_mn
99  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("array C_{%d, %d}\n", blas_param->m, blas_param->ldc);
100  }
101 
102  size_t data_size = (blas_param->data_type == QUDA_BLAS_DATATYPE_D || blas_param->data_type == QUDA_BLAS_DATATYPE_Z) ?
103  sizeof(double) :
104  sizeof(float);
105  int re_im = 1;
106  if (blas_param->data_type == QUDA_BLAS_DATATYPE_C || blas_param->data_type == QUDA_BLAS_DATATYPE_Z) { re_im *= 2; }
107 
108  // If the user passes non-zero offsets, add one extra
109  // matrix to the device array to accomodate it.
110  int batches_extra = 0;
111  if (blas_param->a_offset + blas_param->b_offset + blas_param->c_offset > 0) { batches_extra++; }
112  int batches = blas_param->batch_count + batches_extra;
113 
114  size_t A_bytes = batches * arrayA_size * re_im * data_size;
115  size_t B_bytes = batches * arrayB_size * re_im * data_size;
116  size_t C_bytes = batches * arrayC_size * re_im * data_size;
117  if (getVerbosity() >= QUDA_VERBOSE)
118  printfQuda("A_Gbtyes = %f, B_Gbtyes = %f, C_Gbtyes = %f\n", 1.0 * A_bytes / std::pow(1024, 3),
119  1.0 * B_bytes / std::pow(1024, 3), 1.0 * C_bytes / std::pow(1024, 3));
120  void *A_d = pool_device_malloc(A_bytes);
121  void *B_d = pool_device_malloc(B_bytes);
122  void *C_d = pool_device_malloc(C_bytes);
123  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("QUDA: arrays allocated sucessfully.\n");
125 
126  // Transfer host data to device
127  getProfileBLAS().TPSTART(QUDA_PROFILE_H2D);
128  qudaMemcpy(A_d, arrayA, A_bytes, cudaMemcpyHostToDevice);
129  qudaMemcpy(B_d, arrayB, B_bytes, cudaMemcpyHostToDevice);
130  qudaMemcpy(C_d, arrayC, C_bytes, cudaMemcpyHostToDevice);
131  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("QUDA: arrays copied susessfully.\n");
133 
134  // Compute Batched GEMM
136 
138 
139  if (getVerbosity() >= QUDA_VERBOSE) printfQuda("BatchGEMM success!\n");
141 
142  // Copy device C array back to host
143  getProfileBLAS().TPSTART(QUDA_PROFILE_D2H);
144  qudaMemcpy(arrayC, C_d, C_bytes, cudaMemcpyDeviceToHost);
146 
147  // Clean up
149  pool_device_free(A_d);
150  pool_device_free(B_d);
151  pool_device_free(C_d);
153  }
154 
156  saveTuneCache();
157 }
TimeProfile & getProfileBLAS()
Profiler for covariant derivative.
void blasGEMMQuda(void *arrayA, void *arrayB, void *arrayC, QudaBoolean use_native, QudaBLASParam *blas_param)
Strided Batched GEMM.
void checkBLASParam(QudaBLASParam &param)
@ QUDA_CUDA_FIELD_LOCATION
Definition: enum_quda.h:326
@ QUDA_CPU_FIELD_LOCATION
Definition: enum_quda.h:325
@ QUDA_VERBOSE
Definition: enum_quda.h:267
@ QUDA_BLAS_DATATYPE_Z
Definition: enum_quda.h:480
@ QUDA_BLAS_DATATYPE_D
Definition: enum_quda.h:478
@ QUDA_BLAS_DATATYPE_C
Definition: enum_quda.h:479
@ QUDA_BOOLEAN_FALSE
Definition: enum_quda.h:460
enum QudaBoolean_s QudaBoolean
@ QUDA_BLAS_DATAORDER_COL
Definition: enum_quda.h:486
@ QUDA_BLAS_OP_N
Definition: enum_quda.h:470
#define pool_device_malloc(size)
Definition: malloc_quda.h:170
#define pool_device_free(ptr)
Definition: malloc_quda.h:171
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
void saveTuneCache(bool error=false)
Definition: tune.cpp:439
@ QUDA_PROFILE_INIT
Definition: timer.h:106
@ QUDA_PROFILE_COMPUTE
Definition: timer.h:108
@ QUDA_PROFILE_TOTAL
Definition: timer.h:149
@ QUDA_PROFILE_FREE
Definition: timer.h:111
@ QUDA_PROFILE_H2D
Definition: timer.h:104
@ QUDA_PROFILE_D2H
Definition: timer.h:105
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Definition: complex_quda.h:111
QudaGaugeParam param
Definition: pack_test.cpp:18
Main header file for the QUDA library.
#define qudaMemcpy(dst, src, count, kind)
Definition: quda_api.h:204
int c_offset
Definition: quda.h:761
QudaBLASDataOrder data_order
Definition: quda.h:772
int b_offset
Definition: quda.h:760
QudaBLASOperation trans_a
Definition: quda.h:751
QudaBLASDataType data_type
Definition: quda.h:771
int a_offset
Definition: quda.h:759
int batch_count
Definition: quda.h:769
QudaBLASOperation trans_b
Definition: quda.h:752
#define printfQuda(...)
Definition: util_quda.h:114
QudaVerbosity getVerbosity()
Definition: util_quda.cpp:21