xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma.h (revision f71aa81bd7d2e9c6555cba4570cf145ac8d1aa26)
1 // Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
2 // the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
3 // reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 // magma functions specific to ceed
18 #ifndef _ceed_magma_h
19 #define _ceed_magma_h
20 
21 #include <ceed/ceed.h>
22 #include <ceed/backend.h>
23 #include <magma_v2.h>
24 
25 typedef enum {
26   MAGMA_KERNEL_DIM_GENERIC=101,
27   MAGMA_KERNEL_DIM_SPECIFIC=102
28 } magma_kernel_mode_t;
29 
30 typedef struct {
31   magma_kernel_mode_t basis_kernel_mode;
32   magma_device_t device;
33   magma_queue_t queue;
34 } Ceed_Magma;
35 
36 typedef struct {
37   CeedScalar *dqref1d;
38   CeedScalar *dinterp1d;
39   CeedScalar *dgrad1d;
40   CeedScalar *dqweight1d;
41 } CeedBasis_Magma;
42 
43 typedef struct {
44   CeedScalar *dqref;
45   CeedScalar *dinterp;
46   CeedScalar *dgrad;
47   CeedScalar *dqweight;
48 } CeedBasisNonTensor_Magma;
49 
50 typedef struct {
51   CeedInt *offsets;
52   CeedInt *doffsets;
53   int  own_;
54   int down_;            // cover a case where we own Device memory
55 } CeedElemRestriction_Magma;
56 
57 typedef struct {
58   const CeedScalar **inputs;
59   CeedScalar **outputs;
60   bool setupdone;
61 } CeedQFunction_Magma;
62 
63 #define USE_MAGMA_BATCH
64 #define USE_MAGMA_BATCH2
65 #define USE_MAGMA_BATCH3
66 #define USE_MAGMA_BATCH4
67 
68 #ifdef __cplusplus
69 CEED_INTERN {
70 #endif
71 
72   magma_int_t magma_interp_1d(
73     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
74     const CeedScalar *dT, CeedTransposeMode tmode,
75     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
76     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
77     magma_int_t nelem, magma_queue_t queue);
78 
79   magma_int_t magma_interp_2d(
80     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
81     const CeedScalar *dT, CeedTransposeMode tmode,
82     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
83     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
84     magma_int_t nelem, magma_queue_t queue);
85 
86   magma_int_t magma_interp_3d(
87     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
88     const CeedScalar *dT, CeedTransposeMode tmode,
89     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
90     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
91     magma_int_t nelem, magma_queue_t queue);
92 
93   magma_int_t magma_interp_generic(magma_int_t P, magma_int_t Q,
94                                    magma_int_t dim, magma_int_t ncomp,
95                                    const CeedScalar *dT, CeedTransposeMode tmode,
96                                    const CeedScalar *dU, magma_int_t u_elemstride,
97                                    magma_int_t cstrdU,
98                                    CeedScalar *dV, magma_int_t v_elemstride,
99                                    magma_int_t cstrdV,
100                                    magma_int_t nelem, magma_queue_t queue);
101 
102   magma_int_t magma_interp(
103     magma_int_t P, magma_int_t Q,
104     magma_int_t dim, magma_int_t ncomp,
105     const CeedScalar *dT, CeedTransposeMode tmode,
106     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
107     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
108     magma_int_t nelem, magma_kernel_mode_t kernel_mode, magma_queue_t queue);
109 
110   magma_int_t magma_grad_1d(
111     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
112     const CeedScalar *dTinterp, const CeedScalar *dTgrad, CeedTransposeMode tmode,
113     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
114     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
115     magma_int_t nelem, magma_queue_t queue);
116 
117   magma_int_t magma_gradn_2d(
118     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
119     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
120     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
121     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
122     magma_int_t nelem, magma_queue_t queue);
123 
124   magma_int_t magma_gradt_2d(
125     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
126     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
127     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
128     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
129     magma_int_t nelem, magma_queue_t queue);
130 
131   magma_int_t magma_gradn_3d(
132     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
133     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
134     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
135     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
136     magma_int_t nelem, magma_queue_t queue);
137 
138   magma_int_t magma_gradt_3d(
139     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
140     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
141     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
142     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
143     magma_int_t nelem, magma_queue_t queue);
144 
145   magma_int_t magma_grad_generic(
146     magma_int_t P, magma_int_t Q, magma_int_t dim, magma_int_t ncomp,
147     const CeedScalar* dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
148     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
149     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
150     magma_int_t nelem, magma_queue_t queue);
151 
152   magma_int_t magma_grad(
153     magma_int_t P, magma_int_t Q, magma_int_t dim, magma_int_t ncomp,
154     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
155     const CeedScalar *dU, magma_int_t u_elemstride, magma_int_t cstrdU, magma_int_t dstrdU,
156     CeedScalar *dV, magma_int_t v_elemstride, magma_int_t cstrdV, magma_int_t dstrdV,
157     magma_int_t nelem, magma_kernel_mode_t kernel_mode, magma_queue_t queue);
158 
159   magma_int_t magma_weight_1d(
160     magma_int_t Q, const CeedScalar *dqweight1d,
161     CeedScalar *dV, magma_int_t v_stride,
162     magma_int_t nelem, magma_queue_t queue);
163 
164   magma_int_t magma_weight_2d(
165     magma_int_t Q, const CeedScalar *dqweight1d,
166     CeedScalar *dV, magma_int_t v_stride,
167     magma_int_t nelem, magma_queue_t queue);
168 
169   magma_int_t magma_weight_3d(
170     magma_int_t Q, const CeedScalar *dqweight1d,
171     CeedScalar *dV, magma_int_t v_stride,
172     magma_int_t nelem, magma_queue_t queue);
173 
174   magma_int_t magma_weight_generic(
175     magma_int_t Q, magma_int_t dim,
176     const CeedScalar *dqweight1d,
177     CeedScalar *dV, magma_int_t vstride,
178     magma_int_t nelem, magma_queue_t queue);
179 
180   magma_int_t magma_weight(
181     magma_int_t Q, magma_int_t dim,
182     const CeedScalar *dqweight1d,
183     CeedScalar *dV, magma_int_t v_stride,
184     magma_int_t nelem, magma_kernel_mode_t kernel_mode, magma_queue_t queue);
185 
186   void magma_weight_nontensor(magma_int_t grid, magma_int_t threads, magma_int_t nelem,
187                               magma_int_t Q,
188                               CeedScalar *dqweight, CeedScalar *dv, magma_queue_t queue);
189 
190   void magma_readDofsOffset(const magma_int_t NCOMP,
191                             const magma_int_t compstride,
192                             const magma_int_t esize, const magma_int_t nelem,
193                             magma_int_t *offsets, const CeedScalar *du, CeedScalar *dv,
194                             magma_queue_t queue);
195 
196   void magma_readDofsStrided(const magma_int_t NCOMP, const magma_int_t esize,
197                              const magma_int_t nelem, magma_int_t *strides,
198                              const CeedScalar *du, CeedScalar *dv,
199                              magma_queue_t queue);
200 
201   void magma_writeDofsOffset(const magma_int_t NCOMP,
202                              const magma_int_t compstride,
203                              const magma_int_t esize, const magma_int_t nelem,
204                              magma_int_t *offsets,const CeedScalar *du, CeedScalar *dv,
205                              magma_queue_t queue);
206 
207   void magma_writeDofsStrided(const magma_int_t NCOMP, const magma_int_t esize,
208                               const magma_int_t nelem, magma_int_t *strides,
209                               const CeedScalar *du, CeedScalar *dv,
210                               magma_queue_t queue);
211 
212   int magma_dgemm_nontensor(
213     magma_trans_t transA, magma_trans_t transB,
214     magma_int_t m, magma_int_t n, magma_int_t k,
215     double alpha, const double *dA, magma_int_t ldda,
216     const double *dB, magma_int_t lddb,
217     double beta,  double *dC, magma_int_t lddc,
218     magma_queue_t queue );
219 
220   int magma_sgemm_nontensor(
221     magma_trans_t transA, magma_trans_t transB,
222     magma_int_t m, magma_int_t n, magma_int_t k,
223     float alpha, const float *dA, magma_int_t ldda,
224     const float *dB, magma_int_t lddb,
225     float beta,  float *dC, magma_int_t lddc,
226     magma_queue_t queue );
227 
228   magma_int_t
229   magma_isdevptr(const void *A);
230 
231   int CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P1d,
232                                     CeedInt Q1d,
233                                     const CeedScalar *interp1d,
234                                     const CeedScalar *grad1d,
235                                     const CeedScalar *qref1d,
236                                     const CeedScalar *qweight1d,
237                                     CeedBasis basis);
238 
239   int CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim,
240                               CeedInt ndof, CeedInt nqpts,
241                               const CeedScalar *interp,
242                               const CeedScalar *grad,
243                               const CeedScalar *qref,
244                               const CeedScalar *qweight,
245                               CeedBasis basis);
246 
247   int CeedElemRestrictionCreate_Magma(CeedMemType mtype,
248                                       CeedCopyMode cmode,
249                                       const CeedInt *offsets,
250                                       CeedElemRestriction r);
251 
252   int CeedElemRestrictionCreateBlocked_Magma(const CeedMemType mtype,
253       const CeedCopyMode cmode,
254       const CeedInt *offsets,
255       const CeedElemRestriction res);
256 
257   int CeedOperatorCreate_Magma(CeedOperator op);
258 
259   #ifdef __cplusplus
260 }
261   #endif
262 
263 // comment the line below to use the default magma_is_devptr function
264 #define magma_is_devptr magma_isdevptr
265 
266 // if magma and cuda/ref are using the null stream, then ceed_magma_queue_sync
267 // should do nothing
268 #define ceed_magma_queue_sync(...)
269 
270 // batch stride, override using -DMAGMA_BATCH_STRIDE=<desired-value>
271 #ifndef MAGMA_BATCH_STRIDE
272 #define MAGMA_BATCH_STRIDE (1000)
273 #endif
274 
275 #endif  // _ceed_magma_h
276