60 uint64_t arrayA_size = 0, arrayB_size = 0, arrayC_size = 0;
65 arrayA_size = blas_param->
lda * blas_param->
k;
68 arrayA_size = blas_param->
lda * blas_param->
m;
73 arrayB_size = blas_param->
ldb * blas_param->
n;
76 arrayB_size = blas_param->
ldb * blas_param->
k;
79 arrayC_size = blas_param->
ldc * blas_param->
n;
85 arrayA_size = blas_param->
lda * blas_param->
m;
88 arrayA_size = blas_param->
lda * blas_param->
k;
92 arrayB_size = blas_param->
ldb * blas_param->
k;
95 arrayB_size = blas_param->
ldb * blas_param->
n;
98 arrayC_size = blas_param->
ldc * blas_param->
m;
110 int batches_extra = 0;
112 int batches = blas_param->
batch_count + batches_extra;
114 size_t A_bytes = batches * arrayA_size * re_im * data_size;
115 size_t B_bytes = batches * arrayB_size * re_im * data_size;
116 size_t C_bytes = batches * arrayC_size * re_im * data_size;
118 printfQuda(
"A_Gbtyes = %f, B_Gbtyes = %f, C_Gbtyes = %f\n", 1.0 * A_bytes /
std::pow(1024, 3),
128 qudaMemcpy(A_d, arrayA, A_bytes, cudaMemcpyHostToDevice);
129 qudaMemcpy(B_d, arrayB, B_bytes, cudaMemcpyHostToDevice);
130 qudaMemcpy(C_d, arrayC, C_bytes, cudaMemcpyHostToDevice);
144 qudaMemcpy(arrayC, C_d, C_bytes, cudaMemcpyDeviceToHost);
TimeProfile & getProfileBLAS()
Profiler for covariant derivative.
void blasGEMMQuda(void *arrayA, void *arrayB, void *arrayC, QudaBoolean use_native, QudaBLASParam *blas_param)
Strided Batched GEMM.
void checkBLASParam(QudaBLASParam ¶m)
@ QUDA_CUDA_FIELD_LOCATION
@ QUDA_CPU_FIELD_LOCATION
enum QudaBoolean_s QudaBoolean
@ QUDA_BLAS_DATAORDER_COL
#define pool_device_malloc(size)
#define pool_device_free(ptr)
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
long long stridedBatchGEMM(void *A, void *B, void *C, QudaBLASParam blas_param, QudaFieldLocation location)
Strided Batch GEMM. This function performs N GEMM type operations in a strided batched fashion....
void saveTuneCache(bool error=false)
__host__ __device__ ValueType pow(ValueType x, ExponentType e)
Main header file for the QUDA library.
#define qudaMemcpy(dst, src, count, kind)
QudaBLASDataOrder data_order
QudaBLASOperation trans_a
QudaBLASDataType data_type
QudaBLASOperation trans_b
QudaVerbosity getVerbosity()