xref: /libCEED/backends/magma/ceed-magma-basis.c (revision 5fb68f377259d3910de46d787b7c5d1587fd01e1)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <ceed/jit-tools.h>
11 #include <string.h>
12 
13 #include "ceed-magma.h"
14 #ifdef CEED_MAGMA_USE_HIP
15 #include "../hip/ceed-hip-common.h"
16 #include "../hip/ceed-hip-compile.h"
17 #else
18 #include "../cuda/ceed-cuda-common.h"
19 #include "../cuda/ceed-cuda-compile.h"
20 #endif
21 
22 #ifdef __cplusplus
23 CEED_INTERN "C"
24 #endif
25     int
26     CeedBasisApply_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
27   Ceed ceed;
28   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
29   CeedInt dim, ncomp, ndof;
30   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
31   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
32   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
33 
34   Ceed_Magma *data;
35   CeedCallBackend(CeedGetData(ceed, &data));
36 
37   const CeedScalar *u;
38   CeedScalar       *v;
39   if (emode != CEED_EVAL_WEIGHT) {
40     CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &u));
41   } else if (emode != CEED_EVAL_WEIGHT) {
42     // LCOV_EXCL_START
43     return CeedError(ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
44     // LCOV_EXCL_STOP
45   }
46   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &v));
47 
48   CeedBasis_Magma *impl;
49   CeedCallBackend(CeedBasisGetData(basis, &impl));
50 
51   CeedInt P1d, Q1d;
52   CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P1d));
53   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q1d));
54 
55   CeedDebug256(ceed, 4, "[CeedBasisApply_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * CeedIntPow(P1d, dim), ncomp);
56 
57   if (tmode == CEED_TRANSPOSE) {
58     CeedSize length;
59     CeedCallBackend(CeedVectorGetLength(V, &length));
60     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
61       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)v, length, data->queue);
62     } else {
63       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)v, length, data->queue);
64     }
65     ceed_magma_queue_sync(data->queue);
66   }
67 
68   switch (emode) {
69     case CEED_EVAL_INTERP: {
70       CeedInt P = P1d, Q = Q1d;
71       if (tmode == CEED_TRANSPOSE) {
72         P = Q1d;
73         Q = P1d;
74       }
75 
76       // Define element sizes for dofs/quad
77       CeedInt elquadsize = CeedIntPow(Q1d, dim);
78       CeedInt eldofssize = CeedIntPow(P1d, dim);
79 
80       // E-vector ordering -------------- Q-vector ordering
81       //  component                        component
82       //    elem                             elem
83       //       node                            node
84 
85       // ---  Define strides for NOTRANSPOSE mode: ---
86       // Input (u) is E-vector, output (v) is Q-vector
87 
88       // Element strides
89       CeedInt u_elstride = eldofssize;
90       CeedInt v_elstride = elquadsize;
91       // Component strides
92       CeedInt u_compstride = nelem * eldofssize;
93       CeedInt v_compstride = nelem * elquadsize;
94 
95       // ---  Swap strides for TRANSPOSE mode: ---
96       if (tmode == CEED_TRANSPOSE) {
97         // Input (u) is Q-vector, output (v) is E-vector
98         // Element strides
99         v_elstride = eldofssize;
100         u_elstride = elquadsize;
101         // Component strides
102         v_compstride = nelem * eldofssize;
103         u_compstride = nelem * elquadsize;
104       }
105 
106       CeedInt nthreads = 1;
107       CeedInt ntcol    = 1;
108       CeedInt shmem    = 0;
109       CeedInt maxPQ    = CeedIntMax(P, Q);
110 
111       switch (dim) {
112         case 1:
113           nthreads = maxPQ;
114           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
115           shmem += sizeof(CeedScalar) * ntcol * (ncomp * (1 * P + 1 * Q));
116           shmem += sizeof(CeedScalar) * (P * Q);
117           break;
118         case 2:
119           nthreads = maxPQ;
120           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
121           shmem += P * Q * sizeof(CeedScalar);                // for sT
122           shmem += ntcol * (P * maxPQ * sizeof(CeedScalar));  // for reforming rU we need PxP, and for the intermediate output we need PxQ
123           break;
124         case 3:
125           nthreads = maxPQ * maxPQ;
126           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
127           shmem += sizeof(CeedScalar) * (P * Q);  // for sT
128           shmem += sizeof(CeedScalar) * ntcol *
129                    (CeedIntMax(P * P * maxPQ,
130                                P * Q * Q));  // rU needs P^2xP, the intermediate output needs max(P^2xQ,PQ^2)
131       }
132       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
133       void   *args[] = {&impl->dinterp1d, &u, &u_elstride, &u_compstride, &v, &v_elstride, &v_compstride, &nelem};
134 
135       if (tmode == CEED_TRANSPOSE) {
136         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp_tr, grid, nthreads, ntcol, 1, shmem, args));
137       } else {
138         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp, grid, nthreads, ntcol, 1, shmem, args));
139       }
140     } break;
141     case CEED_EVAL_GRAD: {
142       CeedInt P = P1d, Q = Q1d;
143       // In CEED_NOTRANSPOSE mode:
144       // u is (P^dim x nc), column-major layout (nc = ncomp)
145       // v is (Q^dim x nc x dim), column-major layout (nc = ncomp)
146       // In CEED_TRANSPOSE mode, the sizes of u and v are switched.
147       if (tmode == CEED_TRANSPOSE) {
148         P = Q1d, Q = P1d;
149       }
150 
151       // Define element sizes for dofs/quad
152       CeedInt elquadsize = CeedIntPow(Q1d, dim);
153       CeedInt eldofssize = CeedIntPow(P1d, dim);
154 
155       // E-vector ordering -------------- Q-vector ordering
156       //                                  dim
157       //  component                        component
158       //    elem                              elem
159       //       node                            node
160 
161       // ---  Define strides for NOTRANSPOSE mode: ---
162       // Input (u) is E-vector, output (v) is Q-vector
163 
164       // Element strides
165       CeedInt u_elstride = eldofssize;
166       CeedInt v_elstride = elquadsize;
167       // Component strides
168       CeedInt u_compstride = nelem * eldofssize;
169       CeedInt v_compstride = nelem * elquadsize;
170       // Dimension strides
171       CeedInt u_dimstride = 0;
172       CeedInt v_dimstride = nelem * elquadsize * ncomp;
173 
174       // ---  Swap strides for TRANSPOSE mode: ---
175       if (tmode == CEED_TRANSPOSE) {
176         // Input (u) is Q-vector, output (v) is E-vector
177         // Element strides
178         v_elstride = eldofssize;
179         u_elstride = elquadsize;
180         // Component strides
181         v_compstride = nelem * eldofssize;
182         u_compstride = nelem * elquadsize;
183         // Dimension strides
184         v_dimstride = 0;
185         u_dimstride = nelem * elquadsize * ncomp;
186       }
187 
188       CeedInt nthreads = 1;
189       CeedInt ntcol    = 1;
190       CeedInt shmem    = 0;
191       CeedInt maxPQ    = CeedIntMax(P, Q);
192 
193       switch (dim) {
194         case 1:
195           nthreads = maxPQ;
196           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
197           shmem += sizeof(CeedScalar) * ntcol * (ncomp * (1 * P + 1 * Q));
198           shmem += sizeof(CeedScalar) * (P * Q);
199           break;
200         case 2:
201           nthreads = maxPQ;
202           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
203           shmem += sizeof(CeedScalar) * 2 * P * Q;            // for sTinterp and sTgrad
204           shmem += sizeof(CeedScalar) * ntcol * (P * maxPQ);  // for reforming rU we need PxP, and for the intermediate output we need PxQ
205           break;
206         case 3:
207           nthreads = maxPQ * maxPQ;
208           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
209           shmem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
210           shmem += sizeof(CeedScalar) * ntcol *
211                    CeedIntMax(P * P * P,
212                               (P * P * Q) + (P * Q * Q));  // rU needs P^2xP, the intermediate outputs need (P^2.Q + P.Q^2)
213       }
214       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
215       void   *args[] = {&impl->dinterp1d, &impl->dgrad1d, &u,           &u_elstride, &u_compstride, &u_dimstride, &v,
216                         &v_elstride,      &v_compstride,  &v_dimstride, &nelem};
217 
218       if (tmode == CEED_TRANSPOSE) {
219         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad_tr, grid, nthreads, ntcol, 1, shmem, args));
220       } else {
221         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad, grid, nthreads, ntcol, 1, shmem, args));
222       }
223     } break;
224     case CEED_EVAL_WEIGHT: {
225       if (tmode == CEED_TRANSPOSE)
226         // LCOV_EXCL_START
227         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
228       // LCOV_EXCL_STOP
229       CeedInt Q          = Q1d;
230       CeedInt eldofssize = CeedIntPow(Q, dim);
231       CeedInt nthreads   = 1;
232       CeedInt ntcol      = 1;
233       CeedInt shmem      = 0;
234 
235       switch (dim) {
236         case 1:
237           nthreads = Q;
238           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
239           shmem += sizeof(CeedScalar) * Q;          // for dqweight1d
240           shmem += sizeof(CeedScalar) * ntcol * Q;  // for output
241           break;
242         case 2:
243           nthreads = Q;
244           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
245           shmem += sizeof(CeedScalar) * Q;  // for dqweight1d
246           break;
247         case 3:
248           nthreads = Q * Q;
249           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
250           shmem += sizeof(CeedScalar) * Q;  // for dqweight1d
251       }
252       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
253       void   *args[] = {&impl->dqweight1d, &v, &eldofssize, &nelem};
254 
255       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_weight, grid, nthreads, ntcol, 1, shmem, args));
256     } break;
257     // LCOV_EXCL_START
258     case CEED_EVAL_DIV:
259       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
260     case CEED_EVAL_CURL:
261       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
262     case CEED_EVAL_NONE:
263       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
264       // LCOV_EXCL_STOP
265   }
266 
267   // must sync to ensure completeness
268   ceed_magma_queue_sync(data->queue);
269 
270   if (emode != CEED_EVAL_WEIGHT) {
271     CeedCallBackend(CeedVectorRestoreArrayRead(U, &u));
272   }
273   CeedCallBackend(CeedVectorRestoreArray(V, &v));
274   return CEED_ERROR_SUCCESS;
275 }
276 
277 #ifdef __cplusplus
278 CEED_INTERN "C"
279 #endif
280     int
281     CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
282   Ceed ceed;
283   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
284 
285   Ceed_Magma *data;
286   CeedCallBackend(CeedGetData(ceed, &data));
287 
288   magma_int_t arch = magma_getdevice_arch();
289 
290   CeedInt dim, ncomp, ndof, nqpt;
291   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
292   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
293   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
294   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &nqpt));
295   const CeedScalar *du;
296   CeedScalar       *dv;
297   if (emode != CEED_EVAL_WEIGHT) {
298     CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
299   } else if (emode != CEED_EVAL_WEIGHT) {
300     // LCOV_EXCL_START
301     return CeedError(ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
302     // LCOV_EXCL_STOP
303   }
304   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
305 
306   CeedBasisNonTensor_Magma *impl;
307   CeedCallBackend(CeedBasisGetData(basis, &impl));
308 
309   CeedDebug256(ceed, 4, "[CeedBasisApplyNonTensor_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * ndof, ncomp);
310 
311   if (tmode == CEED_TRANSPOSE) {
312     CeedSize length;
313     CeedCallBackend(CeedVectorGetLength(V, &length));
314     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
315       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
316     } else {
317       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
318     }
319     ceed_magma_queue_sync(data->queue);
320   }
321 
322   CeedInt            P = ndof, Q = nqpt, N = nelem * ncomp;
323   CeedInt            NB = 1;
324   CeedMagmaFunction *interp, *grad;
325 
326   CeedInt Narray[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_N_VALUES};
327   CeedInt iN                                       = 0;
328   CeedInt diff                                     = abs(Narray[iN] - N);
329   for (CeedInt in = iN + 1; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
330     CeedInt idiff = abs(Narray[in] - N);
331     if (idiff < diff) {
332       iN   = in;
333       diff = idiff;
334     }
335   }
336 
337   NB     = nontensor_rtc_get_nb(arch, 'd', emode, tmode, P, Narray[iN], Q);
338   interp = (tmode == CEED_TRANSPOSE) ? &impl->magma_interp_tr_nontensor[iN] : &impl->magma_interp_nontensor[iN];
339   grad   = (tmode == CEED_TRANSPOSE) ? &impl->magma_grad_tr_nontensor[iN] : &impl->magma_grad_nontensor[iN];
340 
341   switch (emode) {
342     case CEED_EVAL_INTERP: {
343       CeedInt P = ndof, Q = nqpt;
344       if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
345         CeedInt M     = (tmode == CEED_TRANSPOSE) ? P : Q;
346         CeedInt K     = (tmode == CEED_TRANSPOSE) ? Q : P;
347         CeedInt ntcol = MAGMA_NONTENSOR_BASIS_NTCOL(M);
348         CeedInt shmem = 0, shmemA = 0, shmemB = 0;
349         shmemB += ntcol * K * NB * sizeof(CeedScalar);
350         shmemA += (tmode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
351         shmem = (tmode == CEED_TRANSPOSE) ? (shmemA + shmemB) : CeedIntMax(shmemA, shmemB);
352 
353         CeedInt       grid   = MAGMA_CEILDIV(MAGMA_CEILDIV(N, NB), ntcol);
354         magma_trans_t transA = (tmode == CEED_TRANSPOSE) ? MagmaNoTrans : MagmaTrans;
355         magma_trans_t transB = MagmaNoTrans;
356         CeedScalar    alpha = 1.0, beta = 0.0;
357 
358         void *args[] = {&transA, &transB, &N, &alpha, &impl->dinterp, &P, &du, &K, &beta, &dv, &M};
359         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, *interp, grid, M, ntcol, 1, shmem, args));
360       } else {
361         if (tmode == CEED_TRANSPOSE)
362           magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, impl->dinterp, P, du, Q, 0.0, dv, P, data->queue);
363         else magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, impl->dinterp, P, du, P, 0.0, dv, Q, data->queue);
364       }
365     } break;
366 
367     case CEED_EVAL_GRAD: {
368       CeedInt P = ndof, Q = nqpt;
369       if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
370         CeedInt M     = (tmode == CEED_TRANSPOSE) ? P : Q;
371         CeedInt K     = (tmode == CEED_TRANSPOSE) ? Q : P;
372         CeedInt ntcol = MAGMA_NONTENSOR_BASIS_NTCOL(M);
373         CeedInt shmem = 0, shmemA = 0, shmemB = 0;
374         shmemB += ntcol * K * NB * sizeof(CeedScalar);
375         shmemA += (tmode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
376         shmem = shmemA + shmemB;
377 
378         CeedInt       grid   = MAGMA_CEILDIV(MAGMA_CEILDIV(N, NB), ntcol);
379         magma_trans_t transA = (tmode == CEED_TRANSPOSE) ? MagmaNoTrans : MagmaTrans;
380         magma_trans_t transB = MagmaNoTrans;
381 
382         void *args[] = {&transA, &transB, &N, &impl->dgrad, &P, &du, &K, &dv, &M};
383         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, *grad, grid, M, ntcol, 1, shmem, args));
384       } else {
385         if (tmode == CEED_TRANSPOSE) {
386           CeedScalar beta = 0.0;
387           for (int d = 0; d < dim; d++) {
388             if (d > 0) beta = 1.0;
389             magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, impl->dgrad + d * P * Q, P, du + d * nelem * ncomp * Q, Q,
390                                  beta, dv, P, data->queue);
391           }
392         } else {
393           for (int d = 0; d < dim; d++)
394             magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, impl->dgrad + d * P * Q, P, du, P, 0.0,
395                                  dv + d * nelem * ncomp * Q, Q, data->queue);
396         }
397       }
398     } break;
399 
400     case CEED_EVAL_WEIGHT: {
401       if (tmode == CEED_TRANSPOSE)
402         // LCOV_EXCL_START
403         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
404       // LCOV_EXCL_STOP
405 
406       int elemsPerBlock = 1;  // basis->Q1d < 7 ? optElems[basis->Q1d] : 1;
407       int grid          = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0);
408       magma_weight_nontensor(grid, nqpt, nelem, nqpt, impl->dqweight, dv, data->queue);
409     } break;
410 
411     // LCOV_EXCL_START
412     case CEED_EVAL_DIV:
413       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
414     case CEED_EVAL_CURL:
415       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
416     case CEED_EVAL_NONE:
417       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
418       // LCOV_EXCL_STOP
419   }
420 
421   // must sync to ensure completeness
422   ceed_magma_queue_sync(data->queue);
423 
424   if (emode != CEED_EVAL_WEIGHT) {
425     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
426   }
427   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
428   return CEED_ERROR_SUCCESS;
429 }
430 
431 #ifdef __cplusplus
432 CEED_INTERN "C"
433 #endif
434     int
435     CeedBasisDestroy_Magma(CeedBasis basis) {
436   CeedBasis_Magma *impl;
437   CeedCallBackend(CeedBasisGetData(basis, &impl));
438 
439   CeedCallBackend(magma_free(impl->dqref1d));
440   CeedCallBackend(magma_free(impl->dinterp1d));
441   CeedCallBackend(magma_free(impl->dgrad1d));
442   CeedCallBackend(magma_free(impl->dqweight1d));
443   Ceed ceed;
444   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
445 #ifdef CEED_MAGMA_USE_HIP
446   CeedCallHip(ceed, hipModuleUnload(impl->module));
447 #else
448   CeedCallCuda(ceed, cuModuleUnload(impl->module));
449 #endif
450 
451   CeedCallBackend(CeedFree(&impl));
452 
453   return CEED_ERROR_SUCCESS;
454 }
455 
456 #ifdef __cplusplus
457 CEED_INTERN "C"
458 #endif
459     int
460     CeedBasisDestroyNonTensor_Magma(CeedBasis basis) {
461   CeedBasisNonTensor_Magma *impl;
462   CeedCallBackend(CeedBasisGetData(basis, &impl));
463 
464   CeedCallBackend(magma_free(impl->dqref));
465   CeedCallBackend(magma_free(impl->dinterp));
466   CeedCallBackend(magma_free(impl->dgrad));
467   CeedCallBackend(magma_free(impl->dqweight));
468   Ceed ceed;
469   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
470 #ifdef CEED_MAGMA_USE_HIP
471   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
472     CeedCallHip(ceed, hipModuleUnload(impl->module[in]));
473   }
474 #else
475   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
476     CeedCallCuda(ceed, cuModuleUnload(impl->module[in]));
477   }
478 #endif
479   CeedCallBackend(CeedFree(&impl));
480 
481   return CEED_ERROR_SUCCESS;
482 }
483 
484 #ifdef __cplusplus
485 CEED_INTERN "C"
486 #endif
487     int
488     CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P1d, CeedInt Q1d, const CeedScalar *interp1d, const CeedScalar *grad1d,
489                                   const CeedScalar *qref1d, const CeedScalar *qweight1d, CeedBasis basis) {
490   CeedBasis_Magma *impl;
491   CeedCallBackend(CeedCalloc(1, &impl));
492   Ceed ceed;
493   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
494 
495   // Check for supported parameters
496   CeedInt ncomp = 0;
497   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
498   Ceed_Magma *data;
499   CeedCallBackend(CeedGetData(ceed, &data));
500 
501   // Compile kernels
502   char *magma_common_path;
503   char *interp_path, *grad_path, *weight_path;
504   char *basis_kernel_source;
505   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_defs.h", &magma_common_path));
506   CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source -----\n");
507   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
508   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_tensor.h", &magma_common_path));
509   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, magma_common_path, &basis_kernel_source));
510   char   *interp_name_base = "ceed/jit-source/magma/interp";
511   CeedInt interp_name_len  = strlen(interp_name_base) + 6;
512   char    interp_name[interp_name_len];
513   snprintf(interp_name, interp_name_len, "%s-%" CeedInt_FMT "d.h", interp_name_base, dim);
514   CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_name, &interp_path));
515   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
516   char   *grad_name_base = "ceed/jit-source/magma/grad";
517   CeedInt grad_name_len  = strlen(grad_name_base) + 6;
518   char    grad_name[grad_name_len];
519   snprintf(grad_name, grad_name_len, "%s-%" CeedInt_FMT "d.h", grad_name_base, dim);
520   CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_name, &grad_path));
521   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
522   char   *weight_name_base = "ceed/jit-source/magma/weight";
523   CeedInt weight_name_len  = strlen(weight_name_base) + 6;
524   char    weight_name[weight_name_len];
525   snprintf(weight_name, weight_name_len, "%s-%" CeedInt_FMT "d.h", weight_name_base, dim);
526   CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_name, &weight_path));
527   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_path, &basis_kernel_source));
528   CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source Complete! -----\n");
529   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip
530   // data
531   Ceed delegate;
532   CeedCallBackend(CeedGetDelegate(ceed, &delegate));
533   CeedCallBackend(CeedCompileMagma(delegate, basis_kernel_source, &impl->module, 5, "DIM", dim, "NCOMP", ncomp, "P", P1d, "Q", Q1d, "MAXPQ",
534                                    CeedIntMax(P1d, Q1d)));
535 
536   // Kernel setup
537   switch (dim) {
538     case 1:
539       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->magma_interp));
540       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->magma_interp_tr));
541       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->magma_grad));
542       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->magma_grad_tr));
543       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->magma_weight));
544       break;
545     case 2:
546       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->magma_interp));
547       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->magma_interp_tr));
548       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->magma_grad));
549       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->magma_grad_tr));
550       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->magma_weight));
551       break;
552     case 3:
553       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->magma_interp));
554       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->magma_interp_tr));
555       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->magma_grad));
556       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->magma_grad_tr));
557       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->magma_weight));
558   }
559 
560   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma));
561   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma));
562 
563   // Copy qref1d to the GPU
564   CeedCallBackend(magma_malloc((void **)&impl->dqref1d, Q1d * sizeof(qref1d[0])));
565   magma_setvector(Q1d, sizeof(qref1d[0]), qref1d, 1, impl->dqref1d, 1, data->queue);
566 
567   // Copy interp1d to the GPU
568   CeedCallBackend(magma_malloc((void **)&impl->dinterp1d, Q1d * P1d * sizeof(interp1d[0])));
569   magma_setvector(Q1d * P1d, sizeof(interp1d[0]), interp1d, 1, impl->dinterp1d, 1, data->queue);
570 
571   // Copy grad1d to the GPU
572   CeedCallBackend(magma_malloc((void **)&impl->dgrad1d, Q1d * P1d * sizeof(grad1d[0])));
573   magma_setvector(Q1d * P1d, sizeof(grad1d[0]), grad1d, 1, impl->dgrad1d, 1, data->queue);
574 
575   // Copy qweight1d to the GPU
576   CeedCallBackend(magma_malloc((void **)&impl->dqweight1d, Q1d * sizeof(qweight1d[0])));
577   magma_setvector(Q1d, sizeof(qweight1d[0]), qweight1d, 1, impl->dqweight1d, 1, data->queue);
578 
579   CeedCallBackend(CeedBasisSetData(basis, impl));
580   CeedCallBackend(CeedFree(&magma_common_path));
581   CeedCallBackend(CeedFree(&interp_path));
582   CeedCallBackend(CeedFree(&grad_path));
583   CeedCallBackend(CeedFree(&weight_path));
584   CeedCallBackend(CeedFree(&basis_kernel_source));
585 
586   return CEED_ERROR_SUCCESS;
587 }
588 
589 #ifdef __cplusplus
590 CEED_INTERN "C"
591 #endif
592     int
593     CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt ndof, CeedInt nqpts, const CeedScalar *interp, const CeedScalar *grad,
594                             const CeedScalar *qref, const CeedScalar *qweight, CeedBasis basis) {
595   CeedBasisNonTensor_Magma *impl;
596   Ceed                      ceed;
597   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
598 
599   Ceed_Magma *data;
600   CeedCallBackend(CeedGetData(ceed, &data));
601   magma_int_t arch = magma_getdevice_arch();
602   CeedCallBackend(CeedCalloc(1, &impl));
603   // Compile kernels
604   char *magma_common_path;
605   char *interp_path, *grad_path;
606   char *basis_kernel_source;
607   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_defs.h", &magma_common_path));
608   CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source -----\n");
609   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
610   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_nontensor.h", &magma_common_path));
611   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, magma_common_path, &basis_kernel_source));
612   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/interp-nontensor.h", &interp_path));
613   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
614   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/grad-nontensor.h", &grad_path));
615   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
616 
617   // tuning parameters for nb
618   CeedInt nb_interp_n[MAGMA_NONTENSOR_KERNEL_INSTANCES];
619   CeedInt nb_interp_t[MAGMA_NONTENSOR_KERNEL_INSTANCES];
620   CeedInt nb_grad_n[MAGMA_NONTENSOR_KERNEL_INSTANCES];
621   CeedInt nb_grad_t[MAGMA_NONTENSOR_KERNEL_INSTANCES];
622   CeedInt P = ndof, Q = nqpts;
623   CeedInt Narray[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_N_VALUES};
624   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
625     nb_interp_n[in] = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_INTERP, CEED_NOTRANSPOSE, P, Narray[in], Q);
626     nb_interp_t[in] = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_INTERP, CEED_TRANSPOSE, P, Narray[in], Q);
627     nb_grad_n[in]   = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_GRAD, CEED_NOTRANSPOSE, P, Narray[in], Q);
628     nb_grad_t[in]   = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_GRAD, CEED_TRANSPOSE, P, Narray[in], Q);
629   }
630 
631   // compile
632   Ceed delegate;
633   CeedCallBackend(CeedGetDelegate(ceed, &delegate));
634   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
635     CeedCallBackend(CeedCompileMagma(delegate, basis_kernel_source, &impl->module[in], 7, "DIM", dim, "P", P, "Q", Q, "NB_INTERP_N", nb_interp_n[in],
636                                      "NB_INTERP_T", nb_interp_t[in], "NB_GRAD_N", nb_grad_n[in], "NB_GRAD_T", nb_grad_t[in]));
637   }
638 
639   // get kernels
640   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
641     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_interp_nontensor_n", &impl->magma_interp_nontensor[in]));
642     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_interp_nontensor_t", &impl->magma_interp_tr_nontensor[in]));
643     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_grad_nontensor_n", &impl->magma_grad_nontensor[in]));
644     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_grad_nontensor_t", &impl->magma_grad_tr_nontensor[in]));
645   }
646 
647   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
648   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
649 
650   // Copy qref to the GPU
651   CeedCallBackend(magma_malloc((void **)&impl->dqref, nqpts * sizeof(qref[0])));
652   magma_setvector(nqpts, sizeof(qref[0]), qref, 1, impl->dqref, 1, data->queue);
653 
654   // Copy interp to the GPU
655   CeedCallBackend(magma_malloc((void **)&impl->dinterp, nqpts * ndof * sizeof(interp[0])));
656   magma_setvector(nqpts * ndof, sizeof(interp[0]), interp, 1, impl->dinterp, 1, data->queue);
657 
658   // Copy grad to the GPU
659   CeedCallBackend(magma_malloc((void **)&impl->dgrad, nqpts * ndof * dim * sizeof(grad[0])));
660   magma_setvector(nqpts * ndof * dim, sizeof(grad[0]), grad, 1, impl->dgrad, 1, data->queue);
661 
662   // Copy qweight to the GPU
663   CeedCallBackend(magma_malloc((void **)&impl->dqweight, nqpts * sizeof(qweight[0])));
664   magma_setvector(nqpts, sizeof(qweight[0]), qweight, 1, impl->dqweight, 1, data->queue);
665 
666   CeedCallBackend(CeedBasisSetData(basis, impl));
667   CeedCallBackend(CeedFree(&magma_common_path));
668   CeedCallBackend(CeedFree(&interp_path));
669   CeedCallBackend(CeedFree(&grad_path));
670   CeedCallBackend(CeedFree(&basis_kernel_source));
671   return CEED_ERROR_SUCCESS;
672 }
673