xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma-basis.c (revision 53f7acb178914a16137d9c91c84843f149b8f9af)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <string.h>
12 
13 #ifdef CEED_MAGMA_USE_HIP
14 #include "../hip/ceed-hip-common.h"
15 #include "../hip/ceed-hip-compile.h"
16 #else
17 #include "../cuda/ceed-cuda-common.h"
18 #include "../cuda/ceed-cuda-compile.h"
19 #endif
20 #include "ceed-magma-common.h"
21 #include "ceed-magma.h"
22 
23 #ifdef __cplusplus
24 CEED_INTERN "C"
25 #endif
26     int
27     CeedBasisApply_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
28   Ceed ceed;
29   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
30   CeedInt dim, ncomp, ndof;
31   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
32   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
33   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
34 
35   Ceed_Magma *data;
36   CeedCallBackend(CeedGetData(ceed, &data));
37 
38   const CeedScalar *du;
39   CeedScalar       *dv;
40   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
41   else CeedCheck(emode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
42   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
43 
44   CeedBasis_Magma *impl;
45   CeedCallBackend(CeedBasisGetData(basis, &impl));
46 
47   CeedInt P1d, Q1d;
48   CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P1d));
49   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q1d));
50 
51   CeedDebug256(ceed, 4, "[CeedBasisApply_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * CeedIntPow(P1d, dim), ncomp);
52 
53   if (tmode == CEED_TRANSPOSE) {
54     CeedSize length;
55     CeedCallBackend(CeedVectorGetLength(V, &length));
56     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
57       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
58     } else {
59       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
60     }
61     ceed_magma_queue_sync(data->queue);
62   }
63 
64   switch (emode) {
65     case CEED_EVAL_INTERP: {
66       CeedInt P = P1d, Q = Q1d;
67       if (tmode == CEED_TRANSPOSE) {
68         P = Q1d;
69         Q = P1d;
70       }
71 
72       // Define element sizes for dofs/quad
73       CeedInt elquadsize = CeedIntPow(Q1d, dim);
74       CeedInt eldofssize = CeedIntPow(P1d, dim);
75 
76       // E-vector ordering -------------- Q-vector ordering
77       //  component                        component
78       //    elem                             elem
79       //       node                            node
80 
81       // ---  Define strides for NOTRANSPOSE mode: ---
82       // Input (du) is E-vector, output (dv) is Q-vector
83 
84       // Element strides
85       CeedInt u_elstride = eldofssize;
86       CeedInt v_elstride = elquadsize;
87       // Component strides
88       CeedInt u_compstride = nelem * eldofssize;
89       CeedInt v_compstride = nelem * elquadsize;
90 
91       // ---  Swap strides for TRANSPOSE mode: ---
92       if (tmode == CEED_TRANSPOSE) {
93         // Input (du) is Q-vector, output (dv) is E-vector
94         // Element strides
95         v_elstride = eldofssize;
96         u_elstride = elquadsize;
97         // Component strides
98         v_compstride = nelem * eldofssize;
99         u_compstride = nelem * elquadsize;
100       }
101 
102       CeedInt nthreads = 1;
103       CeedInt ntcol    = 1;
104       CeedInt shmem    = 0;
105       CeedInt maxPQ    = CeedIntMax(P, Q);
106 
107       switch (dim) {
108         case 1:
109           nthreads = maxPQ;
110           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
111           shmem += sizeof(CeedScalar) * ntcol * (ncomp * (1 * P + 1 * Q));
112           shmem += sizeof(CeedScalar) * (P * Q);
113           break;
114         case 2:
115           nthreads = maxPQ;
116           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
117           shmem += P * Q * sizeof(CeedScalar);                // for sT
118           shmem += ntcol * (P * maxPQ * sizeof(CeedScalar));  // for reforming rU we need PxP, and for the intermediate output we need PxQ
119           break;
120         case 3:
121           nthreads = maxPQ * maxPQ;
122           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
123           shmem += sizeof(CeedScalar) * (P * Q);  // for sT
124           shmem += sizeof(CeedScalar) * ntcol *
125                    (CeedIntMax(P * P * maxPQ,
126                                P * Q * Q));  // rU needs P^2xP, the intermediate output needs max(P^2xQ,PQ^2)
127       }
128       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
129       void   *args[] = {&impl->dinterp1d, &du, &u_elstride, &u_compstride, &dv, &v_elstride, &v_compstride, &nelem};
130 
131       if (tmode == CEED_TRANSPOSE) {
132         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp_tr, grid, nthreads, ntcol, 1, shmem, args));
133       } else {
134         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp, grid, nthreads, ntcol, 1, shmem, args));
135       }
136     } break;
137     case CEED_EVAL_GRAD: {
138       CeedInt P = P1d, Q = Q1d;
139       // In CEED_NOTRANSPOSE mode:
140       // du is (P^dim x nc), column-major layout (nc = ncomp)
141       // dv is (Q^dim x nc x dim), column-major layout (nc = ncomp)
142       // In CEED_TRANSPOSE mode, the sizes of du and dv are switched.
143       if (tmode == CEED_TRANSPOSE) {
144         P = Q1d, Q = P1d;
145       }
146 
147       // Define element sizes for dofs/quad
148       CeedInt elquadsize = CeedIntPow(Q1d, dim);
149       CeedInt eldofssize = CeedIntPow(P1d, dim);
150 
151       // E-vector ordering -------------- Q-vector ordering
152       //                                  dim
153       //  component                        component
154       //    elem                              elem
155       //       node                            node
156 
157       // ---  Define strides for NOTRANSPOSE mode: ---
158       // Input (du) is E-vector, output (dv) is Q-vector
159 
160       // Element strides
161       CeedInt u_elstride = eldofssize;
162       CeedInt v_elstride = elquadsize;
163       // Component strides
164       CeedInt u_compstride = nelem * eldofssize;
165       CeedInt v_compstride = nelem * elquadsize;
166       // Dimension strides
167       CeedInt u_dimstride = 0;
168       CeedInt v_dimstride = nelem * elquadsize * ncomp;
169 
170       // ---  Swap strides for TRANSPOSE mode: ---
171       if (tmode == CEED_TRANSPOSE) {
172         // Input (du) is Q-vector, output (dv) is E-vector
173         // Element strides
174         v_elstride = eldofssize;
175         u_elstride = elquadsize;
176         // Component strides
177         v_compstride = nelem * eldofssize;
178         u_compstride = nelem * elquadsize;
179         // Dimension strides
180         v_dimstride = 0;
181         u_dimstride = nelem * elquadsize * ncomp;
182       }
183 
184       CeedInt nthreads = 1;
185       CeedInt ntcol    = 1;
186       CeedInt shmem    = 0;
187       CeedInt maxPQ    = CeedIntMax(P, Q);
188 
189       switch (dim) {
190         case 1:
191           nthreads = maxPQ;
192           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
193           shmem += sizeof(CeedScalar) * ntcol * (ncomp * (1 * P + 1 * Q));
194           shmem += sizeof(CeedScalar) * (P * Q);
195           break;
196         case 2:
197           nthreads = maxPQ;
198           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
199           shmem += sizeof(CeedScalar) * 2 * P * Q;            // for sTinterp and sTgrad
200           shmem += sizeof(CeedScalar) * ntcol * (P * maxPQ);  // for reforming rU we need PxP, and for the intermediate output we need PxQ
201           break;
202         case 3:
203           nthreads = maxPQ * maxPQ;
204           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
205           shmem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
206           shmem += sizeof(CeedScalar) * ntcol *
207                    CeedIntMax(P * P * P,
208                               (P * P * Q) + (P * Q * Q));  // rU needs P^2xP, the intermediate outputs need (P^2.Q + P.Q^2)
209       }
210       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
211       void   *args[] = {&impl->dinterp1d, &impl->dgrad1d, &du,          &u_elstride, &u_compstride, &u_dimstride, &dv,
212                         &v_elstride,      &v_compstride,  &v_dimstride, &nelem};
213 
214       if (tmode == CEED_TRANSPOSE) {
215         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad_tr, grid, nthreads, ntcol, 1, shmem, args));
216       } else {
217         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad, grid, nthreads, ntcol, 1, shmem, args));
218       }
219     } break;
220     case CEED_EVAL_WEIGHT: {
221       CeedCheck(tmode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
222       CeedInt Q          = Q1d;
223       CeedInt eldofssize = CeedIntPow(Q, dim);
224       CeedInt nthreads   = 1;
225       CeedInt ntcol      = 1;
226       CeedInt shmem      = 0;
227 
228       switch (dim) {
229         case 1:
230           nthreads = Q;
231           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
232           shmem += sizeof(CeedScalar) * Q;          // for dqweight1d
233           shmem += sizeof(CeedScalar) * ntcol * Q;  // for output
234           break;
235         case 2:
236           nthreads = Q;
237           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
238           shmem += sizeof(CeedScalar) * Q;  // for dqweight1d
239           break;
240         case 3:
241           nthreads = Q * Q;
242           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
243           shmem += sizeof(CeedScalar) * Q;  // for dqweight1d
244       }
245       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
246       void   *args[] = {&impl->dqweight1d, &dv, &eldofssize, &nelem};
247 
248       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_weight, grid, nthreads, ntcol, 1, shmem, args));
249     } break;
250     // LCOV_EXCL_START
251     case CEED_EVAL_DIV:
252       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
253     case CEED_EVAL_CURL:
254       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
255     case CEED_EVAL_NONE:
256       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
257       // LCOV_EXCL_STOP
258   }
259 
260   // must sync to ensure completeness
261   ceed_magma_queue_sync(data->queue);
262 
263   if (emode != CEED_EVAL_WEIGHT) {
264     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
265   }
266   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
267   return CEED_ERROR_SUCCESS;
268 }
269 
270 #ifdef __cplusplus
271 CEED_INTERN "C"
272 #endif
273     int
274     CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
275   Ceed ceed;
276   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
277 
278   Ceed_Magma *data;
279   CeedCallBackend(CeedGetData(ceed, &data));
280 
281   magma_int_t arch = magma_getdevice_arch();
282 
283   CeedInt dim, ncomp, ndof, nqpt;
284   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
285   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
286   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
287   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &nqpt));
288   const CeedScalar *du;
289   CeedScalar       *dv;
290   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
291   else CeedCheck(emode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
292   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
293 
294   CeedBasisNonTensor_Magma *impl;
295   CeedCallBackend(CeedBasisGetData(basis, &impl));
296 
297   CeedDebug256(ceed, 4, "[CeedBasisApplyNonTensor_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * ndof, ncomp);
298 
299   if (tmode == CEED_TRANSPOSE) {
300     CeedSize length;
301     CeedCallBackend(CeedVectorGetLength(V, &length));
302     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
303       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
304     } else {
305       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
306     }
307     ceed_magma_queue_sync(data->queue);
308   }
309 
310   CeedInt            P = ndof, Q = nqpt, N = nelem * ncomp;
311   CeedInt            NB = 1;
312   CeedMagmaFunction *interp, *grad;
313 
314   CeedInt Narray[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_N_VALUES};
315   CeedInt iN                                       = 0;
316   CeedInt diff                                     = abs(Narray[iN] - N);
317   for (CeedInt in = iN + 1; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
318     CeedInt idiff = abs(Narray[in] - N);
319     if (idiff < diff) {
320       iN   = in;
321       diff = idiff;
322     }
323   }
324 
325   NB     = nontensor_rtc_get_nb(arch, 'd', emode, tmode, P, Narray[iN], Q);
326   interp = (tmode == CEED_TRANSPOSE) ? &impl->magma_interp_tr_nontensor[iN] : &impl->magma_interp_nontensor[iN];
327   grad   = (tmode == CEED_TRANSPOSE) ? &impl->magma_grad_tr_nontensor[iN] : &impl->magma_grad_nontensor[iN];
328 
329   switch (emode) {
330     case CEED_EVAL_INTERP: {
331       CeedInt P = ndof, Q = nqpt;
332       if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
333         CeedInt M     = (tmode == CEED_TRANSPOSE) ? P : Q;
334         CeedInt K     = (tmode == CEED_TRANSPOSE) ? Q : P;
335         CeedInt ntcol = MAGMA_NONTENSOR_BASIS_NTCOL(M);
336         CeedInt shmem = 0, shmemA = 0, shmemB = 0;
337         shmemB += ntcol * K * NB * sizeof(CeedScalar);
338         shmemA += (tmode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
339         shmem = (tmode == CEED_TRANSPOSE) ? (shmemA + shmemB) : CeedIntMax(shmemA, shmemB);
340 
341         CeedInt       grid   = MAGMA_CEILDIV(MAGMA_CEILDIV(N, NB), ntcol);
342         magma_trans_t transA = (tmode == CEED_TRANSPOSE) ? MagmaNoTrans : MagmaTrans;
343         magma_trans_t transB = MagmaNoTrans;
344         CeedScalar    alpha = 1.0, beta = 0.0;
345 
346         void *args[] = {&transA, &transB, &N, &alpha, &impl->dinterp, &P, &du, &K, &beta, &dv, &M};
347         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, *interp, grid, M, ntcol, 1, shmem, args));
348       } else {
349         if (tmode == CEED_TRANSPOSE)
350           magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, impl->dinterp, P, du, Q, 0.0, dv, P, data->queue);
351         else magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, impl->dinterp, P, du, P, 0.0, dv, Q, data->queue);
352       }
353     } break;
354 
355     case CEED_EVAL_GRAD: {
356       CeedInt P = ndof, Q = nqpt;
357       if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
358         CeedInt M     = (tmode == CEED_TRANSPOSE) ? P : Q;
359         CeedInt K     = (tmode == CEED_TRANSPOSE) ? Q : P;
360         CeedInt ntcol = MAGMA_NONTENSOR_BASIS_NTCOL(M);
361         CeedInt shmem = 0, shmemA = 0, shmemB = 0;
362         shmemB += ntcol * K * NB * sizeof(CeedScalar);
363         shmemA += (tmode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
364         shmem = shmemA + shmemB;
365 
366         CeedInt       grid   = MAGMA_CEILDIV(MAGMA_CEILDIV(N, NB), ntcol);
367         magma_trans_t transA = (tmode == CEED_TRANSPOSE) ? MagmaNoTrans : MagmaTrans;
368         magma_trans_t transB = MagmaNoTrans;
369 
370         void *args[] = {&transA, &transB, &N, &impl->dgrad, &P, &du, &K, &dv, &M};
371         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, *grad, grid, M, ntcol, 1, shmem, args));
372       } else {
373         if (tmode == CEED_TRANSPOSE) {
374           CeedScalar beta = 0.0;
375           for (int d = 0; d < dim; d++) {
376             if (d > 0) beta = 1.0;
377             magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, impl->dgrad + d * P * Q, P, du + d * nelem * ncomp * Q, Q,
378                                  beta, dv, P, data->queue);
379           }
380         } else {
381           for (int d = 0; d < dim; d++)
382             magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, impl->dgrad + d * P * Q, P, du, P, 0.0,
383                                  dv + d * nelem * ncomp * Q, Q, data->queue);
384         }
385       }
386     } break;
387 
388     case CEED_EVAL_WEIGHT: {
389       CeedCheck(tmode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
390 
391       int elemsPerBlock = 1;  // basis->Q1d < 7 ? optElems[basis->Q1d] : 1;
392       int grid          = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0);
393       magma_weight_nontensor(grid, nqpt, nelem, nqpt, impl->dqweight, dv, data->queue);
394     } break;
395 
396     // LCOV_EXCL_START
397     case CEED_EVAL_DIV:
398       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
399     case CEED_EVAL_CURL:
400       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
401     case CEED_EVAL_NONE:
402       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
403       // LCOV_EXCL_STOP
404   }
405 
406   // must sync to ensure completeness
407   ceed_magma_queue_sync(data->queue);
408 
409   if (emode != CEED_EVAL_WEIGHT) {
410     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
411   }
412   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
413   return CEED_ERROR_SUCCESS;
414 }
415 
416 #ifdef __cplusplus
417 CEED_INTERN "C"
418 #endif
419     int
420     CeedBasisDestroy_Magma(CeedBasis basis) {
421   CeedBasis_Magma *impl;
422   CeedCallBackend(CeedBasisGetData(basis, &impl));
423 
424   CeedCallBackend(magma_free(impl->dqref1d));
425   CeedCallBackend(magma_free(impl->dinterp1d));
426   CeedCallBackend(magma_free(impl->dgrad1d));
427   CeedCallBackend(magma_free(impl->dqweight1d));
428   Ceed ceed;
429   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
430 #ifdef CEED_MAGMA_USE_HIP
431   CeedCallHip(ceed, hipModuleUnload(impl->module));
432 #else
433   CeedCallCuda(ceed, cuModuleUnload(impl->module));
434 #endif
435 
436   CeedCallBackend(CeedFree(&impl));
437 
438   return CEED_ERROR_SUCCESS;
439 }
440 
441 #ifdef __cplusplus
442 CEED_INTERN "C"
443 #endif
444     int
445     CeedBasisDestroyNonTensor_Magma(CeedBasis basis) {
446   CeedBasisNonTensor_Magma *impl;
447   CeedCallBackend(CeedBasisGetData(basis, &impl));
448 
449   CeedCallBackend(magma_free(impl->dqref));
450   CeedCallBackend(magma_free(impl->dinterp));
451   CeedCallBackend(magma_free(impl->dgrad));
452   CeedCallBackend(magma_free(impl->dqweight));
453   Ceed ceed;
454   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
455 #ifdef CEED_MAGMA_USE_HIP
456   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
457     CeedCallHip(ceed, hipModuleUnload(impl->module[in]));
458   }
459 #else
460   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
461     CeedCallCuda(ceed, cuModuleUnload(impl->module[in]));
462   }
463 #endif
464   CeedCallBackend(CeedFree(&impl));
465 
466   return CEED_ERROR_SUCCESS;
467 }
468 
469 #ifdef __cplusplus
470 CEED_INTERN "C"
471 #endif
472     int
473     CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P1d, CeedInt Q1d, const CeedScalar *interp1d, const CeedScalar *grad1d,
474                                   const CeedScalar *qref1d, const CeedScalar *qweight1d, CeedBasis basis) {
475   CeedBasis_Magma *impl;
476   CeedCallBackend(CeedCalloc(1, &impl));
477   Ceed ceed;
478   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
479 
480   // Check for supported parameters
481   CeedInt ncomp = 0;
482   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
483   Ceed_Magma *data;
484   CeedCallBackend(CeedGetData(ceed, &data));
485 
486   // Compile kernels
487   char *magma_common_path;
488   char *interp_path, *grad_path, *weight_path;
489   char *basis_kernel_source;
490   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_defs.h", &magma_common_path));
491   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
492   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
493   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_tensor.h", &magma_common_path));
494   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, magma_common_path, &basis_kernel_source));
495   char   *interp_name_base = "ceed/jit-source/magma/interp";
496   CeedInt interp_name_len  = strlen(interp_name_base) + 6;
497   char    interp_name[interp_name_len];
498   snprintf(interp_name, interp_name_len, "%s-%" CeedInt_FMT "d.h", interp_name_base, dim);
499   CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_name, &interp_path));
500   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
501   char   *grad_name_base = "ceed/jit-source/magma/grad";
502   CeedInt grad_name_len  = strlen(grad_name_base) + 6;
503   char    grad_name[grad_name_len];
504   snprintf(grad_name, grad_name_len, "%s-%" CeedInt_FMT "d.h", grad_name_base, dim);
505   CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_name, &grad_path));
506   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
507   char   *weight_name_base = "ceed/jit-source/magma/weight";
508   CeedInt weight_name_len  = strlen(weight_name_base) + 6;
509   char    weight_name[weight_name_len];
510   snprintf(weight_name, weight_name_len, "%s-%" CeedInt_FMT "d.h", weight_name_base, dim);
511   CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_name, &weight_path));
512   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_path, &basis_kernel_source));
513   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
514   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip
515   // data
516   Ceed delegate;
517   CeedCallBackend(CeedGetDelegate(ceed, &delegate));
518   CeedCallBackend(CeedCompileMagma(delegate, basis_kernel_source, &impl->module, 5, "DIM", dim, "NCOMP", ncomp, "P", P1d, "Q", Q1d, "MAXPQ",
519                                    CeedIntMax(P1d, Q1d)));
520 
521   // Kernel setup
522   switch (dim) {
523     case 1:
524       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->magma_interp));
525       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->magma_interp_tr));
526       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->magma_grad));
527       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->magma_grad_tr));
528       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->magma_weight));
529       break;
530     case 2:
531       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->magma_interp));
532       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->magma_interp_tr));
533       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->magma_grad));
534       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->magma_grad_tr));
535       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->magma_weight));
536       break;
537     case 3:
538       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->magma_interp));
539       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->magma_interp_tr));
540       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->magma_grad));
541       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->magma_grad_tr));
542       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->magma_weight));
543   }
544 
545   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma));
546   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma));
547 
548   // Copy qref1d to the GPU
549   CeedCallBackend(magma_malloc((void **)&impl->dqref1d, Q1d * sizeof(qref1d[0])));
550   magma_setvector(Q1d, sizeof(qref1d[0]), qref1d, 1, impl->dqref1d, 1, data->queue);
551 
552   // Copy interp1d to the GPU
553   CeedCallBackend(magma_malloc((void **)&impl->dinterp1d, Q1d * P1d * sizeof(interp1d[0])));
554   magma_setvector(Q1d * P1d, sizeof(interp1d[0]), interp1d, 1, impl->dinterp1d, 1, data->queue);
555 
556   // Copy grad1d to the GPU
557   CeedCallBackend(magma_malloc((void **)&impl->dgrad1d, Q1d * P1d * sizeof(grad1d[0])));
558   magma_setvector(Q1d * P1d, sizeof(grad1d[0]), grad1d, 1, impl->dgrad1d, 1, data->queue);
559 
560   // Copy qweight1d to the GPU
561   CeedCallBackend(magma_malloc((void **)&impl->dqweight1d, Q1d * sizeof(qweight1d[0])));
562   magma_setvector(Q1d, sizeof(qweight1d[0]), qweight1d, 1, impl->dqweight1d, 1, data->queue);
563 
564   CeedCallBackend(CeedBasisSetData(basis, impl));
565   CeedCallBackend(CeedFree(&magma_common_path));
566   CeedCallBackend(CeedFree(&interp_path));
567   CeedCallBackend(CeedFree(&grad_path));
568   CeedCallBackend(CeedFree(&weight_path));
569   CeedCallBackend(CeedFree(&basis_kernel_source));
570 
571   return CEED_ERROR_SUCCESS;
572 }
573 
574 #ifdef __cplusplus
575 CEED_INTERN "C"
576 #endif
577     int
578     CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt ndof, CeedInt nqpts, const CeedScalar *interp, const CeedScalar *grad,
579                             const CeedScalar *qref, const CeedScalar *qweight, CeedBasis basis) {
580   CeedBasisNonTensor_Magma *impl;
581   Ceed                      ceed;
582   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
583 
584   Ceed_Magma *data;
585   CeedCallBackend(CeedGetData(ceed, &data));
586   magma_int_t arch = magma_getdevice_arch();
587   CeedCallBackend(CeedCalloc(1, &impl));
588   // Compile kernels
589   char *magma_common_path;
590   char *interp_path, *grad_path;
591   char *basis_kernel_source;
592   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_defs.h", &magma_common_path));
593   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
594   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
595   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_nontensor.h", &magma_common_path));
596   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, magma_common_path, &basis_kernel_source));
597   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/interp-nontensor.h", &interp_path));
598   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
599   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/grad-nontensor.h", &grad_path));
600   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
601 
602   // tuning parameters for nb
603   CeedInt nb_interp_n[MAGMA_NONTENSOR_KERNEL_INSTANCES];
604   CeedInt nb_interp_t[MAGMA_NONTENSOR_KERNEL_INSTANCES];
605   CeedInt nb_grad_n[MAGMA_NONTENSOR_KERNEL_INSTANCES];
606   CeedInt nb_grad_t[MAGMA_NONTENSOR_KERNEL_INSTANCES];
607   CeedInt P = ndof, Q = nqpts;
608   CeedInt Narray[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_N_VALUES};
609   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
610     nb_interp_n[in] = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_INTERP, CEED_NOTRANSPOSE, P, Narray[in], Q);
611     nb_interp_t[in] = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_INTERP, CEED_TRANSPOSE, P, Narray[in], Q);
612     nb_grad_n[in]   = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_GRAD, CEED_NOTRANSPOSE, P, Narray[in], Q);
613     nb_grad_t[in]   = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_GRAD, CEED_TRANSPOSE, P, Narray[in], Q);
614   }
615 
616   // compile
617   Ceed delegate;
618   CeedCallBackend(CeedGetDelegate(ceed, &delegate));
619   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
620     CeedCallBackend(CeedCompileMagma(delegate, basis_kernel_source, &impl->module[in], 7, "DIM", dim, "P", P, "Q", Q, "NB_INTERP_N", nb_interp_n[in],
621                                      "NB_INTERP_T", nb_interp_t[in], "NB_GRAD_N", nb_grad_n[in], "NB_GRAD_T", nb_grad_t[in]));
622   }
623 
624   // get kernels
625   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
626     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_interp_nontensor_n", &impl->magma_interp_nontensor[in]));
627     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_interp_nontensor_t", &impl->magma_interp_tr_nontensor[in]));
628     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_grad_nontensor_n", &impl->magma_grad_nontensor[in]));
629     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_grad_nontensor_t", &impl->magma_grad_tr_nontensor[in]));
630   }
631 
632   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
633   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
634 
635   // Copy qref to the GPU
636   CeedCallBackend(magma_malloc((void **)&impl->dqref, nqpts * sizeof(qref[0])));
637   magma_setvector(nqpts, sizeof(qref[0]), qref, 1, impl->dqref, 1, data->queue);
638 
639   // Copy interp to the GPU
640   CeedCallBackend(magma_malloc((void **)&impl->dinterp, nqpts * ndof * sizeof(interp[0])));
641   magma_setvector(nqpts * ndof, sizeof(interp[0]), interp, 1, impl->dinterp, 1, data->queue);
642 
643   // Copy grad to the GPU
644   CeedCallBackend(magma_malloc((void **)&impl->dgrad, nqpts * ndof * dim * sizeof(grad[0])));
645   magma_setvector(nqpts * ndof * dim, sizeof(grad[0]), grad, 1, impl->dgrad, 1, data->queue);
646 
647   // Copy qweight to the GPU
648   CeedCallBackend(magma_malloc((void **)&impl->dqweight, nqpts * sizeof(qweight[0])));
649   magma_setvector(nqpts, sizeof(qweight[0]), qweight, 1, impl->dqweight, 1, data->queue);
650 
651   CeedCallBackend(CeedBasisSetData(basis, impl));
652   CeedCallBackend(CeedFree(&magma_common_path));
653   CeedCallBackend(CeedFree(&interp_path));
654   CeedCallBackend(CeedFree(&grad_path));
655   CeedCallBackend(CeedFree(&basis_kernel_source));
656   return CEED_ERROR_SUCCESS;
657 }
658