xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma-basis.c (revision 2b730f8b5a9c809740a0b3b302db43a719c636b1)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <ceed/jit-tools.h>
11 #include <string.h>
12 
13 #include "ceed-magma.h"
14 #ifdef CEED_MAGMA_USE_HIP
15 #include "../hip/ceed-hip-common.h"
16 #include "../hip/ceed-hip-compile.h"
17 #else
18 #include "../cuda/ceed-cuda-common.h"
19 #include "../cuda/ceed-cuda-compile.h"
20 #endif
21 
22 #ifdef __cplusplus
23 CEED_INTERN "C"
24 #endif
25     int
26     CeedBasisApply_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
27   Ceed ceed;
28   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
29   CeedInt dim, ncomp, ndof;
30   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
31   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
32   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
33 
34   Ceed_Magma *data;
35   CeedCallBackend(CeedGetData(ceed, &data));
36 
37   const CeedScalar *u;
38   CeedScalar       *v;
39   if (emode != CEED_EVAL_WEIGHT) {
40     CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &u));
41   } else if (emode != CEED_EVAL_WEIGHT) {
42     // LCOV_EXCL_START
43     return CeedError(ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
44     // LCOV_EXCL_STOP
45   }
46   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &v));
47 
48   CeedBasis_Magma *impl;
49   CeedCallBackend(CeedBasisGetData(basis, &impl));
50 
51   CeedInt P1d, Q1d;
52   CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P1d));
53   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q1d));
54 
55   CeedDebug256(ceed, 4, "[CeedBasisApply_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * CeedIntPow(P1d, dim), ncomp);
56 
57   if (tmode == CEED_TRANSPOSE) {
58     CeedSize length;
59     CeedCallBackend(CeedVectorGetLength(V, &length));
60     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
61       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)v, length, data->queue);
62     } else {
63       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)v, length, data->queue);
64     }
65     ceed_magma_queue_sync(data->queue);
66   }
67 
68   switch (emode) {
69     case CEED_EVAL_INTERP: {
70       CeedInt P = P1d, Q = Q1d;
71       if (tmode == CEED_TRANSPOSE) {
72         P = Q1d;
73         Q = P1d;
74       }
75 
76       // Define element sizes for dofs/quad
77       CeedInt elquadsize = CeedIntPow(Q1d, dim);
78       CeedInt eldofssize = CeedIntPow(P1d, dim);
79 
80       // E-vector ordering -------------- Q-vector ordering
81       //  component                        component
82       //    elem                             elem
83       //       node                            node
84 
85       // ---  Define strides for NOTRANSPOSE mode: ---
86       // Input (u) is E-vector, output (v) is Q-vector
87 
88       // Element strides
89       CeedInt u_elstride = eldofssize;
90       CeedInt v_elstride = elquadsize;
91       // Component strides
92       CeedInt u_compstride = nelem * eldofssize;
93       CeedInt v_compstride = nelem * elquadsize;
94 
95       // ---  Swap strides for TRANSPOSE mode: ---
96       if (tmode == CEED_TRANSPOSE) {
97         // Input (u) is Q-vector, output (v) is E-vector
98         // Element strides
99         v_elstride = eldofssize;
100         u_elstride = elquadsize;
101         // Component strides
102         v_compstride = nelem * eldofssize;
103         u_compstride = nelem * elquadsize;
104       }
105 
106       CeedInt nthreads = 1;
107       CeedInt ntcol    = 1;
108       CeedInt shmem    = 0;
109       CeedInt maxPQ    = CeedIntMax(P, Q);
110 
111       switch (dim) {
112         case 1:
113           nthreads = maxPQ;
114           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
115           shmem += sizeof(CeedScalar) * ntcol * (ncomp * (1 * P + 1 * Q));
116           shmem += sizeof(CeedScalar) * (P * Q);
117           break;
118         case 2:
119           nthreads = maxPQ;
120           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
121           shmem += P * Q * sizeof(CeedScalar);                // for sT
122           shmem += ntcol * (P * maxPQ * sizeof(CeedScalar));  // for reforming rU we need PxP, and for the intermediate output we need PxQ
123           break;
124         case 3:
125           nthreads = maxPQ * maxPQ;
126           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
127           shmem += sizeof(CeedScalar) * (P * Q);  // for sT
128           shmem += sizeof(CeedScalar) * ntcol *
129                    (CeedIntMax(P * P * maxPQ,
130                                P * Q * Q));  // rU needs P^2xP, the intermediate output needs max(P^2xQ,PQ^2)
131       }
132       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
133       void   *args[] = {&impl->dinterp1d, &u, &u_elstride, &u_compstride, &v, &v_elstride, &v_compstride, &nelem};
134 
135       if (tmode == CEED_TRANSPOSE) {
136         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp_tr, grid, nthreads, ntcol, 1, shmem, args));
137       } else {
138         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp, grid, nthreads, ntcol, 1, shmem, args));
139       }
140     } break;
141     case CEED_EVAL_GRAD: {
142       CeedInt P = P1d, Q = Q1d;
143       // In CEED_NOTRANSPOSE mode:
144       // u is (P^dim x nc), column-major layout (nc = ncomp)
145       // v is (Q^dim x nc x dim), column-major layout (nc = ncomp)
146       // In CEED_TRANSPOSE mode, the sizes of u and v are switched.
147       if (tmode == CEED_TRANSPOSE) {
148         P = Q1d, Q = P1d;
149       }
150 
151       // Define element sizes for dofs/quad
152       CeedInt elquadsize = CeedIntPow(Q1d, dim);
153       CeedInt eldofssize = CeedIntPow(P1d, dim);
154 
155       // E-vector ordering -------------- Q-vector ordering
156       //                                  dim
157       //  component                        component
158       //    elem                              elem
159       //       node                            node
160 
161       // ---  Define strides for NOTRANSPOSE mode: ---
162       // Input (u) is E-vector, output (v) is Q-vector
163 
164       // Element strides
165       CeedInt u_elstride = eldofssize;
166       CeedInt v_elstride = elquadsize;
167       // Component strides
168       CeedInt u_compstride = nelem * eldofssize;
169       CeedInt v_compstride = nelem * elquadsize;
170       // Dimension strides
171       CeedInt u_dimstride = 0;
172       CeedInt v_dimstride = nelem * elquadsize * ncomp;
173 
174       // ---  Swap strides for TRANSPOSE mode: ---
175       if (tmode == CEED_TRANSPOSE) {
176         // Input (u) is Q-vector, output (v) is E-vector
177         // Element strides
178         v_elstride = eldofssize;
179         u_elstride = elquadsize;
180         // Component strides
181         v_compstride = nelem * eldofssize;
182         u_compstride = nelem * elquadsize;
183         // Dimension strides
184         v_dimstride = 0;
185         u_dimstride = nelem * elquadsize * ncomp;
186       }
187 
188       CeedInt nthreads = 1;
189       CeedInt ntcol    = 1;
190       CeedInt shmem    = 0;
191       CeedInt maxPQ    = CeedIntMax(P, Q);
192 
193       switch (dim) {
194         case 1:
195           nthreads = maxPQ;
196           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
197           shmem += sizeof(CeedScalar) * ntcol * (ncomp * (1 * P + 1 * Q));
198           shmem += sizeof(CeedScalar) * (P * Q);
199           break;
200         case 2:
201           nthreads = maxPQ;
202           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
203           shmem += sizeof(CeedScalar) * 2 * P * Q;            // for sTinterp and sTgrad
204           shmem += sizeof(CeedScalar) * ntcol * (P * maxPQ);  // for reforming rU we need PxP, and for the intermediate output we need PxQ
205           break;
206         case 3:
207           nthreads = maxPQ * maxPQ;
208           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
209           shmem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
210           shmem += sizeof(CeedScalar) * ntcol *
211                    CeedIntMax(P * P * P,
212                               (P * P * Q) + (P * Q * Q));  // rU needs P^2xP, the intermediate outputs need (P^2.Q + P.Q^2)
213       }
214       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
215       void   *args[] = {&impl->dinterp1d, &impl->dgrad1d, &u,           &u_elstride, &u_compstride, &u_dimstride, &v,
216                         &v_elstride,      &v_compstride,  &v_dimstride, &nelem};
217 
218       if (tmode == CEED_TRANSPOSE) {
219         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad_tr, grid, nthreads, ntcol, 1, shmem, args));
220       } else {
221         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad, grid, nthreads, ntcol, 1, shmem, args));
222       }
223     } break;
224     case CEED_EVAL_WEIGHT: {
225       if (tmode == CEED_TRANSPOSE)
226         // LCOV_EXCL_START
227         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
228       // LCOV_EXCL_STOP
229       CeedInt Q          = Q1d;
230       CeedInt eldofssize = CeedIntPow(Q, dim);
231       CeedInt nthreads   = 1;
232       CeedInt ntcol      = 1;
233       CeedInt shmem      = 0;
234 
235       switch (dim) {
236         case 1:
237           nthreads = Q;
238           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_1D);
239           shmem += sizeof(CeedScalar) * Q;          // for dqweight1d
240           shmem += sizeof(CeedScalar) * ntcol * Q;  // for output
241           break;
242         case 2:
243           nthreads = Q;
244           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_2D);
245           shmem += sizeof(CeedScalar) * Q;  // for dqweight1d
246           break;
247         case 3:
248           nthreads = Q * Q;
249           ntcol    = MAGMA_BASIS_NTCOL(nthreads, MAGMA_MAXTHREADS_3D);
250           shmem += sizeof(CeedScalar) * Q;  // for dqweight1d
251       }
252       CeedInt grid   = (nelem + ntcol - 1) / ntcol;
253       void   *args[] = {&impl->dqweight1d, &v, &eldofssize, &nelem};
254 
255       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_weight, grid, nthreads, ntcol, 1, shmem, args));
256     } break;
257     // LCOV_EXCL_START
258     case CEED_EVAL_DIV:
259       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
260     case CEED_EVAL_CURL:
261       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
262     case CEED_EVAL_NONE:
263       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
264       // LCOV_EXCL_STOP
265   }
266 
267   // must sync to ensure completeness
268   ceed_magma_queue_sync(data->queue);
269 
270   if (emode != CEED_EVAL_WEIGHT) {
271     CeedCallBackend(CeedVectorRestoreArrayRead(U, &u));
272   }
273   CeedCallBackend(CeedVectorRestoreArray(V, &v));
274   return CEED_ERROR_SUCCESS;
275 }
276 
277 #ifdef __cplusplus
278 CEED_INTERN "C"
279 #endif
280     int
281     CeedBasisApplyNonTensor_f64_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
282   Ceed ceed;
283   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
284 
285   Ceed_Magma *data;
286   CeedCallBackend(CeedGetData(ceed, &data));
287 
288   CeedInt dim, ncomp, ndof, nqpt;
289   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
290   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
291   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
292   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &nqpt));
293   const CeedScalar *du;
294   CeedScalar       *dv;
295   if (emode != CEED_EVAL_WEIGHT) {
296     CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
297   } else if (emode != CEED_EVAL_WEIGHT) {
298     // LCOV_EXCL_START
299     return CeedError(ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
300     // LCOV_EXCL_STOP
301   }
302   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
303 
304   CeedBasisNonTensor_Magma *impl;
305   CeedCallBackend(CeedBasisGetData(basis, &impl));
306 
307   CeedDebug256(ceed, 4, "[CeedBasisApplyNonTensor_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * ndof, ncomp);
308 
309   if (tmode == CEED_TRANSPOSE) {
310     CeedSize length;
311     CeedCallBackend(CeedVectorGetLength(V, &length));
312     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
313       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
314     } else {
315       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
316     }
317     ceed_magma_queue_sync(data->queue);
318   }
319 
320   switch (emode) {
321     case CEED_EVAL_INTERP: {
322       CeedInt P = ndof, Q = nqpt;
323       if (tmode == CEED_TRANSPOSE)
324         magma_dgemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, (double *)impl->dinterp, P, (double *)du, Q, 0.0, (double *)dv, P,
325                               data->queue);
326       else
327         magma_dgemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, (double *)impl->dinterp, P, (double *)du, P, 0.0, (double *)dv, Q,
328                               data->queue);
329     } break;
330 
331     case CEED_EVAL_GRAD: {
332       CeedInt P = ndof, Q = nqpt;
333       if (tmode == CEED_TRANSPOSE) {
334         CeedScalar beta = 0.0;
335         for (int d = 0; d < dim; d++) {
336           if (d > 0) beta = 1.0;
337           magma_dgemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, (double *)(impl->dgrad + d * P * Q), P,
338                                 (double *)(du + d * nelem * ncomp * Q), Q, beta, (double *)dv, P, data->queue);
339         }
340       } else {
341         for (int d = 0; d < dim; d++)
342           magma_dgemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, (double *)(impl->dgrad + d * P * Q), P, (double *)du, P, 0.0,
343                                 (double *)(dv + d * nelem * ncomp * Q), Q, data->queue);
344       }
345     } break;
346 
347     case CEED_EVAL_WEIGHT: {
348       if (tmode == CEED_TRANSPOSE)
349         // LCOV_EXCL_START
350         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
351       // LCOV_EXCL_STOP
352 
353       int elemsPerBlock = 1;  // basis->Q1d < 7 ? optElems[basis->Q1d] : 1;
354       int grid          = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0);
355       magma_weight_nontensor(grid, nqpt, nelem, nqpt, impl->dqweight, dv, data->queue);
356     } break;
357 
358     // LCOV_EXCL_START
359     case CEED_EVAL_DIV:
360       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
361     case CEED_EVAL_CURL:
362       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
363     case CEED_EVAL_NONE:
364       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
365       // LCOV_EXCL_STOP
366   }
367 
368   // must sync to ensure completeness
369   ceed_magma_queue_sync(data->queue);
370 
371   if (emode != CEED_EVAL_WEIGHT) {
372     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
373   }
374   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
375   return CEED_ERROR_SUCCESS;
376 }
377 
378 int CeedBasisApplyNonTensor_f32_Magma(CeedBasis basis, CeedInt nelem, CeedTransposeMode tmode, CeedEvalMode emode, CeedVector U, CeedVector V) {
379   Ceed ceed;
380   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
381 
382   Ceed_Magma *data;
383   CeedCallBackend(CeedGetData(ceed, &data));
384 
385   CeedInt dim, ncomp, ndof, nqpt;
386   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
387   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
388   CeedCallBackend(CeedBasisGetNumNodes(basis, &ndof));
389   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &nqpt));
390   const CeedScalar *du;
391   CeedScalar       *dv;
392   if (emode != CEED_EVAL_WEIGHT) {
393     CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
394   } else if (emode != CEED_EVAL_WEIGHT) {
395     // LCOV_EXCL_START
396     return CeedError(ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
397     // LCOV_EXCL_STOP
398   }
399   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
400 
401   CeedBasisNonTensor_Magma *impl;
402   CeedCallBackend(CeedBasisGetData(basis, &impl));
403 
404   CeedDebug256(ceed, 4, "[CeedBasisApplyNonTensor_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, ncomp * ndof, ncomp);
405 
406   if (tmode == CEED_TRANSPOSE) {
407     CeedSize length;
408     CeedCallBackend(CeedVectorGetLength(V, &length));
409     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
410       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
411     } else {
412       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
413     }
414     ceed_magma_queue_sync(data->queue);
415   }
416 
417   switch (emode) {
418     case CEED_EVAL_INTERP: {
419       CeedInt P = ndof, Q = nqpt;
420       if (tmode == CEED_TRANSPOSE)
421         magma_sgemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, (float *)impl->dinterp, P, (float *)du, Q, 0.0, (float *)dv, P,
422                               data->queue);
423       else
424         magma_sgemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, (float *)impl->dinterp, P, (float *)du, P, 0.0, (float *)dv, Q,
425                               data->queue);
426     } break;
427 
428     case CEED_EVAL_GRAD: {
429       CeedInt P = ndof, Q = nqpt;
430       if (tmode == CEED_TRANSPOSE) {
431         CeedScalar beta = 0.0;
432         for (int d = 0; d < dim; d++) {
433           if (d > 0) beta = 1.0;
434           magma_sgemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, nelem * ncomp, Q, 1.0, (float *)(impl->dgrad + d * P * Q), P,
435                                 (float *)(du + d * nelem * ncomp * Q), Q, beta, (float *)dv, P, data->queue);
436         }
437       } else {
438         for (int d = 0; d < dim; d++)
439           magma_sgemm_nontensor(MagmaTrans, MagmaNoTrans, Q, nelem * ncomp, P, 1.0, (float *)(impl->dgrad + d * P * Q), P, (float *)du, P, 0.0,
440                                 (float *)(dv + d * nelem * ncomp * Q), Q, data->queue);
441       }
442     } break;
443 
444     case CEED_EVAL_WEIGHT: {
445       if (tmode == CEED_TRANSPOSE)
446         // LCOV_EXCL_START
447         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
448       // LCOV_EXCL_STOP
449 
450       int elemsPerBlock = 1;  // basis->Q1d < 7 ? optElems[basis->Q1d] : 1;
451       int grid          = nelem / elemsPerBlock + ((nelem / elemsPerBlock * elemsPerBlock < nelem) ? 1 : 0);
452       magma_weight_nontensor(grid, nqpt, nelem, nqpt, impl->dqweight, dv, data->queue);
453     } break;
454 
455     // LCOV_EXCL_START
456     case CEED_EVAL_DIV:
457       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
458     case CEED_EVAL_CURL:
459       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
460     case CEED_EVAL_NONE:
461       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
462       // LCOV_EXCL_STOP
463   }
464 
465   // must sync to ensure completeness
466   ceed_magma_queue_sync(data->queue);
467 
468   if (emode != CEED_EVAL_WEIGHT) {
469     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
470   }
471   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
472   return CEED_ERROR_SUCCESS;
473 }
474 
475 #ifdef __cplusplus
476 CEED_INTERN "C"
477 #endif
478     int
479     CeedBasisDestroy_Magma(CeedBasis basis) {
480   CeedBasis_Magma *impl;
481   CeedCallBackend(CeedBasisGetData(basis, &impl));
482 
483   CeedCallBackend(magma_free(impl->dqref1d));
484   CeedCallBackend(magma_free(impl->dinterp1d));
485   CeedCallBackend(magma_free(impl->dgrad1d));
486   CeedCallBackend(magma_free(impl->dqweight1d));
487   Ceed ceed;
488   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
489 #ifdef CEED_MAGMA_USE_HIP
490   CeedCallHip(ceed, hipModuleUnload(impl->module));
491 #else
492   CeedCallCuda(ceed, cuModuleUnload(impl->module));
493 #endif
494 
495   CeedCallBackend(CeedFree(&impl));
496 
497   return CEED_ERROR_SUCCESS;
498 }
499 
500 #ifdef __cplusplus
501 CEED_INTERN "C"
502 #endif
503     int
504     CeedBasisDestroyNonTensor_Magma(CeedBasis basis) {
505   CeedBasisNonTensor_Magma *impl;
506   CeedCallBackend(CeedBasisGetData(basis, &impl));
507 
508   CeedCallBackend(magma_free(impl->dqref));
509   CeedCallBackend(magma_free(impl->dinterp));
510   CeedCallBackend(magma_free(impl->dgrad));
511   CeedCallBackend(magma_free(impl->dqweight));
512 
513   CeedCallBackend(CeedFree(&impl));
514 
515   return CEED_ERROR_SUCCESS;
516 }
517 
518 #ifdef __cplusplus
519 CEED_INTERN "C"
520 #endif
521     int
522     CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P1d, CeedInt Q1d, const CeedScalar *interp1d, const CeedScalar *grad1d,
523                                   const CeedScalar *qref1d, const CeedScalar *qweight1d, CeedBasis basis) {
524   CeedBasis_Magma *impl;
525   CeedCallBackend(CeedCalloc(1, &impl));
526   Ceed ceed;
527   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
528 
529   // Check for supported parameters
530   CeedInt ncomp = 0;
531   CeedCallBackend(CeedBasisGetNumComponents(basis, &ncomp));
532   Ceed_Magma *data;
533   CeedCallBackend(CeedGetData(ceed, &data));
534 
535   // Compile kernels
536   char *magma_common_path;
537   char *interp_path, *grad_path, *weight_path;
538   char *basis_kernel_source;
539   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_device.h", &magma_common_path));
540   CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source -----\n");
541   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
542   char   *interp_name_base = "ceed/jit-source/magma/interp";
543   CeedInt interp_name_len  = strlen(interp_name_base) + 6;
544   char    interp_name[interp_name_len];
545   snprintf(interp_name, interp_name_len, "%s-%" CeedInt_FMT "d.h", interp_name_base, dim);
546   CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_name, &interp_path));
547   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
548   char   *grad_name_base = "ceed/jit-source/magma/grad";
549   CeedInt grad_name_len  = strlen(grad_name_base) + 6;
550   char    grad_name[grad_name_len];
551   snprintf(grad_name, grad_name_len, "%s-%" CeedInt_FMT "d.h", grad_name_base, dim);
552   CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_name, &grad_path));
553   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
554   char   *weight_name_base = "ceed/jit-source/magma/weight";
555   CeedInt weight_name_len  = strlen(weight_name_base) + 6;
556   char    weight_name[weight_name_len];
557   snprintf(weight_name, weight_name_len, "%s-%" CeedInt_FMT "d.h", weight_name_base, dim);
558   CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_name, &weight_path));
559   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_path, &basis_kernel_source));
560   CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source Complete! -----\n");
561   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip
562   // data
563   Ceed delegate;
564   CeedCallBackend(CeedGetDelegate(ceed, &delegate));
565   CeedCallBackend(CeedCompileMagma(delegate, basis_kernel_source, &impl->module, 5, "DIM", dim, "NCOMP", ncomp, "P", P1d, "Q", Q1d, "MAXPQ",
566                                    CeedIntMax(P1d, Q1d)));
567 
568   // Kernel setup
569   switch (dim) {
570     case 1:
571       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->magma_interp));
572       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->magma_interp_tr));
573       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->magma_grad));
574       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->magma_grad_tr));
575       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->magma_weight));
576       break;
577     case 2:
578       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->magma_interp));
579       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->magma_interp_tr));
580       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->magma_grad));
581       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->magma_grad_tr));
582       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->magma_weight));
583       break;
584     case 3:
585       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->magma_interp));
586       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->magma_interp_tr));
587       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->magma_grad));
588       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->magma_grad_tr));
589       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->magma_weight));
590   }
591 
592   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma));
593   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma));
594 
595   // Copy qref1d to the GPU
596   CeedCallBackend(magma_malloc((void **)&impl->dqref1d, Q1d * sizeof(qref1d[0])));
597   magma_setvector(Q1d, sizeof(qref1d[0]), qref1d, 1, impl->dqref1d, 1, data->queue);
598 
599   // Copy interp1d to the GPU
600   CeedCallBackend(magma_malloc((void **)&impl->dinterp1d, Q1d * P1d * sizeof(interp1d[0])));
601   magma_setvector(Q1d * P1d, sizeof(interp1d[0]), interp1d, 1, impl->dinterp1d, 1, data->queue);
602 
603   // Copy grad1d to the GPU
604   CeedCallBackend(magma_malloc((void **)&impl->dgrad1d, Q1d * P1d * sizeof(grad1d[0])));
605   magma_setvector(Q1d * P1d, sizeof(grad1d[0]), grad1d, 1, impl->dgrad1d, 1, data->queue);
606 
607   // Copy qweight1d to the GPU
608   CeedCallBackend(magma_malloc((void **)&impl->dqweight1d, Q1d * sizeof(qweight1d[0])));
609   magma_setvector(Q1d, sizeof(qweight1d[0]), qweight1d, 1, impl->dqweight1d, 1, data->queue);
610 
611   CeedCallBackend(CeedBasisSetData(basis, impl));
612   CeedCallBackend(CeedBasisSetData(basis, impl));
613   CeedCallBackend(CeedFree(&magma_common_path));
614   CeedCallBackend(CeedFree(&interp_path));
615   CeedCallBackend(CeedFree(&grad_path));
616   CeedCallBackend(CeedFree(&weight_path));
617   CeedCallBackend(CeedFree(&basis_kernel_source));
618 
619   return CEED_ERROR_SUCCESS;
620 }
621 
622 #ifdef __cplusplus
623 CEED_INTERN "C"
624 #endif
625     int
626     CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt ndof, CeedInt nqpts, const CeedScalar *interp, const CeedScalar *grad,
627                             const CeedScalar *qref, const CeedScalar *qweight, CeedBasis basis) {
628   CeedBasisNonTensor_Magma *impl;
629   Ceed                      ceed;
630   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
631 
632   Ceed_Magma *data;
633   CeedCallBackend(CeedGetData(ceed, &data));
634 
635   if (CEED_SCALAR_TYPE == CEED_SCALAR_FP64) {
636     CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_f64_Magma));
637   } else {
638     CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_f32_Magma));
639   }
640   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
641 
642   CeedCallBackend(CeedCalloc(1, &impl));
643   CeedCallBackend(CeedBasisSetData(basis, impl));
644 
645   // Copy qref to the GPU
646   CeedCallBackend(magma_malloc((void **)&impl->dqref, nqpts * sizeof(qref[0])));
647   magma_setvector(nqpts, sizeof(qref[0]), qref, 1, impl->dqref, 1, data->queue);
648 
649   // Copy interp to the GPU
650   CeedCallBackend(magma_malloc((void **)&impl->dinterp, nqpts * ndof * sizeof(interp[0])));
651   magma_setvector(nqpts * ndof, sizeof(interp[0]), interp, 1, impl->dinterp, 1, data->queue);
652 
653   // Copy grad to the GPU
654   CeedCallBackend(magma_malloc((void **)&impl->dgrad, nqpts * ndof * dim * sizeof(grad[0])));
655   magma_setvector(nqpts * ndof * dim, sizeof(grad[0]), grad, 1, impl->dgrad, 1, data->queue);
656 
657   // Copy qweight to the GPU
658   CeedCallBackend(magma_malloc((void **)&impl->dqweight, nqpts * sizeof(qweight[0])));
659   magma_setvector(nqpts, sizeof(qweight[0]), qweight, 1, impl->dqweight, 1, data->queue);
660 
661   return CEED_ERROR_SUCCESS;
662 }
663