QUDA  1.0.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
blas_reference.cpp
Go to the documentation of this file.
1 #include <blas_reference.h>
2 #include <stdio.h>
3 #include <comm_quda.h>
4 
5 template <typename Float>
6 inline void aXpY(Float a, Float *x, Float *y, int len)
7 {
8  for(int i=0; i < len; i++){ y[i] += a*x[i]; }
9 }
10 
11 void axpy(double a, void *x, void *y, int len, QudaPrecision precision) {
12  if( precision == QUDA_DOUBLE_PRECISION ) aXpY(a, (double *)x, (double *)y, len);
13  else aXpY((float)a, (float *)x, (float *)y, len);
14 }
15 
16 // performs the operation x[i] *= a
17 template <typename Float>
18 inline void aX(Float a, Float *x, int len) {
19  for (int i=0; i<len; i++) x[i] *= a;
20 }
21 
22 void ax(double a, void *x, int len, QudaPrecision precision) {
23  if (precision == QUDA_DOUBLE_PRECISION) aX(a, (double*)x, len);
24  else aX((float)a, (float*)x, len);
25 }
26 
27 // performs the operation y[i] -= x[i] (minus x plus y)
28 template <typename Float>
29 inline void mXpY(Float *x, Float *y, int len) {
30  for (int i=0; i<len; i++) y[i] -= x[i];
31 }
32 
33 void mxpy(void* x, void* y, int len, QudaPrecision precision) {
34  if (precision == QUDA_DOUBLE_PRECISION) mXpY((double*)x, (double*)y, len);
35  else mXpY((float*)x, (float*)y, len);
36 }
37 
38 
39 // returns the square of the L2 norm of the vector
40 template <typename Float>
41 inline double norm2(Float *v, int len) {
42  double sum=0.0;
43  for (int i=0; i<len; i++) sum += v[i]*v[i];
44  comm_allreduce(&sum);
45  return sum;
46 }
47 
48 double norm_2(void *v, int len, QudaPrecision precision) {
49  if (precision == QUDA_DOUBLE_PRECISION) return norm2((double*)v, len);
50  else return norm2((float*)v, len);
51 }
52 
53 // performs the operation y[i] = x[i] + a*y[i]
54 template <typename Float>
55 static inline void xpay(Float *x, Float a, Float *y, int len) {
56  for (int i=0; i<len; i++) y[i] = x[i] + a*y[i];
57 }
58 
59 void xpay(void *x, double a, void *y, int length, QudaPrecision precision) {
60  if (precision == QUDA_DOUBLE_PRECISION) xpay((double*)x, a, (double*)y, length);
61  else xpay((float*)x, (float)a, (float*)y, length);
62 }
63 
64 void cxpay(void *x, double _Complex a, void *y, int length, QudaPrecision precision)
65 {
66  if (precision == QUDA_DOUBLE_PRECISION) {
67  xpay((double _Complex *)x, (double _Complex)a, (double _Complex *)y, length / 2);
68  } else {
69  xpay((float _Complex *)x, (float _Complex)a, (float _Complex *)y, length / 2);
70  }
71 }
static void sum(Float *dst, Float *a, Float *b, int cnt)
Definition: dslash_util.h:8
enum QudaPrecision_s QudaPrecision
int length[]
void ax(double a, void *x, int len, QudaPrecision precision)
void aXpY(Float a, Float *x, Float *y, int len)
void mXpY(Float *x, Float *y, int len)
static void xpay(Float *x, Float a, Float *y, int len)
void cxpay(void *x, double _Complex a, void *y, int length, QudaPrecision precision)
double norm2(Float *v, int len)
void mxpy(void *x, void *y, int len, QudaPrecision precision)
double norm_2(void *v, int len, QudaPrecision precision)
void aX(Float a, Float *x, int len)
void axpy(double a, void *x, void *y, int len, QudaPrecision precision)
void comm_allreduce(double *data)
Definition: comm_mpi.cpp:242