xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma.h (revision c42f38b112c169aa5035aee9bc56d3ad72b21cee)
13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
34444f328STzanio //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
54444f328STzanio //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
74444f328STzanio 
890104f39SStan Tomov // magma functions specific to ceed
9972b3d9dSNatalie Beams #ifndef _ceed_magma_h
103d576824SJeremy L Thompson #define _ceed_magma_h
1190104f39SStan Tomov 
12ec3da8bcSJed Brown #include <ceed/ceed.h>
13ec3da8bcSJed Brown #include <ceed/backend.h>
14e0582403Sabdelfattah83 #include <magma_v2.h>
15e0582403Sabdelfattah83 
16f6af633fSnbeams #define MAGMA_MAXTHREADS_1D 128
17f6af633fSnbeams #define MAGMA_MAXTHREADS_2D 128
18f6af633fSnbeams #define MAGMA_MAXTHREADS_3D 64
19f6af633fSnbeams // Define macro for determining number of threads in y-direction
20f6af633fSnbeams // for basis kernels
21f6af633fSnbeams #define MAGMA_BASIS_NTCOL(x, maxt) (((maxt) < (x)) ? 1 : ((maxt) / (x)))
22f6af633fSnbeams // Define macro for computing the total threads in a block
23f6af633fSnbeams // for use with __launch_bounds__()
24f6af633fSnbeams #define MAGMA_BASIS_BOUNDS(x, maxt) (x * MAGMA_BASIS_NTCOL(x, maxt))
25f6af633fSnbeams 
26e5f091ebSnbeams #ifdef CEED_MAGMA_USE_HIP
27*c42f38b1Snbeams typedef hipModule_t CeedMagmaModule;
28*c42f38b1Snbeams typedef hipFunction_t CeedMagmaFunction;
29*c42f38b1Snbeams #define CeedCompileMagma CeedCompileHip
30*c42f38b1Snbeams #define CeedGetKernelMagma CeedGetKernelHip
31*c42f38b1Snbeams #define CeedRunKernelMagma CeedRunKernelHip
32*c42f38b1Snbeams #define CeedRunKernelDimMagma CeedRunKernelDimHip
33*c42f38b1Snbeams #define CeedRunKernelDimSharedMagma CeedRunKernelDimSharedHip
34f6af633fSnbeams #else
35*c42f38b1Snbeams typedef CUmodule CeedMagmaModule;
36*c42f38b1Snbeams typedef CUfunction CeedMagmaFunction;
37*c42f38b1Snbeams #define CeedCompileMagma CeedCompileCuda
38*c42f38b1Snbeams #define CeedGetKernelMagma CeedGetKernelCuda
39*c42f38b1Snbeams #define CeedRunKernelMagma CeedRunKernelCuda
40*c42f38b1Snbeams #define CeedRunKernelDimMagma CeedRunKernelDimCuda
41*c42f38b1Snbeams #define CeedRunKernelDimSharedMagma CeedRunKernelDimSharedCuda
42f6af633fSnbeams #endif
43f6af633fSnbeams 
44e0582403Sabdelfattah83 typedef enum {
45e0582403Sabdelfattah83   MAGMA_KERNEL_DIM_GENERIC=101,
46e0582403Sabdelfattah83   MAGMA_KERNEL_DIM_SPECIFIC=102
47e0582403Sabdelfattah83 } magma_kernel_mode_t;
48e0582403Sabdelfattah83 
49e0582403Sabdelfattah83 typedef struct {
50e0582403Sabdelfattah83   magma_kernel_mode_t basis_kernel_mode;
51e0582403Sabdelfattah83   magma_device_t device;
52e0582403Sabdelfattah83   magma_queue_t queue;
53e0582403Sabdelfattah83 } Ceed_Magma;
545a9ca9adSVeselin Dobrev 
557f5b9731SStan Tomov typedef struct {
56*c42f38b1Snbeams   CeedMagmaModule module;
57*c42f38b1Snbeams   CeedMagmaFunction magma_interp;
58*c42f38b1Snbeams   CeedMagmaFunction magma_interp_tr;
59*c42f38b1Snbeams   CeedMagmaFunction magma_grad;
60*c42f38b1Snbeams   CeedMagmaFunction magma_grad_tr;
61*c42f38b1Snbeams   CeedMagmaFunction magma_weight;
627f5b9731SStan Tomov   CeedScalar *dqref1d;
637f5b9731SStan Tomov   CeedScalar *dinterp1d;
647f5b9731SStan Tomov   CeedScalar *dgrad1d;
657f5b9731SStan Tomov   CeedScalar *dqweight1d;
667f5b9731SStan Tomov } CeedBasis_Magma;
677f5b9731SStan Tomov 
687f5b9731SStan Tomov typedef struct {
69868539c2SNatalie Beams   CeedScalar *dqref;
70868539c2SNatalie Beams   CeedScalar *dinterp;
71868539c2SNatalie Beams   CeedScalar *dgrad;
72868539c2SNatalie Beams   CeedScalar *dqweight;
73868539c2SNatalie Beams } CeedBasisNonTensor_Magma;
74868539c2SNatalie Beams 
75c8b3a627SJed Brown typedef enum {
76c8b3a627SJed Brown   OWNED_NONE = 0,
77c8b3a627SJed Brown   OWNED_UNPINNED,
78c8b3a627SJed Brown   OWNED_PINNED,
79c8b3a627SJed Brown } OwnershipMode;
80c8b3a627SJed Brown 
81868539c2SNatalie Beams typedef struct {
82*c42f38b1Snbeams   CeedMagmaModule module;
83*c42f38b1Snbeams   CeedMagmaFunction StridedTranspose;
84*c42f38b1Snbeams   CeedMagmaFunction StridedNoTranspose;
85*c42f38b1Snbeams   CeedMagmaFunction OffsetTranspose;
86*c42f38b1Snbeams   CeedMagmaFunction OffsetNoTranspose;
87d655899aSNatalie Beams   CeedInt *offsets;
88d655899aSNatalie Beams   CeedInt *doffsets;
89c8b3a627SJed Brown   OwnershipMode own_;
90868539c2SNatalie Beams   int down_;            // cover a case where we own Device memory
91868539c2SNatalie Beams } CeedElemRestriction_Magma;
92868539c2SNatalie Beams 
93868539c2SNatalie Beams typedef struct {
947f5b9731SStan Tomov   const CeedScalar **inputs;
957f5b9731SStan Tomov   CeedScalar **outputs;
967f5b9731SStan Tomov   bool setupdone;
977f5b9731SStan Tomov } CeedQFunction_Magma;
987f5b9731SStan Tomov 
9990104f39SStan Tomov #define USE_MAGMA_BATCH
10097ee337cSStan Tomov #define USE_MAGMA_BATCH2
1017f5b9731SStan Tomov #define USE_MAGMA_BATCH3
1027f5b9731SStan Tomov #define USE_MAGMA_BATCH4
10390104f39SStan Tomov 
1047f5b9731SStan Tomov #ifdef __cplusplus
1057f5b9731SStan Tomov CEED_INTERN {
1067f5b9731SStan Tomov #endif
107e0582403Sabdelfattah83 
108e0582403Sabdelfattah83   magma_int_t magma_interp_1d(
109e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
110e0582403Sabdelfattah83     const CeedScalar *dT, CeedTransposeMode tmode,
111e0582403Sabdelfattah83     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
112e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
113f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
114e0582403Sabdelfattah83 
115e0582403Sabdelfattah83   magma_int_t magma_interp_2d(
116e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
117e0582403Sabdelfattah83     const CeedScalar *dT, CeedTransposeMode tmode,
118e0582403Sabdelfattah83     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
119e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
120f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
121e0582403Sabdelfattah83 
122e0582403Sabdelfattah83   magma_int_t magma_interp_3d(
123e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
124e0582403Sabdelfattah83     const CeedScalar *dT, CeedTransposeMode tmode,
125e0582403Sabdelfattah83     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
126e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
127f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
128e0582403Sabdelfattah83 
129e0582403Sabdelfattah83   magma_int_t magma_interp_generic(magma_int_t P, magma_int_t Q,
130868539c2SNatalie Beams                                    magma_int_t dim, magma_int_t ncomp,
13180a9ef05SNatalie Beams                                    const CeedScalar *dT, CeedTransposeMode tmode,
13280a9ef05SNatalie Beams                                    const CeedScalar *dU, magma_int_t u_elemstride,
133e0582403Sabdelfattah83                                    magma_int_t cstrdU,
13480a9ef05SNatalie Beams                                    CeedScalar *dV, magma_int_t v_elemstride,
135e0582403Sabdelfattah83                                    magma_int_t cstrdV,
136e0582403Sabdelfattah83                                    magma_int_t nelem, magma_queue_t queue);
1377f5b9731SStan Tomov 
138e0582403Sabdelfattah83   magma_int_t magma_interp(
139e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q,
140868539c2SNatalie Beams     magma_int_t dim, magma_int_t ncomp,
14180a9ef05SNatalie Beams     const CeedScalar *dT, CeedTransposeMode tmode,
14280a9ef05SNatalie Beams     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
14380a9ef05SNatalie Beams     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
144f71aa81bSnbeams     magma_int_t nelem, magma_kernel_mode_t kernel_mode, magma_queue_t queue);
1457f5b9731SStan Tomov 
146e0582403Sabdelfattah83   magma_int_t magma_grad_1d(
147e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
148e0582403Sabdelfattah83     const CeedScalar *dTinterp, const CeedScalar *dTgrad, CeedTransposeMode tmode,
149e0582403Sabdelfattah83     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU,
150e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV,
151f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
152868539c2SNatalie Beams 
153e0582403Sabdelfattah83   magma_int_t magma_gradn_2d(
154e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
155e0582403Sabdelfattah83     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
156e0582403Sabdelfattah83     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
157e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
158f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
159e0582403Sabdelfattah83 
160e0582403Sabdelfattah83   magma_int_t magma_gradt_2d(
161e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
162e0582403Sabdelfattah83     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
163e0582403Sabdelfattah83     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
164e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
165f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
166e0582403Sabdelfattah83 
167e0582403Sabdelfattah83   magma_int_t magma_gradn_3d(
168e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
169e0582403Sabdelfattah83     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
170e0582403Sabdelfattah83     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
171e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
172f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
173e0582403Sabdelfattah83 
174e0582403Sabdelfattah83   magma_int_t magma_gradt_3d(
175e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q, magma_int_t ncomp,
176e0582403Sabdelfattah83     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
177e0582403Sabdelfattah83     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
178e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
179f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
180e0582403Sabdelfattah83 
181e0582403Sabdelfattah83   magma_int_t magma_grad_generic(
182e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q, magma_int_t dim, magma_int_t ncomp,
183e0582403Sabdelfattah83     const CeedScalar* dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
184e0582403Sabdelfattah83     const CeedScalar *dU, magma_int_t estrdU, magma_int_t cstrdU, magma_int_t dstrdU,
185e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t estrdV, magma_int_t cstrdV, magma_int_t dstrdV,
186e0582403Sabdelfattah83     magma_int_t nelem, magma_queue_t queue);
187e0582403Sabdelfattah83 
188e0582403Sabdelfattah83   magma_int_t magma_grad(
189e0582403Sabdelfattah83     magma_int_t P, magma_int_t Q, magma_int_t dim, magma_int_t ncomp,
190e0582403Sabdelfattah83     const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, CeedTransposeMode tmode,
191e0582403Sabdelfattah83     const CeedScalar *dU, magma_int_t u_elemstride, magma_int_t cstrdU, magma_int_t dstrdU,
192e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t v_elemstride, magma_int_t cstrdV, magma_int_t dstrdV,
193f71aa81bSnbeams     magma_int_t nelem, magma_kernel_mode_t kernel_mode, magma_queue_t queue);
194e0582403Sabdelfattah83 
195e0582403Sabdelfattah83   magma_int_t magma_weight_1d(
196e0582403Sabdelfattah83     magma_int_t Q, const CeedScalar *dqweight1d,
197e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t v_stride,
198f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
199e0582403Sabdelfattah83 
200e0582403Sabdelfattah83   magma_int_t magma_weight_2d(
201e0582403Sabdelfattah83     magma_int_t Q, const CeedScalar *dqweight1d,
202e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t v_stride,
203f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
204e0582403Sabdelfattah83 
205e0582403Sabdelfattah83   magma_int_t magma_weight_3d(
206e0582403Sabdelfattah83     magma_int_t Q, const CeedScalar *dqweight1d,
207e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t v_stride,
208f71aa81bSnbeams     magma_int_t nelem, magma_queue_t queue);
209e0582403Sabdelfattah83 
210e0582403Sabdelfattah83   magma_int_t magma_weight_generic(
211e0582403Sabdelfattah83     magma_int_t Q, magma_int_t dim,
212e0582403Sabdelfattah83     const CeedScalar *dqweight1d,
213e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t vstride,
214e0582403Sabdelfattah83     magma_int_t nelem, magma_queue_t queue);
215e0582403Sabdelfattah83 
216e0582403Sabdelfattah83   magma_int_t magma_weight(
217e0582403Sabdelfattah83     magma_int_t Q, magma_int_t dim,
218e0582403Sabdelfattah83     const CeedScalar *dqweight1d,
219e0582403Sabdelfattah83     CeedScalar *dV, magma_int_t v_stride,
220f71aa81bSnbeams     magma_int_t nelem, magma_kernel_mode_t kernel_mode, magma_queue_t queue);
221e0582403Sabdelfattah83 
222e0582403Sabdelfattah83   void magma_weight_nontensor(magma_int_t grid, magma_int_t threads, magma_int_t nelem,
223868539c2SNatalie Beams                               magma_int_t Q,
22480a9ef05SNatalie Beams                               CeedScalar *dqweight, CeedScalar *dv, magma_queue_t queue);
225e0582403Sabdelfattah83 
226e0582403Sabdelfattah83   int magma_dgemm_nontensor(
227e0582403Sabdelfattah83     magma_trans_t transA, magma_trans_t transB,
228e0582403Sabdelfattah83     magma_int_t m, magma_int_t n, magma_int_t k,
229e0582403Sabdelfattah83     double alpha, const double *dA, magma_int_t ldda,
230e0582403Sabdelfattah83     const double *dB, magma_int_t lddb,
231e0582403Sabdelfattah83     double beta,  double *dC, magma_int_t lddc,
232e0582403Sabdelfattah83     magma_queue_t queue );
233e0582403Sabdelfattah83 
23480a9ef05SNatalie Beams   int magma_sgemm_nontensor(
23580a9ef05SNatalie Beams     magma_trans_t transA, magma_trans_t transB,
23680a9ef05SNatalie Beams     magma_int_t m, magma_int_t n, magma_int_t k,
23780a9ef05SNatalie Beams     float alpha, const float *dA, magma_int_t ldda,
23880a9ef05SNatalie Beams     const float *dB, magma_int_t lddb,
23980a9ef05SNatalie Beams     float beta,  float *dC, magma_int_t lddc,
24080a9ef05SNatalie Beams     magma_queue_t queue );
24180a9ef05SNatalie Beams 
2427f5b9731SStan Tomov   magma_int_t
2437f5b9731SStan Tomov   magma_isdevptr(const void *A);
2447f5b9731SStan Tomov 
245868539c2SNatalie Beams   int CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P1d,
246868539c2SNatalie Beams                                     CeedInt Q1d,
247868539c2SNatalie Beams                                     const CeedScalar *interp1d,
248868539c2SNatalie Beams                                     const CeedScalar *grad1d,
249868539c2SNatalie Beams                                     const CeedScalar *qref1d,
250868539c2SNatalie Beams                                     const CeedScalar *qweight1d,
251868539c2SNatalie Beams                                     CeedBasis basis);
2527f5b9731SStan Tomov 
253868539c2SNatalie Beams   int CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim,
254d4f68153Sjeremylt                               CeedInt ndof, CeedInt nqpts,
255d4f68153Sjeremylt                               const CeedScalar *interp,
256d4f68153Sjeremylt                               const CeedScalar *grad,
257d4f68153Sjeremylt                               const CeedScalar *qref,
258d4f68153Sjeremylt                               const CeedScalar *qweight,
259d4f68153Sjeremylt                               CeedBasis basis);
260868539c2SNatalie Beams 
261868539c2SNatalie Beams   int CeedElemRestrictionCreate_Magma(CeedMemType mtype,
262868539c2SNatalie Beams                                       CeedCopyMode cmode,
263d655899aSNatalie Beams                                       const CeedInt *offsets,
264868539c2SNatalie Beams                                       CeedElemRestriction r);
265868539c2SNatalie Beams 
266868539c2SNatalie Beams   int CeedElemRestrictionCreateBlocked_Magma(const CeedMemType mtype,
267868539c2SNatalie Beams       const CeedCopyMode cmode,
268d655899aSNatalie Beams       const CeedInt *offsets,
269868539c2SNatalie Beams       const CeedElemRestriction res);
270a8c028e3SNatalie Beams 
271a8c028e3SNatalie Beams   int CeedOperatorCreate_Magma(CeedOperator op);
272a8c028e3SNatalie Beams 
2737f5b9731SStan Tomov   #ifdef __cplusplus
2747f5b9731SStan Tomov }
2757f5b9731SStan Tomov   #endif
2767f5b9731SStan Tomov 
2777f5b9731SStan Tomov // comment the line below to use the default magma_is_devptr function
2787f5b9731SStan Tomov #define magma_is_devptr magma_isdevptr
2797f5b9731SStan Tomov 
280e0582403Sabdelfattah83 // if magma and cuda/ref are using the null stream, then ceed_magma_queue_sync
281e0582403Sabdelfattah83 // should do nothing
282e0582403Sabdelfattah83 #define ceed_magma_queue_sync(...)
283e0582403Sabdelfattah83 
2847f5b9731SStan Tomov // batch stride, override using -DMAGMA_BATCH_STRIDE=<desired-value>
2857f5b9731SStan Tomov #ifndef MAGMA_BATCH_STRIDE
2867f5b9731SStan Tomov #define MAGMA_BATCH_STRIDE (1000)
2877f5b9731SStan Tomov #endif
288e0582403Sabdelfattah83 
2893d576824SJeremy L Thompson #endif  // _ceed_magma_h
290