xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma.h (revision e5f091eb2082fd2e5a436aed5b3c40dee25ac3c3)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 // magma functions specific to ceed
9 #ifndef _ceed_magma_h
10 #define _ceed_magma_h
11 
12 #include <ceed/ceed.h>
13 #include <ceed/backend.h>
14 #include <magma_v2.h>
15 
16 #define MAGMA_MAXTHREADS_1D 128
17 #define MAGMA_MAXTHREADS_2D 128
18 #define MAGMA_MAXTHREADS_3D 64
19 // Define macro for determining number of threads in y-direction
20 // for basis kernels
21 #define MAGMA_BASIS_NTCOL(x, maxt) (((maxt) < (x)) ? 1 : ((maxt) / (x)))
22 // Define macro for computing the total threads in a block
23 // for use with __launch_bounds__()
24 #define MAGMA_BASIS_BOUNDS(x, maxt) (x * MAGMA_BASIS_NTCOL(x, maxt))
25 
26 #ifdef CEED_MAGMA_USE_HIP
27 #define MAGMA_RTC_MODULE hipModule_t
28 #define MAGMA_RTC_FUNCTION hipFunction_t
29 #define MAGMA_RTC_COMPILE CeedCompileHip
30 #define MAGMA_RTC_GET_KERNEL CeedGetKernelHip
31 #define MAGMA_RTC_RUN_KERNEL CeedRunKernelHip
32 #define MAGMA_RTC_RUN_KERNEL_DIM CeedRunKernelDimHip
33 #define MAGMA_RTC_RUN_KERNEL_DIM_SH CeedRunKernelDimSharedHip
34 #else
35 #define MAGMA_RTC_MODULE CUmodule
36 #define MAGMA_RTC_FUNCTION CUfunction
37 #define MAGMA_RTC_COMPILE CeedCompileCuda
38 #define MAGMA_RTC_GET_KERNEL CeedGetKernelCuda
39 #define MAGMA_RTC_RUN_KERNEL CeedRunKernelCuda
40 #define MAGMA_RTC_RUN_KERNEL_DIM CeedRunKernelDimCuda
41 #define MAGMA_RTC_RUN_KERNEL_DIM_SH CeedRunKernelDimSharedCuda
42 #endif
43 
44 typedef enum {
45   MAGMA_KERNEL_DIM_GENERIC=101,
46   MAGMA_KERNEL_DIM_SPECIFIC=102
47 } magma_kernel_mode_t;
48 
49 typedef struct {
50   magma_kernel_mode_t basis_kernel_mode;
51   magma_device_t device;
52   magma_queue_t queue;
53 } Ceed_Magma;
54 
55 typedef struct {
56   MAGMA_RTC_MODULE module;
57   MAGMA_RTC_FUNCTION magma_interp;
58   MAGMA_RTC_FUNCTION magma_interp_tr;
59   MAGMA_RTC_FUNCTION magma_grad;
60   MAGMA_RTC_FUNCTION magma_grad_tr;
61   MAGMA_RTC_FUNCTION magma_weight;
62   CeedScalar *dqref1d;
63   CeedScalar *dinterp1d;
64   CeedScalar *dgrad1d;
65   CeedScalar *dqweight1d;
66 } CeedBasis_Magma;
67 
68 typedef struct {
69   CeedScalar *dqref;
70   CeedScalar *dinterp;
71   CeedScalar *dgrad;
72   CeedScalar *dqweight;
73 } CeedBasisNonTensor_Magma;
74 
75 typedef enum {
76   OWNED_NONE = 0,
77   OWNED_UNPINNED,
78   OWNED_PINNED,
79 } OwnershipMode;
80 
81 typedef struct {
82   MAGMA_RTC_MODULE module;
83   MAGMA_RTC_FUNCTION StridedTranspose;
84   MAGMA_RTC_FUNCTION StridedNoTranspose;
85   MAGMA_RTC_FUNCTION OffsetTranspose;
86   MAGMA_RTC_FUNCTION OffsetNoTranspose;
87   CeedInt *offsets;
88   CeedInt *doffsets;
89   OwnershipMode own_;
90   int down_;            // cover a case where we own Device memory
91 } CeedElemRestriction_Magma;
92 
93 typedef struct {
94   const CeedScalar **inputs;
95   CeedScalar **outputs;
96   bool setupdone;
97 } CeedQFunction_Magma;
98 
99 #define USE_MAGMA_BATCH
100 #define USE_MAGMA_BATCH2
101 #define USE_MAGMA_BATCH3
102 #define USE_MAGMA_BATCH4
103 
104 #ifdef __cplusplus
105 CEED_INTERN {
106 #endif
107 
108   magma_int_t magma_interp_1d(
109     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
110     const CeedScalar *dT, CeedTransposeMode tmode,
111     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
112     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
113     magma_int_t nelem, magma_queue_t queue);
114 
115   magma_int_t magma_interp_2d(
116     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
117     const CeedScalar *dT, CeedTransposeMode tmode,
118     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
119     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
120     magma_int_t nelem, magma_queue_t queue);
121 
122   magma_int_t magma_interp_3d(
123     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
124     const CeedScalar *dT, CeedTransposeMode tmode,
125     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
126     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
127     magma_int_t nelem, magma_queue_t queue);
128 
129   magma_int_t magma_interp_generic(magma_int_t P, magma_int_t Q,
130                                    magma_int_t dim, magma_int_t ncomp,
131                                    const CeedScalar *dT, CeedTransposeMode tmode,
132                                    const CeedScalar *dU, magma_int_t u_elemstride,
133                                    magma_int_t cstrdU,
134                                    CeedScalar *dV, magma_int_t v_elemstride,
135                                    magma_int_t cstrdV,
136                                    magma_int_t nelem, magma_queue_t queue);
137 
138   magma_int_t magma_interp(
139     magma_int_t P, magma_int_t Q,
140     magma_int_t dim, magma_int_t ncomp,
141     const CeedScalar *dT, CeedTransposeMode tmode,
142     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
143     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
144     magma_int_t nelem, magma_kernel_mode_t kernel_mode, magma_queue_t queue);
145 
146   magma_int_t magma_grad_1d(
147     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
148     const CeedScalar *dTinterp, const CeedScalar *dTgrad, CeedTransposeMode tmode,
149     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
150     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
151     magma_int_t nelem, magma_queue_t queue);
152 
153   magma_int_t magma_gradn_2d(
154     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
155     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
156     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
157     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
158     magma_int_t nelem, magma_queue_t queue);
159 
160   magma_int_t magma_gradt_2d(
161     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
162     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
163     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
164     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
165     magma_int_t nelem, magma_queue_t queue);
166 
167   magma_int_t magma_gradn_3d(
168     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
169     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
170     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
171     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
172     magma_int_t nelem, magma_queue_t queue);
173 
174   magma_int_t magma_gradt_3d(
175     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
176     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
177     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
178     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
179     magma_int_t nelem, magma_queue_t queue);
180 
181   magma_int_t magma_grad_generic(
182     magma_int_t P, magma_int_t Q, magma_int_t dim, magma_int_t ncomp,
183     const CeedScalar* dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
184     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
185     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
186     magma_int_t nelem, magma_queue_t queue);
187 
188   magma_int_t magma_grad(
189     magma_int_t P, magma_int_t Q, magma_int_t dim, magma_int_t ncomp,
190     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
191     const CeedScalar *dU, magma_int_t u_elemstride, magma_int_t cstrdU, magma_int_t dstrdU,
192     CeedScalar *dV, magma_int_t v_elemstride, magma_int_t cstrdV, magma_int_t dstrdV,
193     magma_int_t nelem, magma_kernel_mode_t kernel_mode, magma_queue_t queue);
194 
195   magma_int_t magma_weight_1d(
196     magma_int_t Q, const CeedScalar *dqweight1d,
197     CeedScalar *dV, magma_int_t v_stride,
198     magma_int_t nelem, magma_queue_t queue);
199 
200   magma_int_t magma_weight_2d(
201     magma_int_t Q, const CeedScalar *dqweight1d,
202     CeedScalar *dV, magma_int_t v_stride,
203     magma_int_t nelem, magma_queue_t queue);
204 
205   magma_int_t magma_weight_3d(
206     magma_int_t Q, const CeedScalar *dqweight1d,
207     CeedScalar *dV, magma_int_t v_stride,
208     magma_int_t nelem, magma_queue_t queue);
209 
210   magma_int_t magma_weight_generic(
211     magma_int_t Q, magma_int_t dim,
212     const CeedScalar *dqweight1d,
213     CeedScalar *dV, magma_int_t vstride,
214     magma_int_t nelem, magma_queue_t queue);
215 
216   magma_int_t magma_weight(
217     magma_int_t Q, magma_int_t dim,
218     const CeedScalar *dqweight1d,
219     CeedScalar *dV, magma_int_t v_stride,
220     magma_int_t nelem, magma_kernel_mode_t kernel_mode, magma_queue_t queue);
221 
222   void magma_weight_nontensor(magma_int_t grid, magma_int_t threads, magma_int_t nelem,
223                               magma_int_t Q,
224                               CeedScalar *dqweight, CeedScalar *dv, magma_queue_t queue);
225 
226   int magma_dgemm_nontensor(
227     magma_trans_t transA, magma_trans_t transB,
228     magma_int_t m, magma_int_t n, magma_int_t k,
229     double alpha, const double *dA, magma_int_t ldda,
230     const double *dB, magma_int_t lddb,
231     double beta,  double *dC, magma_int_t lddc,
232     magma_queue_t queue );
233 
234   int magma_sgemm_nontensor(
235     magma_trans_t transA, magma_trans_t transB,
236     magma_int_t m, magma_int_t n, magma_int_t k,
237     float alpha, const float *dA, magma_int_t ldda,
238     const float *dB, magma_int_t lddb,
239     float beta,  float *dC, magma_int_t lddc,
240     magma_queue_t queue );
241 
242   magma_int_t
243   magma_isdevptr(const void *A);
244 
245   int CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P1d,
246                                     CeedInt Q1d,
247                                     const CeedScalar *interp1d,
248                                     const CeedScalar *grad1d,
249                                     const CeedScalar *qref1d,
250                                     const CeedScalar *qweight1d,
251                                     CeedBasis basis);
252 
253   int CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim,
254                               CeedInt ndof, CeedInt nqpts,
255                               const CeedScalar *interp,
256                               const CeedScalar *grad,
257                               const CeedScalar *qref,
258                               const CeedScalar *qweight,
259                               CeedBasis basis);
260 
261   int CeedElemRestrictionCreate_Magma(CeedMemType mtype,
262                                       CeedCopyMode cmode,
263                                       const CeedInt *offsets,
264                                       CeedElemRestriction r);
265 
266   int CeedElemRestrictionCreateBlocked_Magma(const CeedMemType mtype,
267       const CeedCopyMode cmode,
268       const CeedInt *offsets,
269       const CeedElemRestriction res);
270 
271   int CeedOperatorCreate_Magma(CeedOperator op);
272 
273   #ifdef __cplusplus
274 }
275   #endif
276 
277 // comment the line below to use the default magma_is_devptr function
278 #define magma_is_devptr magma_isdevptr
279 
280 // if magma and cuda/ref are using the null stream, then ceed_magma_queue_sync
281 // should do nothing
282 #define ceed_magma_queue_sync(...)
283 
284 // batch stride, override using -DMAGMA_BATCH_STRIDE=<desired-value>
285 #ifndef MAGMA_BATCH_STRIDE
286 #define MAGMA_BATCH_STRIDE (1000)
287 #endif
288 
289 #endif  // _ceed_magma_h
290