xref: /libCEED/backends/magma/ceed-magma-basis.c (revision bb229da952f7e9779ba6cb3cd1ca2ebeac5feb1f)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <string.h>
12 
13 #ifdef CEED_MAGMA_USE_HIP
14 #include "../hip/ceed-hip-common.h"
15 #include "../hip/ceed-hip-compile.h"
16 #else
17 #include "../cuda/ceed-cuda-common.h"
18 #include "../cuda/ceed-cuda-compile.h"
19 #endif
20 #include "ceed-magma-common.h"
21 #include "ceed-magma.h"
22 
23 #ifdef __cplusplus
24 CEED_INTERN "C"
25 #endif
26     int
27     CeedBasisApply_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
28   Ceed ceed;
29   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
30   CeedInt dim, ncomp, ndof;
31   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
32   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
33   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
34 
35   Ceed_Magma *data;
36   CeedCallBackend(CeedGetData(ceed, &data));
37 
38   const CeedScalar *du;
39   CeedScalar       *dv;
40   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
41   else CeedCheck(emode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
42   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
43 
44   CeedBasis_Magma *impl;
45   CeedCallBackend(CeedBasisGetData(basis, &impl));
46 
47   CeedInt P1d, Q1d;
48   CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P1d));
49   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q1d));
50 
51   CeedDebug256(ceed, 4, "[CeedBasisApply_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * CeedIntPow(P1d, dim), ncomp);
52 
53   if (tmode == CEED_TRANSPOSE) {
54     CeedSize length;
55     CeedCallBackend(CeedVectorGetLength(V, &length));
56     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
57       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
58     } else {
59       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
60     }
61     ceed_magma_queue_sync(data->queue);
62   }
63 
64   switch (emode) {
65     case CEED_EVAL_INTERP: {
66       CeedInt P = P1d, Q = Q1d;
67       if (tmode == CEED_TRANSPOSE) {
68         P = Q1d;
69         Q = P1d;
70       }
71 
72       // Define element sizes for dofs/quad
73       CeedInt elquadsize = CeedIntPow(Q1d, dim);
74       CeedInt eldofssize = CeedIntPow(P1d, dim);
75 
76       // E-vector ordering -------------- Q-vector ordering
77       //  component                        component
78       //    elem                             elem
79       //       node                            node
80 
81       // ---  Define strides for NOTRANSPOSE mode: ---
82       // Input (du) is E-vector, output (dv) is Q-vector
83 
84       // Element strides
85       CeedInt u_elstride = eldofssize;
86       CeedInt v_elstride = elquadsize;
87       // Component strides
88       CeedInt u_compstride = nelem * eldofssize;
89       CeedInt v_compstride = nelem * elquadsize;
90 
91       // ---  Swap strides for TRANSPOSE mode: ---
92       if (tmode == CEED_TRANSPOSE) {
93         // Input (du) is Q-vector, output (dv) is E-vector
94         // Element strides
95         v_elstride = eldofssize;
96         u_elstride = elquadsize;
97         // Component strides
98         v_compstride = nelem * eldofssize;
99         u_compstride = nelem * elquadsize;
100       }
101 
102       CeedInt nthreads = 1;
103       CeedInt ntcol    = 1;
104       CeedInt shmem    = 0;
105       CeedInt maxPQ    = CeedIntMax(P, Q);
106 
107       switch (dim) {
108         case 1:
109           nthreads = maxPQ;
110           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
111           shmem += sizeof(CeedScalar) * ntcol * (ncomp * (1 * P + 1 * Q));
112           shmem += sizeof(CeedScalar) * (P * Q);
113           break;
114         case 2:
115           nthreads = maxPQ;
116           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
117           shmem += P * Q * sizeof(CeedScalar);                // for sT
118           shmem += ntcol * (P * maxPQ * sizeof(CeedScalar));  // for reforming rU we need PxP, and for the intermediate output we need PxQ
119           break;
120         case 3:
121           nthreads = maxPQ * maxPQ;
122           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
123           shmem += sizeof(CeedScalar) * (P * Q);  // for sT
124           shmem += sizeof(CeedScalar) * ntcol *
125                    (CeedIntMax(P * P * maxPQ,
126                                P * Q * Q));  // rU needs P^2xP, the intermediate output needs max(P^2xQ,PQ^2)
127       }
128       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
129       void   *args[] = {&impl->dinterp1d, &du, &u_elstride, &u_compstride, &dv, &v_elstride, &v_compstride, &nelem};
130 
131       if (tmode == CEED_TRANSPOSE) {
132         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp_tr, grid, nthreads, ntcol, 1, shmem, args));
133       } else {
134         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp, grid, nthreads, ntcol, 1, shmem, args));
135       }
136     } break;
137     case CEED_EVAL_GRAD: {
138       CeedInt P = P1d, Q = Q1d;
139       // In CEED_NOTRANSPOSE mode:
140       // du is (P^dim x nc), column-major layout (nc = ncomp)
141       // dv is (Q^dim x nc x dim), column-major layout (nc = ncomp)
142       // In CEED_TRANSPOSE mode, the sizes of du and dv are switched.
143       if (tmode == CEED_TRANSPOSE) {
144         P = Q1d;
145         Q = P1d;
146       }
147 
148       // Define element sizes for dofs/quad
149       CeedInt elquadsize = CeedIntPow(Q1d, dim);
150       CeedInt eldofssize = CeedIntPow(P1d, dim);
151 
152       // E-vector ordering -------------- Q-vector ordering
153       //                                  dim
154       //  component                        component
155       //    elem                              elem
156       //       node                            node
157 
158       // ---  Define strides for NOTRANSPOSE mode: ---
159       // Input (du) is E-vector, output (dv) is Q-vector
160 
161       // Element strides
162       CeedInt u_elstride = eldofssize;
163       CeedInt v_elstride = elquadsize;
164       // Component strides
165       CeedInt u_compstride = nelem * eldofssize;
166       CeedInt v_compstride = nelem * elquadsize;
167       // Dimension strides
168       CeedInt u_dimstride = 0;
169       CeedInt v_dimstride = nelem * elquadsize * ncomp;
170 
171       // ---  Swap strides for TRANSPOSE mode: ---
172       if (tmode == CEED_TRANSPOSE) {
173         // Input (du) is Q-vector, output (dv) is E-vector
174         // Element strides
175         v_elstride = eldofssize;
176         u_elstride = elquadsize;
177         // Component strides
178         v_compstride = nelem * eldofssize;
179         u_compstride = nelem * elquadsize;
180         // Dimension strides
181         v_dimstride = 0;
182         u_dimstride = nelem * elquadsize * ncomp;
183       }
184 
185       CeedInt nthreads = 1;
186       CeedInt ntcol    = 1;
187       CeedInt shmem    = 0;
188       CeedInt maxPQ    = CeedIntMax(P, Q);
189 
190       switch (dim) {
191         case 1:
192           nthreads = maxPQ;
193           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
194           shmem += sizeof(CeedScalar) * ntcol * (ncomp * (1 * P + 1 * Q));
195           shmem += sizeof(CeedScalar) * (P * Q);
196           break;
197         case 2:
198           nthreads = maxPQ;
199           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
200           shmem += sizeof(CeedScalar) * 2 * P * Q;            // for sTinterp and sTgrad
201           shmem += sizeof(CeedScalar) * ntcol * (P * maxPQ);  // for reforming rU we need PxP, and for the intermediate output we need PxQ
202           break;
203         case 3:
204           nthreads = maxPQ * maxPQ;
205           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
206           shmem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
207           shmem += sizeof(CeedScalar) * ntcol *
208                    CeedIntMax(P * P * P,
209                               (P * P * Q) + (P * Q * Q));  // rU needs P^2xP, the intermediate outputs need (P^2.Q + P.Q^2)
210       }
211       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
212       void   *args[] = {&impl->dinterp1d, &impl->dgrad1d, &du,          &u_elstride, &u_compstride, &u_dimstride, &dv,
213                         &v_elstride,      &v_compstride,  &v_dimstride, &nelem};
214 
215       if (tmode == CEED_TRANSPOSE) {
216         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad_tr, grid, nthreads, ntcol, 1, shmem, args));
217       } else {
218         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad, grid, nthreads, ntcol, 1, shmem, args));
219       }
220     } break;
221     case CEED_EVAL_WEIGHT: {
222       CeedCheck(tmode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
223       CeedInt Q          = Q1d;
224       CeedInt eldofssize = CeedIntPow(Q, dim);
225       CeedInt nthreads   = 1;
226       CeedInt ntcol      = 1;
227       CeedInt shmem      = 0;
228 
229       switch (dim) {
230         case 1:
231           nthreads = Q;
232           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
233           shmem += sizeof(CeedScalar) * Q;          // for dqweight1d
234           shmem += sizeof(CeedScalar) * ntcol * Q;  // for output
235           break;
236         case 2:
237           nthreads = Q;
238           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
239           shmem += sizeof(CeedScalar) * Q;  // for dqweight1d
240           break;
241         case 3:
242           nthreads = Q * Q;
243           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
244           shmem += sizeof(CeedScalar) * Q;  // for dqweight1d
245       }
246       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
247       void   *args[] = {&impl->dqweight1d, &dv, &eldofssize, &nelem};
248 
249       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_weight, grid, nthreads, ntcol, 1, shmem, args));
250     } break;
251     // LCOV_EXCL_START
252     case CEED_EVAL_DIV:
253       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
254     case CEED_EVAL_CURL:
255       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
256     case CEED_EVAL_NONE:
257       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
258       // LCOV_EXCL_STOP
259   }
260 
261   // must sync to ensure completeness
262   ceed_magma_queue_sync(data->queue);
263 
264   if (emode != CEED_EVAL_WEIGHT) {
265     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
266   }
267   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
268   return CEED_ERROR_SUCCESS;
269 }
270 
271 #ifdef __cplusplus
272 CEED_INTERN "C"
273 #endif
274     int
275     CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
276   Ceed ceed;
277   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
278 
279   Ceed_Magma *data;
280   CeedCallBackend(CeedGetData(ceed, &data));
281 
282   magma_int_t arch = magma_getdevice_arch();
283 
284   CeedInt dim, ncomp, ndof, nqpt;
285   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
286   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
287   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
288   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &nqpt));
289   const CeedScalar *du;
290   CeedScalar       *dv;
291   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
292   else CeedCheck(emode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
293   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
294 
295   CeedBasisNonTensor_Magma *impl;
296   CeedCallBackend(CeedBasisGetData(basis, &impl));
297 
298   CeedDebug256(ceed, 4, "[CeedBasisApplyNonTensor_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * ndof, ncomp);
299 
300   if (tmode == CEED_TRANSPOSE) {
301     CeedSize length;
302     CeedCallBackend(CeedVectorGetLength(V, &length));
303     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
304       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
305     } else {
306       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
307     }
308     ceed_magma_queue_sync(data->queue);
309   }
310 
311   CeedInt            P = ndof, Q = nqpt, N = nelem * ncomp;
312   CeedInt            NB = 1;
313   CeedMagmaFunction *interp, *grad;
314 
315   CeedInt Narray[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_N_VALUES};
316   CeedInt iN                                       = 0;
317   CeedInt diff                                     = abs(Narray[iN] - N);
318   for (CeedInt in = iN + 1; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
319     CeedInt idiff = abs(Narray[in] - N);
320     if (idiff < diff) {
321       iN   = in;
322       diff = idiff;
323     }
324   }
325 
326   NB     = nontensor_rtc_get_nb(arch, 'd', emode, tmode, P, Narray[iN], Q);
327   interp = (tmode == CEED_TRANSPOSE) ? &impl->magma_interp_tr_nontensor[iN] : &impl->magma_interp_nontensor[iN];
328   grad   = (tmode == CEED_TRANSPOSE) ? &impl->magma_grad_tr_nontensor[iN] : &impl->magma_grad_nontensor[iN];
329 
330   switch (emode) {
331     case CEED_EVAL_INTERP: {
332       CeedInt P = ndof, Q = nqpt;
333       if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
334         CeedInt M     = (tmode == CEED_TRANSPOSE) ? P : Q;
335         CeedInt K     = (tmode == CEED_TRANSPOSE) ? Q : P;
336         CeedInt ntcol = MAGMA_NONTENSOR_BASIS_NTCOL(M);
337         CeedInt shmem = 0, shmemA = 0, shmemB = 0;
338         shmemB += ntcol * K * NB * sizeof(CeedScalar);
339         shmemA += (tmode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
340         shmem = (tmode == CEED_TRANSPOSE) ? (shmemA + shmemB) : CeedIntMax(shmemA, shmemB);
341 
342         CeedInt       grid   = MAGMA_CEILDIV(MAGMA_CEILDIV(N, NB), ntcol);
343         magma_trans_t transA = (tmode == CEED_TRANSPOSE) ? MagmaNoTrans : MagmaTrans;
344         magma_trans_t transB = MagmaNoTrans;
345         CeedScalar    alpha = 1.0, beta = 0.0;
346 
347         void *args[] = {&transA, &transB, &N, &alpha, &impl->dinterp, &P, &du, &K, &beta, &dv, &M};
348         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, *interp, grid, M, ntcol, 1, shmem, args));
349       } else {
350         if (tmode == CEED_TRANSPOSE)
351           magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, impl->dinterp, P, du, Q, 0.0, dv, P, data->queue);
352         else magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, impl->dinterp, P, du, P, 0.0, dv, Q, data->queue);
353       }
354     } break;
355 
356     case CEED_EVAL_GRAD: {
357       CeedInt P = ndof, Q = nqpt;
358       if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
359         CeedInt M     = (tmode == CEED_TRANSPOSE) ? P : Q;
360         CeedInt K     = (tmode == CEED_TRANSPOSE) ? Q : P;
361         CeedInt ntcol = MAGMA_NONTENSOR_BASIS_NTCOL(M);
362         CeedInt shmem = 0, shmemA = 0, shmemB = 0;
363         shmemB += ntcol * K * NB * sizeof(CeedScalar);
364         shmemA += (tmode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
365         shmem = shmemA + shmemB;
366 
367         CeedInt       grid   = MAGMA_CEILDIV(MAGMA_CEILDIV(N, NB), ntcol);
368         magma_trans_t transA = (tmode == CEED_TRANSPOSE) ? MagmaNoTrans : MagmaTrans;
369         magma_trans_t transB = MagmaNoTrans;
370 
371         void *args[] = {&transA, &transB, &N, &impl->dgrad, &P, &du, &K, &dv, &M};
372         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, *grad, grid, M, ntcol, 1, shmem, args));
373       } else {
374         if (tmode == CEED_TRANSPOSE) {
375           CeedScalar beta = 0.0;
376           for (int d = 0; d < dim; d++) {
377             if (d > 0) beta = 1.0;
378             magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, impl->dgrad + d * P * Q, P, du + d * nelem * ncomp * Q, Q,
379                                  beta, dv, P, data->queue);
380           }
381         } else {
382           for (int d = 0; d < dim; d++)
383             magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, impl->dgrad + d * P * Q, P, du, P, 0.0,
384                                  dv + d * nelem * ncomp * Q, Q, data->queue);
385         }
386       }
387     } break;
388 
389     case CEED_EVAL_WEIGHT: {
390       CeedCheck(tmode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
391 
392       int elemsPerBlock = 1;  // basis->Q1d < 7 ? optElems[basis->Q1d] : 1;
393       int grid          = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0);
394       magma_weight_nontensor(grid, nqpt, nelem, nqpt, impl->dqweight, dv, data->queue);
395     } break;
396 
397     // LCOV_EXCL_START
398     case CEED_EVAL_DIV:
399       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
400     case CEED_EVAL_CURL:
401       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
402     case CEED_EVAL_NONE:
403       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
404       // LCOV_EXCL_STOP
405   }
406 
407   // must sync to ensure completeness
408   ceed_magma_queue_sync(data->queue);
409 
410   if (emode != CEED_EVAL_WEIGHT) {
411     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
412   }
413   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
414   return CEED_ERROR_SUCCESS;
415 }
416 
417 #ifdef __cplusplus
418 CEED_INTERN "C"
419 #endif
420     int
421     CeedBasisDestroy_Magma(CeedBasis basis) {
422   CeedBasis_Magma *impl;
423   CeedCallBackend(CeedBasisGetData(basis, &impl));
424 
425   CeedCallBackend(magma_free(impl->dqref1d));
426   CeedCallBackend(magma_free(impl->dinterp1d));
427   CeedCallBackend(magma_free(impl->dgrad1d));
428   CeedCallBackend(magma_free(impl->dqweight1d));
429   Ceed ceed;
430   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
431 #ifdef CEED_MAGMA_USE_HIP
432   CeedCallHip(ceed, hipModuleUnload(impl->module));
433 #else
434   CeedCallCuda(ceed, cuModuleUnload(impl->module));
435 #endif
436 
437   CeedCallBackend(CeedFree(&impl));
438 
439   return CEED_ERROR_SUCCESS;
440 }
441 
442 #ifdef __cplusplus
443 CEED_INTERN "C"
444 #endif
445     int
446     CeedBasisDestroyNonTensor_Magma(CeedBasis basis) {
447   CeedBasisNonTensor_Magma *impl;
448   CeedCallBackend(CeedBasisGetData(basis, &impl));
449 
450   CeedCallBackend(magma_free(impl->dqref));
451   CeedCallBackend(magma_free(impl->dinterp));
452   CeedCallBackend(magma_free(impl->dgrad));
453   CeedCallBackend(magma_free(impl->dqweight));
454   Ceed ceed;
455   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
456 #ifdef CEED_MAGMA_USE_HIP
457   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
458     CeedCallHip(ceed, hipModuleUnload(impl->module[in]));
459   }
460 #else
461   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
462     CeedCallCuda(ceed, cuModuleUnload(impl->module[in]));
463   }
464 #endif
465   CeedCallBackend(CeedFree(&impl));
466 
467   return CEED_ERROR_SUCCESS;
468 }
469 
470 #ifdef __cplusplus
471 CEED_INTERN "C"
472 #endif
473     int
474     CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P1d, CeedInt Q1d, const CeedScalar *interp1d, const CeedScalar *grad1d,
475                                   const CeedScalar *qref1d, const CeedScalar *qweight1d, CeedBasis basis) {
476   CeedBasis_Magma *impl;
477   CeedCallBackend(CeedCalloc(1, &impl));
478   Ceed ceed;
479   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
480 
481   // Check for supported parameters
482   CeedInt ncomp = 0;
483   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
484   Ceed_Magma *data;
485   CeedCallBackend(CeedGetData(ceed, &data));
486 
487   // Compile kernels
488   char *magma_common_path;
489   char *interp_path, *grad_path, *weight_path;
490   char *basis_kernel_source;
491   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_defs.h", &magma_common_path));
492   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
493   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
494   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_tensor.h", &magma_common_path));
495   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, magma_common_path, &basis_kernel_source));
496   char   *interp_name_base = "ceed/jit-source/magma/interp";
497   CeedInt interp_name_len  = strlen(interp_name_base) + 6;
498   char    interp_name[interp_name_len];
499   snprintf(interp_name, interp_name_len, "%s-%" CeedInt_FMT "d.h", interp_name_base, dim);
500   CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_name, &interp_path));
501   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
502   char   *grad_name_base = "ceed/jit-source/magma/grad";
503   CeedInt grad_name_len  = strlen(grad_name_base) + 6;
504   char    grad_name[grad_name_len];
505   snprintf(grad_name, grad_name_len, "%s-%" CeedInt_FMT "d.h", grad_name_base, dim);
506   CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_name, &grad_path));
507   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
508   char   *weight_name_base = "ceed/jit-source/magma/weight";
509   CeedInt weight_name_len  = strlen(weight_name_base) + 6;
510   char    weight_name[weight_name_len];
511   snprintf(weight_name, weight_name_len, "%s-%" CeedInt_FMT "d.h", weight_name_base, dim);
512   CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_name, &weight_path));
513   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_path, &basis_kernel_source));
514   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
515   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip
516   // data
517   Ceed delegate;
518   CeedCallBackend(CeedGetDelegate(ceed, &delegate));
519   CeedCallBackend(CeedCompileMagma(delegate, basis_kernel_source, &impl->module, 5, "DIM", dim, "NCOMP", ncomp, "P", P1d, "Q", Q1d, "MAXPQ",
520                                    CeedIntMax(P1d, Q1d)));
521 
522   // Kernel setup
523   switch (dim) {
524     case 1:
525       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->magma_interp));
526       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->magma_interp_tr));
527       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->magma_grad));
528       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->magma_grad_tr));
529       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->magma_weight));
530       break;
531     case 2:
532       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->magma_interp));
533       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->magma_interp_tr));
534       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->magma_grad));
535       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->magma_grad_tr));
536       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->magma_weight));
537       break;
538     case 3:
539       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->magma_interp));
540       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->magma_interp_tr));
541       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->magma_grad));
542       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->magma_grad_tr));
543       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->magma_weight));
544   }
545 
546   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma));
547   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma));
548 
549   // Copy qref1d to the GPU
550   CeedCallBackend(magma_malloc((void **)&impl->dqref1d, Q1d * sizeof(qref1d[0])));
551   magma_setvector(Q1d, sizeof(qref1d[0]), qref1d, 1, impl->dqref1d, 1, data->queue);
552 
553   // Copy interp1d to the GPU
554   CeedCallBackend(magma_malloc((void **)&impl->dinterp1d, Q1d * P1d * sizeof(interp1d[0])));
555   magma_setvector(Q1d * P1d, sizeof(interp1d[0]), interp1d, 1, impl->dinterp1d, 1, data->queue);
556 
557   // Copy grad1d to the GPU
558   CeedCallBackend(magma_malloc((void **)&impl->dgrad1d, Q1d * P1d * sizeof(grad1d[0])));
559   magma_setvector(Q1d * P1d, sizeof(grad1d[0]), grad1d, 1, impl->dgrad1d, 1, data->queue);
560 
561   // Copy qweight1d to the GPU
562   CeedCallBackend(magma_malloc((void **)&impl->dqweight1d, Q1d * sizeof(qweight1d[0])));
563   magma_setvector(Q1d, sizeof(qweight1d[0]), qweight1d, 1, impl->dqweight1d, 1, data->queue);
564 
565   CeedCallBackend(CeedBasisSetData(basis, impl));
566   CeedCallBackend(CeedFree(&magma_common_path));
567   CeedCallBackend(CeedFree(&interp_path));
568   CeedCallBackend(CeedFree(&grad_path));
569   CeedCallBackend(CeedFree(&weight_path));
570   CeedCallBackend(CeedFree(&basis_kernel_source));
571 
572   return CEED_ERROR_SUCCESS;
573 }
574 
575 #ifdef __cplusplus
576 CEED_INTERN "C"
577 #endif
578     int
579     CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt ndof, CeedInt nqpts, const CeedScalar *interp, const CeedScalar *grad,
580                             const CeedScalar *qref, const CeedScalar *qweight, CeedBasis basis) {
581   CeedBasisNonTensor_Magma *impl;
582   Ceed                      ceed;
583   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
584 
585   Ceed_Magma *data;
586   CeedCallBackend(CeedGetData(ceed, &data));
587   magma_int_t arch = magma_getdevice_arch();
588   CeedCallBackend(CeedCalloc(1, &impl));
589   // Compile kernels
590   char *magma_common_path;
591   char *interp_path, *grad_path;
592   char *basis_kernel_source;
593   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_defs.h", &magma_common_path));
594   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
595   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
596   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_nontensor.h", &magma_common_path));
597   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, magma_common_path, &basis_kernel_source));
598   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/interp-nontensor.h", &interp_path));
599   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
600   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/grad-nontensor.h", &grad_path));
601   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
602 
603   // tuning parameters for nb
604   CeedInt nb_interp_n[MAGMA_NONTENSOR_KERNEL_INSTANCES];
605   CeedInt nb_interp_t[MAGMA_NONTENSOR_KERNEL_INSTANCES];
606   CeedInt nb_grad_n[MAGMA_NONTENSOR_KERNEL_INSTANCES];
607   CeedInt nb_grad_t[MAGMA_NONTENSOR_KERNEL_INSTANCES];
608   CeedInt P = ndof, Q = nqpts;
609   CeedInt Narray[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_N_VALUES};
610   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
611     nb_interp_n[in] = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_INTERP, CEED_NOTRANSPOSE, P, Narray[in], Q);
612     nb_interp_t[in] = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_INTERP, CEED_TRANSPOSE, P, Narray[in], Q);
613     nb_grad_n[in]   = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_GRAD, CEED_NOTRANSPOSE, P, Narray[in], Q);
614     nb_grad_t[in]   = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_GRAD, CEED_TRANSPOSE, P, Narray[in], Q);
615   }
616 
617   // compile
618   Ceed delegate;
619   CeedCallBackend(CeedGetDelegate(ceed, &delegate));
620   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
621     CeedCallBackend(CeedCompileMagma(delegate, basis_kernel_source, &impl->module[in], 7, "DIM", dim, "P", P, "Q", Q, "NB_INTERP_N", nb_interp_n[in],
622                                      "NB_INTERP_T", nb_interp_t[in], "NB_GRAD_N", nb_grad_n[in], "NB_GRAD_T", nb_grad_t[in]));
623   }
624 
625   // get kernels
626   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
627     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_interp_nontensor_n", &impl->magma_interp_nontensor[in]));
628     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_interp_nontensor_t", &impl->magma_interp_tr_nontensor[in]));
629     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_grad_nontensor_n", &impl->magma_grad_nontensor[in]));
630     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_grad_nontensor_t", &impl->magma_grad_tr_nontensor[in]));
631   }
632 
633   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
634   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
635 
636   // Copy qref to the GPU
637   CeedCallBackend(magma_malloc((void **)&impl->dqref, nqpts * sizeof(qref[0])));
638   magma_setvector(nqpts, sizeof(qref[0]), qref, 1, impl->dqref, 1, data->queue);
639 
640   // Copy interp to the GPU
641   CeedCallBackend(magma_malloc((void **)&impl->dinterp, nqpts * ndof * sizeof(interp[0])));
642   magma_setvector(nqpts * ndof, sizeof(interp[0]), interp, 1, impl->dinterp, 1, data->queue);
643 
644   // Copy grad to the GPU
645   CeedCallBackend(magma_malloc((void **)&impl->dgrad, nqpts * ndof * dim * sizeof(grad[0])));
646   magma_setvector(nqpts * ndof * dim, sizeof(grad[0]), grad, 1, impl->dgrad, 1, data->queue);
647 
648   // Copy qweight to the GPU
649   CeedCallBackend(magma_malloc((void **)&impl->dqweight, nqpts * sizeof(qweight[0])));
650   magma_setvector(nqpts, sizeof(qweight[0]), qweight, 1, impl->dqweight, 1, data->queue);
651 
652   CeedCallBackend(CeedBasisSetData(basis, impl));
653   CeedCallBackend(CeedFree(&magma_common_path));
654   CeedCallBackend(CeedFree(&interp_path));
655   CeedCallBackend(CeedFree(&grad_path));
656   CeedCallBackend(CeedFree(&basis_kernel_source));
657   return CEED_ERROR_SUCCESS;
658 }
659