xref: /libCEED/backends/magma/ceed-magma-basis.c (revision 7113573b6efd54558bb98b919dff5d6d8ffcff54)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <string.h>
12 
13 #include "ceed-magma.h"
14 #ifdef CEED_MAGMA_USE_HIP
15 #include "../hip/ceed-hip-common.h"
16 #include "../hip/ceed-hip-compile.h"
17 #else
18 #include "../cuda/ceed-cuda-common.h"
19 #include "../cuda/ceed-cuda-compile.h"
20 #endif
21 
22 #ifdef __cplusplus
23 CEED_INTERN "C"
24 #endif
25     int
26     CeedBasisApply_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
27   Ceed ceed;
28   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
29   CeedInt dim, ncomp, ndof;
30   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
31   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
32   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
33 
34   Ceed_Magma *data;
35   CeedCallBackend(CeedGetData(ceed, &data));
36 
37   const CeedScalar *du;
38   CeedScalar       *dv;
39   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
40   else CeedCheck(emode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
41   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
42 
43   CeedBasis_Magma *impl;
44   CeedCallBackend(CeedBasisGetData(basis, &impl));
45 
46   CeedInt P1d, Q1d;
47   CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P1d));
48   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q1d));
49 
50   CeedDebug256(ceed, 4, "[CeedBasisApply_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * CeedIntPow(P1d, dim), ncomp);
51 
52   if (tmode == CEED_TRANSPOSE) {
53     CeedSize length;
54     CeedCallBackend(CeedVectorGetLength(V, &length));
55     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
56       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
57     } else {
58       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
59     }
60     ceed_magma_queue_sync(data->queue);
61   }
62 
63   switch (emode) {
64     case CEED_EVAL_INTERP: {
65       CeedInt P = P1d, Q = Q1d;
66       if (tmode == CEED_TRANSPOSE) {
67         P = Q1d;
68         Q = P1d;
69       }
70 
71       // Define element sizes for dofs/quad
72       CeedInt elquadsize = CeedIntPow(Q1d, dim);
73       CeedInt eldofssize = CeedIntPow(P1d, dim);
74 
75       // E-vector ordering -------------- Q-vector ordering
76       //  component                        component
77       //    elem                             elem
78       //       node                            node
79 
80       // ---  Define strides for NOTRANSPOSE mode: ---
81       // Input (du) is E-vector, output (dv) is Q-vector
82 
83       // Element strides
84       CeedInt u_elstride = eldofssize;
85       CeedInt v_elstride = elquadsize;
86       // Component strides
87       CeedInt u_compstride = nelem * eldofssize;
88       CeedInt v_compstride = nelem * elquadsize;
89 
90       // ---  Swap strides for TRANSPOSE mode: ---
91       if (tmode == CEED_TRANSPOSE) {
92         // Input (du) is Q-vector, output (dv) is E-vector
93         // Element strides
94         v_elstride = eldofssize;
95         u_elstride = elquadsize;
96         // Component strides
97         v_compstride = nelem * eldofssize;
98         u_compstride = nelem * elquadsize;
99       }
100 
101       CeedInt nthreads = 1;
102       CeedInt ntcol    = 1;
103       CeedInt shmem    = 0;
104       CeedInt maxPQ    = CeedIntMax(P, Q);
105 
106       switch (dim) {
107         case 1:
108           nthreads = maxPQ;
109           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
110           shmem += sizeof(CeedScalar) * ntcol * (ncomp * (1 * P + 1 * Q));
111           shmem += sizeof(CeedScalar) * (P * Q);
112           break;
113         case 2:
114           nthreads = maxPQ;
115           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
116           shmem += P * Q * sizeof(CeedScalar);                // for sT
117           shmem += ntcol * (P * maxPQ * sizeof(CeedScalar));  // for reforming rU we need PxP, and for the intermediate output we need PxQ
118           break;
119         case 3:
120           nthreads = maxPQ * maxPQ;
121           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
122           shmem += sizeof(CeedScalar) * (P * Q);  // for sT
123           shmem += sizeof(CeedScalar) * ntcol *
124                    (CeedIntMax(P * P * maxPQ,
125                                P * Q * Q));  // rU needs P^2xP, the intermediate output needs max(P^2xQ,PQ^2)
126       }
127       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
128       void   *args[] = {&impl->dinterp1d, &du, &u_elstride, &u_compstride, &dv, &v_elstride, &v_compstride, &nelem};
129 
130       if (tmode == CEED_TRANSPOSE) {
131         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp_tr, grid, nthreads, ntcol, 1, shmem, args));
132       } else {
133         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp, grid, nthreads, ntcol, 1, shmem, args));
134       }
135     } break;
136     case CEED_EVAL_GRAD: {
137       CeedInt P = P1d, Q = Q1d;
138       // In CEED_NOTRANSPOSE mode:
139       // du is (P^dim x nc), column-major layout (nc = ncomp)
140       // dv is (Q^dim x nc x dim), column-major layout (nc = ncomp)
141       // In CEED_TRANSPOSE mode, the sizes of du and dv are switched.
142       if (tmode == CEED_TRANSPOSE) {
143         P = Q1d, Q = P1d;
144       }
145 
146       // Define element sizes for dofs/quad
147       CeedInt elquadsize = CeedIntPow(Q1d, dim);
148       CeedInt eldofssize = CeedIntPow(P1d, dim);
149 
150       // E-vector ordering -------------- Q-vector ordering
151       //                                  dim
152       //  component                        component
153       //    elem                              elem
154       //       node                            node
155 
156       // ---  Define strides for NOTRANSPOSE mode: ---
157       // Input (du) is E-vector, output (dv) is Q-vector
158 
159       // Element strides
160       CeedInt u_elstride = eldofssize;
161       CeedInt v_elstride = elquadsize;
162       // Component strides
163       CeedInt u_compstride = nelem * eldofssize;
164       CeedInt v_compstride = nelem * elquadsize;
165       // Dimension strides
166       CeedInt u_dimstride = 0;
167       CeedInt v_dimstride = nelem * elquadsize * ncomp;
168 
169       // ---  Swap strides for TRANSPOSE mode: ---
170       if (tmode == CEED_TRANSPOSE) {
171         // Input (du) is Q-vector, output (dv) is E-vector
172         // Element strides
173         v_elstride = eldofssize;
174         u_elstride = elquadsize;
175         // Component strides
176         v_compstride = nelem * eldofssize;
177         u_compstride = nelem * elquadsize;
178         // Dimension strides
179         v_dimstride = 0;
180         u_dimstride = nelem * elquadsize * ncomp;
181       }
182 
183       CeedInt nthreads = 1;
184       CeedInt ntcol    = 1;
185       CeedInt shmem    = 0;
186       CeedInt maxPQ    = CeedIntMax(P, Q);
187 
188       switch (dim) {
189         case 1:
190           nthreads = maxPQ;
191           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
192           shmem += sizeof(CeedScalar) * ntcol * (ncomp * (1 * P + 1 * Q));
193           shmem += sizeof(CeedScalar) * (P * Q);
194           break;
195         case 2:
196           nthreads = maxPQ;
197           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
198           shmem += sizeof(CeedScalar) * 2 * P * Q;            // for sTinterp and sTgrad
199           shmem += sizeof(CeedScalar) * ntcol * (P * maxPQ);  // for reforming rU we need PxP, and for the intermediate output we need PxQ
200           break;
201         case 3:
202           nthreads = maxPQ * maxPQ;
203           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
204           shmem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
205           shmem += sizeof(CeedScalar) * ntcol *
206                    CeedIntMax(P * P * P,
207                               (P * P * Q) + (P * Q * Q));  // rU needs P^2xP, the intermediate outputs need (P^2.Q + P.Q^2)
208       }
209       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
210       void   *args[] = {&impl->dinterp1d, &impl->dgrad1d, &du,          &u_elstride, &u_compstride, &u_dimstride, &dv,
211                         &v_elstride,      &v_compstride,  &v_dimstride, &nelem};
212 
213       if (tmode == CEED_TRANSPOSE) {
214         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad_tr, grid, nthreads, ntcol, 1, shmem, args));
215       } else {
216         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad, grid, nthreads, ntcol, 1, shmem, args));
217       }
218     } break;
219     case CEED_EVAL_WEIGHT: {
220       CeedCheck(tmode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
221       CeedInt Q          = Q1d;
222       CeedInt eldofssize = CeedIntPow(Q, dim);
223       CeedInt nthreads   = 1;
224       CeedInt ntcol      = 1;
225       CeedInt shmem      = 0;
226 
227       switch (dim) {
228         case 1:
229           nthreads = Q;
230           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
231           shmem += sizeof(CeedScalar) * Q;          // for dqweight1d
232           shmem += sizeof(CeedScalar) * ntcol * Q;  // for output
233           break;
234         case 2:
235           nthreads = Q;
236           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
237           shmem += sizeof(CeedScalar) * Q;  // for dqweight1d
238           break;
239         case 3:
240           nthreads = Q * Q;
241           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
242           shmem += sizeof(CeedScalar) * Q;  // for dqweight1d
243       }
244       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
245       void   *args[] = {&impl->dqweight1d, &dv, &eldofssize, &nelem};
246 
247       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_weight, grid, nthreads, ntcol, 1, shmem, args));
248     } break;
249     // LCOV_EXCL_START
250     case CEED_EVAL_DIV:
251       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
252     case CEED_EVAL_CURL:
253       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
254     case CEED_EVAL_NONE:
255       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
256       // LCOV_EXCL_STOP
257   }
258 
259   // must sync to ensure completeness
260   ceed_magma_queue_sync(data->queue);
261 
262   if (emode != CEED_EVAL_WEIGHT) {
263     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
264   }
265   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
266   return CEED_ERROR_SUCCESS;
267 }
268 
269 #ifdef __cplusplus
270 CEED_INTERN "C"
271 #endif
272     int
273     CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
274   Ceed ceed;
275   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
276 
277   Ceed_Magma *data;
278   CeedCallBackend(CeedGetData(ceed, &data));
279 
280   magma_int_t arch = magma_getdevice_arch();
281 
282   CeedInt dim, ncomp, ndof, nqpt;
283   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
284   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
285   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
286   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &nqpt));
287   const CeedScalar *du;
288   CeedScalar       *dv;
289   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
290   else CeedCheck(emode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
291   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
292 
293   CeedBasisNonTensor_Magma *impl;
294   CeedCallBackend(CeedBasisGetData(basis, &impl));
295 
296   CeedDebug256(ceed, 4, "[CeedBasisApplyNonTensor_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * ndof, ncomp);
297 
298   if (tmode == CEED_TRANSPOSE) {
299     CeedSize length;
300     CeedCallBackend(CeedVectorGetLength(V, &length));
301     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
302       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
303     } else {
304       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
305     }
306     ceed_magma_queue_sync(data->queue);
307   }
308 
309   CeedInt            P = ndof, Q = nqpt, N = nelem * ncomp;
310   CeedInt            NB = 1;
311   CeedMagmaFunction *interp, *grad;
312 
313   CeedInt Narray[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_N_VALUES};
314   CeedInt iN                                       = 0;
315   CeedInt diff                                     = abs(Narray[iN] - N);
316   for (CeedInt in = iN + 1; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
317     CeedInt idiff = abs(Narray[in] - N);
318     if (idiff < diff) {
319       iN   = in;
320       diff = idiff;
321     }
322   }
323 
324   NB     = nontensor_rtc_get_nb(arch, 'd', emode, tmode, P, Narray[iN], Q);
325   interp = (tmode == CEED_TRANSPOSE) ? &impl->magma_interp_tr_nontensor[iN] : &impl->magma_interp_nontensor[iN];
326   grad   = (tmode == CEED_TRANSPOSE) ? &impl->magma_grad_tr_nontensor[iN] : &impl->magma_grad_nontensor[iN];
327 
328   switch (emode) {
329     case CEED_EVAL_INTERP: {
330       CeedInt P = ndof, Q = nqpt;
331       if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
332         CeedInt M     = (tmode == CEED_TRANSPOSE) ? P : Q;
333         CeedInt K     = (tmode == CEED_TRANSPOSE) ? Q : P;
334         CeedInt ntcol = MAGMA_NONTENSOR_BASIS_NTCOL(M);
335         CeedInt shmem = 0, shmemA = 0, shmemB = 0;
336         shmemB += ntcol * K * NB * sizeof(CeedScalar);
337         shmemA += (tmode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
338         shmem = (tmode == CEED_TRANSPOSE) ? (shmemA + shmemB) : CeedIntMax(shmemA, shmemB);
339 
340         CeedInt       grid   = MAGMA_CEILDIV(MAGMA_CEILDIV(N, NB), ntcol);
341         magma_trans_t transA = (tmode == CEED_TRANSPOSE) ? MagmaNoTrans : MagmaTrans;
342         magma_trans_t transB = MagmaNoTrans;
343         CeedScalar    alpha = 1.0, beta = 0.0;
344 
345         void *args[] = {&transA, &transB, &N, &alpha, &impl->dinterp, &P, &du, &K, &beta, &dv, &M};
346         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, *interp, grid, M, ntcol, 1, shmem, args));
347       } else {
348         if (tmode == CEED_TRANSPOSE)
349           magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, impl->dinterp, P, du, Q, 0.0, dv, P, data->queue);
350         else magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, impl->dinterp, P, du, P, 0.0, dv, Q, data->queue);
351       }
352     } break;
353 
354     case CEED_EVAL_GRAD: {
355       CeedInt P = ndof, Q = nqpt;
356       if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
357         CeedInt M     = (tmode == CEED_TRANSPOSE) ? P : Q;
358         CeedInt K     = (tmode == CEED_TRANSPOSE) ? Q : P;
359         CeedInt ntcol = MAGMA_NONTENSOR_BASIS_NTCOL(M);
360         CeedInt shmem = 0, shmemA = 0, shmemB = 0;
361         shmemB += ntcol * K * NB * sizeof(CeedScalar);
362         shmemA += (tmode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
363         shmem = shmemA + shmemB;
364 
365         CeedInt       grid   = MAGMA_CEILDIV(MAGMA_CEILDIV(N, NB), ntcol);
366         magma_trans_t transA = (tmode == CEED_TRANSPOSE) ? MagmaNoTrans : MagmaTrans;
367         magma_trans_t transB = MagmaNoTrans;
368 
369         void *args[] = {&transA, &transB, &N, &impl->dgrad, &P, &du, &K, &dv, &M};
370         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, *grad, grid, M, ntcol, 1, shmem, args));
371       } else {
372         if (tmode == CEED_TRANSPOSE) {
373           CeedScalar beta = 0.0;
374           for (int d = 0; d < dim; d++) {
375             if (d > 0) beta = 1.0;
376             magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, impl->dgrad + d * P * Q, P, du + d * nelem * ncomp * Q, Q,
377                                  beta, dv, P, data->queue);
378           }
379         } else {
380           for (int d = 0; d < dim; d++)
381             magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, impl->dgrad + d * P * Q, P, du, P, 0.0,
382                                  dv + d * nelem * ncomp * Q, Q, data->queue);
383         }
384       }
385     } break;
386 
387     case CEED_EVAL_WEIGHT: {
388       CeedCheck(tmode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
389 
390       int elemsPerBlock = 1;  // basis->Q1d < 7 ? optElems[basis->Q1d] : 1;
391       int grid          = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0);
392       magma_weight_nontensor(grid, nqpt, nelem, nqpt, impl->dqweight, dv, data->queue);
393     } break;
394 
395     // LCOV_EXCL_START
396     case CEED_EVAL_DIV:
397       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
398     case CEED_EVAL_CURL:
399       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
400     case CEED_EVAL_NONE:
401       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
402       // LCOV_EXCL_STOP
403   }
404 
405   // must sync to ensure completeness
406   ceed_magma_queue_sync(data->queue);
407 
408   if (emode != CEED_EVAL_WEIGHT) {
409     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
410   }
411   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
412   return CEED_ERROR_SUCCESS;
413 }
414 
415 #ifdef __cplusplus
416 CEED_INTERN "C"
417 #endif
418     int
419     CeedBasisDestroy_Magma(CeedBasis basis) {
420   CeedBasis_Magma *impl;
421   CeedCallBackend(CeedBasisGetData(basis, &impl));
422 
423   CeedCallBackend(magma_free(impl->dqref1d));
424   CeedCallBackend(magma_free(impl->dinterp1d));
425   CeedCallBackend(magma_free(impl->dgrad1d));
426   CeedCallBackend(magma_free(impl->dqweight1d));
427   Ceed ceed;
428   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
429 #ifdef CEED_MAGMA_USE_HIP
430   CeedCallHip(ceed, hipModuleUnload(impl->module));
431 #else
432   CeedCallCuda(ceed, cuModuleUnload(impl->module));
433 #endif
434 
435   CeedCallBackend(CeedFree(&impl));
436 
437   return CEED_ERROR_SUCCESS;
438 }
439 
440 #ifdef __cplusplus
441 CEED_INTERN "C"
442 #endif
443     int
444     CeedBasisDestroyNonTensor_Magma(CeedBasis basis) {
445   CeedBasisNonTensor_Magma *impl;
446   CeedCallBackend(CeedBasisGetData(basis, &impl));
447 
448   CeedCallBackend(magma_free(impl->dqref));
449   CeedCallBackend(magma_free(impl->dinterp));
450   CeedCallBackend(magma_free(impl->dgrad));
451   CeedCallBackend(magma_free(impl->dqweight));
452   Ceed ceed;
453   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
454 #ifdef CEED_MAGMA_USE_HIP
455   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
456     CeedCallHip(ceed, hipModuleUnload(impl->module[in]));
457   }
458 #else
459   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
460     CeedCallCuda(ceed, cuModuleUnload(impl->module[in]));
461   }
462 #endif
463   CeedCallBackend(CeedFree(&impl));
464 
465   return CEED_ERROR_SUCCESS;
466 }
467 
468 #ifdef __cplusplus
469 CEED_INTERN "C"
470 #endif
471     int
472     CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P1d, CeedInt Q1d, const CeedScalar *interp1d, const CeedScalar *grad1d,
473                                   const CeedScalar *qref1d, const CeedScalar *qweight1d, CeedBasis basis) {
474   CeedBasis_Magma *impl;
475   CeedCallBackend(CeedCalloc(1, &impl));
476   Ceed ceed;
477   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
478 
479   // Check for supported parameters
480   CeedInt ncomp = 0;
481   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
482   Ceed_Magma *data;
483   CeedCallBackend(CeedGetData(ceed, &data));
484 
485   // Compile kernels
486   char *magma_common_path;
487   char *interp_path, *grad_path, *weight_path;
488   char *basis_kernel_source;
489   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_defs.h", &magma_common_path));
490   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
491   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
492   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_tensor.h", &magma_common_path));
493   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, magma_common_path, &basis_kernel_source));
494   char   *interp_name_base = "ceed/jit-source/magma/interp";
495   CeedInt interp_name_len  = strlen(interp_name_base) + 6;
496   char    interp_name[interp_name_len];
497   snprintf(interp_name, interp_name_len, "%s-%" CeedInt_FMT "d.h", interp_name_base, dim);
498   CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_name, &interp_path));
499   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
500   char   *grad_name_base = "ceed/jit-source/magma/grad";
501   CeedInt grad_name_len  = strlen(grad_name_base) + 6;
502   char    grad_name[grad_name_len];
503   snprintf(grad_name, grad_name_len, "%s-%" CeedInt_FMT "d.h", grad_name_base, dim);
504   CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_name, &grad_path));
505   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
506   char   *weight_name_base = "ceed/jit-source/magma/weight";
507   CeedInt weight_name_len  = strlen(weight_name_base) + 6;
508   char    weight_name[weight_name_len];
509   snprintf(weight_name, weight_name_len, "%s-%" CeedInt_FMT "d.h", weight_name_base, dim);
510   CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_name, &weight_path));
511   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_path, &basis_kernel_source));
512   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
513   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip
514   // data
515   Ceed delegate;
516   CeedCallBackend(CeedGetDelegate(ceed, &delegate));
517   CeedCallBackend(CeedCompileMagma(delegate, basis_kernel_source, &impl->module, 5, "DIM", dim, "NCOMP", ncomp, "P", P1d, "Q", Q1d, "MAXPQ",
518                                    CeedIntMax(P1d, Q1d)));
519 
520   // Kernel setup
521   switch (dim) {
522     case 1:
523       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->magma_interp));
524       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->magma_interp_tr));
525       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->magma_grad));
526       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->magma_grad_tr));
527       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->magma_weight));
528       break;
529     case 2:
530       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->magma_interp));
531       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->magma_interp_tr));
532       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->magma_grad));
533       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->magma_grad_tr));
534       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->magma_weight));
535       break;
536     case 3:
537       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->magma_interp));
538       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->magma_interp_tr));
539       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->magma_grad));
540       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->magma_grad_tr));
541       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->magma_weight));
542   }
543 
544   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma));
545   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma));
546 
547   // Copy qref1d to the GPU
548   CeedCallBackend(magma_malloc((void **)&impl->dqref1d, Q1d * sizeof(qref1d[0])));
549   magma_setvector(Q1d, sizeof(qref1d[0]), qref1d, 1, impl->dqref1d, 1, data->queue);
550 
551   // Copy interp1d to the GPU
552   CeedCallBackend(magma_malloc((void **)&impl->dinterp1d, Q1d * P1d * sizeof(interp1d[0])));
553   magma_setvector(Q1d * P1d, sizeof(interp1d[0]), interp1d, 1, impl->dinterp1d, 1, data->queue);
554 
555   // Copy grad1d to the GPU
556   CeedCallBackend(magma_malloc((void **)&impl->dgrad1d, Q1d * P1d * sizeof(grad1d[0])));
557   magma_setvector(Q1d * P1d, sizeof(grad1d[0]), grad1d, 1, impl->dgrad1d, 1, data->queue);
558 
559   // Copy qweight1d to the GPU
560   CeedCallBackend(magma_malloc((void **)&impl->dqweight1d, Q1d * sizeof(qweight1d[0])));
561   magma_setvector(Q1d, sizeof(qweight1d[0]), qweight1d, 1, impl->dqweight1d, 1, data->queue);
562 
563   CeedCallBackend(CeedBasisSetData(basis, impl));
564   CeedCallBackend(CeedFree(&magma_common_path));
565   CeedCallBackend(CeedFree(&interp_path));
566   CeedCallBackend(CeedFree(&grad_path));
567   CeedCallBackend(CeedFree(&weight_path));
568   CeedCallBackend(CeedFree(&basis_kernel_source));
569 
570   return CEED_ERROR_SUCCESS;
571 }
572 
573 #ifdef __cplusplus
574 CEED_INTERN "C"
575 #endif
576     int
577     CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt ndof, CeedInt nqpts, const CeedScalar *interp, const CeedScalar *grad,
578                             const CeedScalar *qref, const CeedScalar *qweight, CeedBasis basis) {
579   CeedBasisNonTensor_Magma *impl;
580   Ceed                      ceed;
581   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
582 
583   Ceed_Magma *data;
584   CeedCallBackend(CeedGetData(ceed, &data));
585   magma_int_t arch = magma_getdevice_arch();
586   CeedCallBackend(CeedCalloc(1, &impl));
587   // Compile kernels
588   char *magma_common_path;
589   char *interp_path, *grad_path;
590   char *basis_kernel_source;
591   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_defs.h", &magma_common_path));
592   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
593   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
594   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_nontensor.h", &magma_common_path));
595   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, magma_common_path, &basis_kernel_source));
596   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/interp-nontensor.h", &interp_path));
597   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
598   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/grad-nontensor.h", &grad_path));
599   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
600 
601   // tuning parameters for nb
602   CeedInt nb_interp_n[MAGMA_NONTENSOR_KERNEL_INSTANCES];
603   CeedInt nb_interp_t[MAGMA_NONTENSOR_KERNEL_INSTANCES];
604   CeedInt nb_grad_n[MAGMA_NONTENSOR_KERNEL_INSTANCES];
605   CeedInt nb_grad_t[MAGMA_NONTENSOR_KERNEL_INSTANCES];
606   CeedInt P = ndof, Q = nqpts;
607   CeedInt Narray[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_N_VALUES};
608   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
609     nb_interp_n[in] = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_INTERP, CEED_NOTRANSPOSE, P, Narray[in], Q);
610     nb_interp_t[in] = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_INTERP, CEED_TRANSPOSE, P, Narray[in], Q);
611     nb_grad_n[in]   = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_GRAD, CEED_NOTRANSPOSE, P, Narray[in], Q);
612     nb_grad_t[in]   = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_GRAD, CEED_TRANSPOSE, P, Narray[in], Q);
613   }
614 
615   // compile
616   Ceed delegate;
617   CeedCallBackend(CeedGetDelegate(ceed, &delegate));
618   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
619     CeedCallBackend(CeedCompileMagma(delegate, basis_kernel_source, &impl->module[in], 7, "DIM", dim, "P", P, "Q", Q, "NB_INTERP_N", nb_interp_n[in],
620                                      "NB_INTERP_T", nb_interp_t[in], "NB_GRAD_N", nb_grad_n[in], "NB_GRAD_T", nb_grad_t[in]));
621   }
622 
623   // get kernels
624   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
625     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_interp_nontensor_n", &impl->magma_interp_nontensor[in]));
626     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_interp_nontensor_t", &impl->magma_interp_tr_nontensor[in]));
627     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_grad_nontensor_n", &impl->magma_grad_nontensor[in]));
628     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_grad_nontensor_t", &impl->magma_grad_tr_nontensor[in]));
629   }
630 
631   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
632   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
633 
634   // Copy qref to the GPU
635   CeedCallBackend(magma_malloc((void **)&impl->dqref, nqpts * sizeof(qref[0])));
636   magma_setvector(nqpts, sizeof(qref[0]), qref, 1, impl->dqref, 1, data->queue);
637 
638   // Copy interp to the GPU
639   CeedCallBackend(magma_malloc((void **)&impl->dinterp, nqpts * ndof * sizeof(interp[0])));
640   magma_setvector(nqpts * ndof, sizeof(interp[0]), interp, 1, impl->dinterp, 1, data->queue);
641 
642   // Copy grad to the GPU
643   CeedCallBackend(magma_malloc((void **)&impl->dgrad, nqpts * ndof * dim * sizeof(grad[0])));
644   magma_setvector(nqpts * ndof * dim, sizeof(grad[0]), grad, 1, impl->dgrad, 1, data->queue);
645 
646   // Copy qweight to the GPU
647   CeedCallBackend(magma_malloc((void **)&impl->dqweight, nqpts * sizeof(qweight[0])));
648   magma_setvector(nqpts, sizeof(qweight[0]), qweight, 1, impl->dqweight, 1, data->queue);
649 
650   CeedCallBackend(CeedBasisSetData(basis, impl));
651   CeedCallBackend(CeedFree(&magma_common_path));
652   CeedCallBackend(CeedFree(&interp_path));
653   CeedCallBackend(CeedFree(&grad_path));
654   CeedCallBackend(CeedFree(&basis_kernel_source));
655   return CEED_ERROR_SUCCESS;
656 }
657