xref: /libCEED/backends/magma/ceed-magma-basis.c (revision 6eb06d7cb0f5787c494a4969c0aa6769f3bcfbd0)
1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <string.h>
12 
13 #ifdef CEED_MAGMA_USE_HIP
14 #include "../hip/ceed-hip-common.h"
15 #include "../hip/ceed-hip-compile.h"
16 #else
17 #include "../cuda/ceed-cuda-common.h"
18 #include "../cuda/ceed-cuda-compile.h"
19 #endif
20 #include "ceed-magma-common.h"
21 #include "ceed-magma.h"
22 
23 #include "ceed-magma-gemm-nontensor.h"
24 #include "ceed-magma-gemm-selector.h"
25 
26 //------------------------------------------------------------------------------
27 // Basis apply - tensor
28 //------------------------------------------------------------------------------
29 static int CeedBasisApplyCore_Magma(CeedBasis basis, bool apply_add, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u,
30                                     CeedVector v) {
31   Ceed              ceed;
32   Ceed_Magma       *data;
33   CeedInt           dim, num_comp, num_nodes, P_1d, Q_1d, P, Q;
34   const CeedScalar *d_u;
35   CeedScalar       *d_v;
36   CeedBasis_Magma  *impl;
37 
38   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
39   CeedCallBackend(CeedGetData(ceed, &data));
40   CeedCallBackend(CeedBasisGetData(basis, &impl));
41   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
42   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
43   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
44   CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
45   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
46   P = P_1d;
47   Q = Q_1d;
48   if (t_mode == CEED_TRANSPOSE) {
49     P = Q_1d;
50     Q = P_1d;
51   }
52 
53   // Read vectors
54   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
55   else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
56   if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
57   else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
58 
59   // Apply basis operation
60   switch (e_mode) {
61     case CEED_EVAL_INTERP: {
62       // Define element sizes for dofs/quad
63       CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim);
64       CeedInt elem_dofs_size = CeedIntPow(P_1d, dim);
65 
66       // E-vector ordering -------------- Q-vector ordering
67       //  component                        component
68       //    elem                             elem
69       //       node                            node
70 
71       // ---  Define strides for NOTRANSPOSE mode: ---
72       // Input (d_u) is E-vector, output (d_v) is Q-vector
73 
74       // Element strides
75       CeedInt u_elem_stride = elem_dofs_size;
76       CeedInt v_elem_stride = elem_qpts_size;
77       // Component strides
78       CeedInt u_comp_stride = num_elem * elem_dofs_size;
79       CeedInt v_comp_stride = num_elem * elem_qpts_size;
80       if (t_mode == CEED_TRANSPOSE) {
81         // Input (d_u) is Q-vector, output (d_v) is E-vector
82         // Element strides
83         v_elem_stride = elem_dofs_size;
84         u_elem_stride = elem_qpts_size;
85         // Component strides
86         v_comp_stride = num_elem * elem_dofs_size;
87         u_comp_stride = num_elem * elem_qpts_size;
88       }
89       CeedInt num_threads = 1;
90       CeedInt num_t_col   = 1;
91       CeedInt shared_mem  = 0;
92       CeedInt max_P_Q     = CeedIntMax(P, Q);
93 
94       switch (dim) {
95         case 1:
96           num_threads = max_P_Q;
97           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
98           shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q));
99           shared_mem += sizeof(CeedScalar) * (P * Q);
100           break;
101         case 2:
102           num_threads = max_P_Q;
103           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
104           shared_mem += P * Q * sizeof(CeedScalar);  // for sT
105           // for reforming rU we need P x P, and for the intermediate output we need P x Q
106           shared_mem += num_t_col * (P * max_P_Q * sizeof(CeedScalar));
107           break;
108         case 3:
109           num_threads = max_P_Q * max_P_Q;
110           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
111           shared_mem += sizeof(CeedScalar) * (P * Q);  // for sT
112           // rU needs P^2 x P, the intermediate output needs max(P^2 x Q, P x Q^2)
113           shared_mem += sizeof(CeedScalar) * num_t_col * (CeedIntMax(P * P * max_P_Q, P * Q * Q));
114           break;
115       }
116       CeedInt grid   = CeedDivUpInt(num_elem, num_t_col);
117       void   *args[] = {&impl->d_interp_1d, &d_u, &u_elem_stride, &u_comp_stride, &d_v, &v_elem_stride, &v_comp_stride, &num_elem};
118 
119       if (t_mode == CEED_TRANSPOSE) {
120         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, apply_add ? impl->InterpTransposeAdd : impl->InterpTranspose, NULL, grid, num_threads,
121                                                     num_t_col, 1, shared_mem, args));
122       } else {
123         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Interp, NULL, grid, num_threads, num_t_col, 1, shared_mem, args));
124       }
125     } break;
126     case CEED_EVAL_GRAD: {
127       // Define element sizes for dofs/quad
128       CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim);
129       CeedInt elem_dofs_size = CeedIntPow(P_1d, dim);
130 
131       // In CEED_NOTRANSPOSE mode:
132       // d_u is (P^dim x nc), column-major layout (nc = num_comp)
133       // d_v is (Q^dim x nc x dim), column-major layout (nc = num_comp)
134       // In CEED_TRANSPOSE mode, the sizes of d_u and d_v are switched.
135 
136       // E-vector ordering -------------- Q-vector ordering
137       //                                  dim
138       //  component                        component
139       //    elem                              elem
140       //       node                            node
141 
142       // ---  Define strides for NOTRANSPOSE mode: ---
143       // Input (d_u) is E-vector, output (d_v) is Q-vector
144 
145       // Element strides
146       CeedInt u_elem_stride = elem_dofs_size;
147       CeedInt v_elem_stride = elem_qpts_size;
148       // Component strides
149       CeedInt u_comp_stride = num_elem * elem_dofs_size;
150       CeedInt v_comp_stride = num_elem * elem_qpts_size;
151       // Dimension strides
152       CeedInt u_dim_stride = 0;
153       CeedInt v_dim_stride = num_elem * elem_qpts_size * num_comp;
154       if (t_mode == CEED_TRANSPOSE) {
155         // Input (d_u) is Q-vector, output (d_v) is E-vector
156         // Element strides
157         v_elem_stride = elem_dofs_size;
158         u_elem_stride = elem_qpts_size;
159         // Component strides
160         v_comp_stride = num_elem * elem_dofs_size;
161         u_comp_stride = num_elem * elem_qpts_size;
162         // Dimension strides
163         v_dim_stride = 0;
164         u_dim_stride = num_elem * elem_qpts_size * num_comp;
165       }
166       CeedInt num_threads = 1;
167       CeedInt num_t_col   = 1;
168       CeedInt shared_mem  = 0;
169       CeedInt max_P_Q     = CeedIntMax(P, Q);
170 
171       switch (dim) {
172         case 1:
173           num_threads = max_P_Q;
174           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
175           shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q));
176           shared_mem += sizeof(CeedScalar) * (P * Q);
177           break;
178         case 2:
179           num_threads = max_P_Q;
180           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
181           shared_mem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
182           // for reforming rU we need P x P, and for the intermediate output we need P x Q
183           shared_mem += sizeof(CeedScalar) * num_t_col * (P * max_P_Q);
184           break;
185         case 3:
186           num_threads = max_P_Q * max_P_Q;
187           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
188           shared_mem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
189           // rU needs P^2 x P, the intermediate outputs need (P^2 x Q + P x Q^2)
190           shared_mem += sizeof(CeedScalar) * num_t_col * CeedIntMax(P * P * P, (P * P * Q) + (P * Q * Q));
191           break;
192       }
193       CeedInt grid   = CeedDivUpInt(num_elem, num_t_col);
194       void   *args[] = {&impl->d_interp_1d, &impl->d_grad_1d, &d_u,          &u_elem_stride, &u_comp_stride, &u_dim_stride, &d_v,
195                         &v_elem_stride,     &v_comp_stride,   &v_dim_stride, &num_elem};
196 
197       if (t_mode == CEED_TRANSPOSE) {
198         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, apply_add ? impl->GradTransposeAdd : impl->GradTranspose, NULL, grid, num_threads,
199                                                     num_t_col, 1, shared_mem, args));
200       } else {
201         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Grad, NULL, grid, num_threads, num_t_col, 1, shared_mem, args));
202       }
203     } break;
204     case CEED_EVAL_WEIGHT: {
205       CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
206       CeedCheck(impl->d_q_weight_1d, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weight_1d not set", CeedEvalModes[e_mode]);
207       CeedInt elem_dofs_size = CeedIntPow(Q, dim);
208       CeedInt num_threads    = 1;
209       CeedInt num_t_col      = 1;
210       CeedInt shared_mem     = 0;
211 
212       switch (dim) {
213         case 1:
214           num_threads = Q;
215           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
216           shared_mem += sizeof(CeedScalar) * Q;              // for d_q_weight_1d
217           shared_mem += sizeof(CeedScalar) * num_t_col * Q;  // for output
218           break;
219         case 2:
220           num_threads = Q;
221           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
222           shared_mem += sizeof(CeedScalar) * Q;  // for d_q_weight_1d
223           break;
224         case 3:
225           num_threads = Q * Q;
226           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
227           shared_mem += sizeof(CeedScalar) * Q;  // for d_q_weight_1d
228           break;
229       }
230       CeedInt grid   = CeedDivUpInt(num_elem, num_t_col);
231       void   *args[] = {&impl->d_q_weight_1d, &d_v, &elem_dofs_size, &num_elem};
232 
233       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Weight, NULL, grid, num_threads, num_t_col, 1, shared_mem, args));
234     } break;
235     // LCOV_EXCL_START
236     case CEED_EVAL_DIV:
237     case CEED_EVAL_CURL:
238       return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[e_mode]);
239     case CEED_EVAL_NONE:
240       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
241       // LCOV_EXCL_STOP
242   }
243 
244   // Must sync to ensure completeness
245   ceed_magma_queue_sync(data->queue);
246 
247   // Restore vectors
248   if (e_mode != CEED_EVAL_WEIGHT) {
249     CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
250   }
251   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
252   CeedCallBackend(CeedDestroy(&ceed));
253   return CEED_ERROR_SUCCESS;
254 }
255 
256 static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, CeedVector v) {
257   CeedCallBackend(CeedBasisApplyCore_Magma(basis, false, num_elem, t_mode, e_mode, u, v));
258   return CEED_ERROR_SUCCESS;
259 }
260 
261 static int CeedBasisApplyAdd_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, CeedVector v) {
262   CeedCallBackend(CeedBasisApplyCore_Magma(basis, true, num_elem, t_mode, e_mode, u, v));
263   return CEED_ERROR_SUCCESS;
264 }
265 
266 //------------------------------------------------------------------------------
267 // Basis apply - tensor AtPoints
268 //------------------------------------------------------------------------------
269 int CeedBasisApplyAtPoints_Magma(CeedBasis basis, const CeedInt num_elem, const CeedInt *num_points, CeedTransposeMode t_mode, CeedEvalMode eval_mode,
270                                  CeedVector x_ref, CeedVector u, CeedVector v) {
271   return CeedError(CeedBasisReturnCeed(basis), CEED_ERROR_BACKEND, "Backend does not implement CeedBasisApplyAtPoints");
272 }
273 
274 //------------------------------------------------------------------------------
275 // Basis apply - non-tensor
276 //------------------------------------------------------------------------------
277 static int CeedBasisApplyNonTensorCore_Magma(CeedBasis basis, bool apply_add, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode,
278                                              CeedVector u, CeedVector v) {
279   Ceed                      ceed;
280   Ceed_Magma               *data;
281   CeedInt                   num_comp, num_nodes, num_qpts, P, Q, N;
282   const CeedScalar         *d_u;
283   CeedScalar               *d_v;
284   CeedBasisNonTensor_Magma *impl;
285 
286   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
287   CeedCallBackend(CeedGetData(ceed, &data));
288   CeedCallBackend(CeedBasisGetData(basis, &impl));
289   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
290   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
291   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
292   P = num_nodes;
293   Q = num_qpts;
294   N = num_elem * num_comp;
295 
296   // Read vectors
297   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
298   else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
299   if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
300   else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
301 
302   // Compile kernels for N as needed
303   CeedInt iN = 0;
304   if (P <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q && (e_mode != CEED_EVAL_WEIGHT || !impl->Weight)) {
305     CeedInt n_array[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_KERNEL_N_VALUES};
306     CeedInt diff                                      = abs(n_array[iN] - N), idiff;
307 
308     for (CeedInt in = iN + 1; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
309       idiff = abs(n_array[in] - N);
310       if (idiff < diff) {
311         iN   = in;
312         diff = idiff;
313       }
314     }
315 
316     if (!impl->NB_interp[iN]) {
317       CeedFESpace fe_space;
318       CeedInt     q_comp_interp, q_comp_deriv;
319       Ceed        ceed_delegate;
320       char       *basis_kernel_source;
321       const char *basis_kernel_path, *weight_kernel_path;
322       char      **file_paths     = NULL;
323       CeedInt     num_file_paths = 0;
324       magma_int_t arch           = magma_getdevice_arch();
325 
326       // Tuning parameters for NB
327       CeedCallBackend(CeedBasisGetFESpace(basis, &fe_space));
328       CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
329       switch (fe_space) {
330         case CEED_FE_SPACE_H1:
331           CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_deriv));
332           break;
333         case CEED_FE_SPACE_HDIV:
334           CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_deriv));
335           break;
336         case CEED_FE_SPACE_HCURL:
337           CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_deriv));
338           break;
339       }
340       impl->NB_interp[iN]   = nontensor_rtc_get_nb(arch, 'n', q_comp_interp, P, Q, n_array[iN]);
341       impl->NB_interp_t[iN] = nontensor_rtc_get_nb(arch, 't', q_comp_interp, P, Q, n_array[iN]);
342       impl->NB_deriv[iN]    = nontensor_rtc_get_nb(arch, 'n', q_comp_deriv, P, Q, n_array[iN]);
343       impl->NB_deriv_t[iN]  = nontensor_rtc_get_nb(arch, 't', q_comp_deriv, P, Q, n_array[iN]);
344 
345       // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
346       CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
347 
348       // Compile kernels
349       CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-interp-deriv-nontensor.h", &basis_kernel_path));
350       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
351       CeedCallBackend(CeedLoadSourceAndInitializeBuffer(ceed, basis_kernel_path, &num_file_paths, &file_paths, &basis_kernel_source));
352       if (!impl->Weight) {
353         CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path));
354         CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_kernel_path, &num_file_paths, &file_paths, &basis_kernel_source));
355       }
356       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
357       CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module[iN], 8, "BASIS_Q_COMP_INTERP", q_comp_interp,
358                                        "BASIS_Q_COMP_DERIV", q_comp_deriv, "BASIS_P", P, "BASIS_Q", Q, "BASIS_NB_INTERP_N", impl->NB_interp[iN],
359                                        "BASIS_NB_INTERP_T", impl->NB_interp_t[iN], "BASIS_NB_DERIV_N", impl->NB_deriv[iN], "BASIS_NB_DERIV_T",
360                                        impl->NB_deriv_t[iN]));
361       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_interp_nontensor_n", &impl->Interp[iN]));
362       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_interp_nontensor_t", &impl->InterpTranspose[iN]));
363       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_interp_nontensor_ta", &impl->InterpTransposeAdd[iN]));
364       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_deriv_nontensor_n", &impl->Deriv[iN]));
365       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_deriv_nontensor_t", &impl->DerivTranspose[iN]));
366       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_deriv_nontensor_ta", &impl->DerivTransposeAdd[iN]));
367       if (!impl->Weight) {
368         CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_weight_nontensor", &impl->Weight));
369         CeedCallBackend(CeedFree(&weight_kernel_path));
370       }
371       CeedCallBackend(CeedFree(&basis_kernel_path));
372       CeedCallBackend(CeedFree(&basis_kernel_source));
373       for (CeedInt i = 0; i < num_file_paths; i++) CeedCallBackend(CeedFree(&file_paths[i]));
374       CeedCallBackend(CeedFree(&file_paths));
375       CeedCallBackend(CeedDestroy(&ceed_delegate));
376     }
377   }
378 
379   // Apply basis operation
380   if (e_mode != CEED_EVAL_WEIGHT) {
381     const CeedScalar *d_b = NULL;
382     CeedInt           q_comp, NB, M, K;
383     CeedMagmaFunction Kernel;
384 
385     switch (e_mode) {
386       case CEED_EVAL_INTERP:
387         d_b = impl->d_interp;
388         break;
389       case CEED_EVAL_GRAD:
390         d_b = impl->d_grad;
391         break;
392       case CEED_EVAL_DIV:
393         d_b = impl->d_div;
394         break;
395       case CEED_EVAL_CURL:
396         d_b = impl->d_curl;
397         break;
398       // LCOV_EXCL_START
399       case CEED_EVAL_WEIGHT:
400       case CEED_EVAL_NONE:
401         return CeedError(ceed, CEED_ERROR_BACKEND, "%s does not make sense in this context", CeedEvalModes[e_mode]);
402         // LCOV_EXCL_STOP
403     }
404     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, e_mode, &q_comp));
405     M = (t_mode == CEED_TRANSPOSE) ? P : Q, K = (t_mode == CEED_TRANSPOSE) ? Q : P;
406 
407     if (P <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
408       if (e_mode == CEED_EVAL_INTERP) {
409         if (t_mode == CEED_TRANSPOSE) {
410           Kernel = apply_add ? impl->InterpTransposeAdd[iN] : impl->InterpTranspose[iN];
411           NB     = impl->NB_interp_t[iN];
412         } else {
413           Kernel = impl->Interp[iN];
414           NB     = impl->NB_interp[iN];
415         }
416       } else {
417         if (t_mode == CEED_TRANSPOSE) {
418           Kernel = apply_add ? impl->DerivTransposeAdd[iN] : impl->DerivTranspose[iN];
419           NB     = impl->NB_deriv_t[iN];
420         } else {
421           Kernel = impl->Deriv[iN];
422           NB     = impl->NB_deriv[iN];
423         }
424       }
425       CeedInt num_t_col    = MAGMA_BASIS_NTCOL(M, MAGMA_MAXTHREADS_1D);
426       CeedInt grid         = CeedDivUpInt(N, num_t_col * NB);
427       CeedInt shared_mem_A = P * Q * sizeof(CeedScalar);
428       CeedInt shared_mem_B = num_t_col * K * NB * sizeof(CeedScalar);
429       CeedInt shared_mem   = (t_mode != CEED_TRANSPOSE && q_comp > 1) ? (shared_mem_A + shared_mem_B) : CeedIntMax(shared_mem_A, shared_mem_B);
430       void   *args[]       = {&N, &d_b, &d_u, &d_v};
431 
432       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, Kernel, NULL, grid, M, num_t_col, 1, shared_mem, args));
433     } else {
434       for (CeedInt d = 0; d < q_comp; d++) {
435         if (t_mode == CEED_TRANSPOSE) {
436           const CeedScalar beta = (apply_add || (d > 0)) ? 1.0 : 0.0;
437           magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, N, Q, 1.0, d_b + d * P * Q, P, d_u + d * N * Q, Q, beta, d_v, P, data->queue);
438         } else {
439           magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, N, P, 1.0, d_b + d * P * Q, P, d_u, P, 0.0, d_v + d * N * Q, Q, data->queue);
440         }
441       }
442     }
443   } else {
444     CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
445     CeedCheck(impl->d_q_weight, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weight not set", CeedEvalModes[e_mode]);
446     CeedInt num_t_col  = MAGMA_BASIS_NTCOL(Q, MAGMA_MAXTHREADS_1D);
447     CeedInt grid       = CeedDivUpInt(num_elem, num_t_col);
448     CeedInt shared_mem = Q * sizeof(CeedScalar) + num_t_col * Q * sizeof(CeedScalar);
449     void   *args[]     = {&num_elem, &impl->d_q_weight, &d_v};
450 
451     CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Weight, NULL, grid, Q, num_t_col, 1, shared_mem, args));
452   }
453 
454   // Must sync to ensure completeness
455   ceed_magma_queue_sync(data->queue);
456 
457   // Restore vectors
458   if (e_mode != CEED_EVAL_WEIGHT) {
459     CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
460   }
461   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
462   CeedCallBackend(CeedDestroy(&ceed));
463   return CEED_ERROR_SUCCESS;
464 }
465 
466 static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u,
467                                          CeedVector v) {
468   CeedCallBackend(CeedBasisApplyNonTensorCore_Magma(basis, false, num_elem, t_mode, e_mode, u, v));
469   return CEED_ERROR_SUCCESS;
470 }
471 
472 static int CeedBasisApplyAddNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u,
473                                             CeedVector v) {
474   CeedCallBackend(CeedBasisApplyNonTensorCore_Magma(basis, true, num_elem, t_mode, e_mode, u, v));
475   return CEED_ERROR_SUCCESS;
476 }
477 
478 //------------------------------------------------------------------------------
479 // Destroy tensor basis
480 //------------------------------------------------------------------------------
481 static int CeedBasisDestroy_Magma(CeedBasis basis) {
482   Ceed             ceed;
483   CeedBasis_Magma *impl;
484 
485   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
486   CeedCallBackend(CeedBasisGetData(basis, &impl));
487 #ifdef CEED_MAGMA_USE_HIP
488   CeedCallHip(ceed, hipModuleUnload(impl->module));
489 #else
490   CeedCallCuda(ceed, cuModuleUnload(impl->module));
491 #endif
492   CeedCallBackend(magma_free(impl->d_interp_1d));
493   CeedCallBackend(magma_free(impl->d_grad_1d));
494   if (impl->d_q_weight_1d) CeedCallBackend(magma_free(impl->d_q_weight_1d));
495   CeedCallBackend(CeedFree(&impl));
496   CeedCallBackend(CeedDestroy(&ceed));
497   return CEED_ERROR_SUCCESS;
498 }
499 
500 //------------------------------------------------------------------------------
501 // Destroy non-tensor basis
502 //------------------------------------------------------------------------------
503 static int CeedBasisDestroyNonTensor_Magma(CeedBasis basis) {
504   Ceed                      ceed;
505   CeedBasisNonTensor_Magma *impl;
506 
507   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
508   CeedCallBackend(CeedBasisGetData(basis, &impl));
509   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
510     if (impl->module[in]) {
511 #ifdef CEED_MAGMA_USE_HIP
512       CeedCallHip(ceed, hipModuleUnload(impl->module[in]));
513 #else
514       CeedCallCuda(ceed, cuModuleUnload(impl->module[in]));
515 #endif
516     }
517   }
518   CeedCallBackend(magma_free(impl->d_interp));
519   CeedCallBackend(magma_free(impl->d_grad));
520   CeedCallBackend(magma_free(impl->d_div));
521   CeedCallBackend(magma_free(impl->d_curl));
522   if (impl->d_q_weight) CeedCallBackend(magma_free(impl->d_q_weight));
523   CeedCallBackend(CeedFree(&impl));
524   CeedCallBackend(CeedDestroy(&ceed));
525   return CEED_ERROR_SUCCESS;
526 }
527 
528 //------------------------------------------------------------------------------
529 // Create tensor
530 //------------------------------------------------------------------------------
531 int CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
532                                   const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
533   Ceed             ceed, ceed_delegate;
534   Ceed_Magma      *data;
535   char            *basis_kernel_source;
536   const char      *interp_kernel_path, *grad_kernel_path, *weight_kernel_path;
537   char           **file_paths     = NULL;
538   CeedInt          num_file_paths = 0;
539   CeedInt          num_comp;
540   CeedBasis_Magma *impl;
541 
542   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
543   CeedCallBackend(CeedGetData(ceed, &data));
544   CeedCallBackend(CeedCalloc(1, &impl));
545 
546   // Copy basis data to GPU
547   if (q_weight_1d) {
548     CeedCallBackend(magma_malloc((void **)&impl->d_q_weight_1d, Q_1d * sizeof(q_weight_1d[0])));
549     magma_setvector(Q_1d, sizeof(q_weight_1d[0]), q_weight_1d, 1, impl->d_q_weight_1d, 1, data->queue);
550   }
551   CeedCallBackend(magma_malloc((void **)&impl->d_interp_1d, Q_1d * P_1d * sizeof(interp_1d[0])));
552   magma_setvector(Q_1d * P_1d, sizeof(interp_1d[0]), interp_1d, 1, impl->d_interp_1d, 1, data->queue);
553   CeedCallBackend(magma_malloc((void **)&impl->d_grad_1d, Q_1d * P_1d * sizeof(grad_1d[0])));
554   magma_setvector(Q_1d * P_1d, sizeof(grad_1d[0]), grad_1d, 1, impl->d_grad_1d, 1, data->queue);
555 
556   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
557   CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
558 
559   // Compile kernels
560   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
561   {
562     char   *interp_kernel_name_base = "ceed/jit-source/magma/magma-basis-interp";
563     CeedInt interp_kernel_name_len  = strlen(interp_kernel_name_base) + 6;
564     char    interp_kernel_name[interp_kernel_name_len];
565 
566     snprintf(interp_kernel_name, interp_kernel_name_len, "%s-%" CeedInt_FMT "d.h", interp_kernel_name_base, dim);
567     CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_kernel_name, &interp_kernel_path));
568   }
569   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
570   CeedCallBackend(CeedLoadSourceAndInitializeBuffer(ceed, interp_kernel_path, &num_file_paths, &file_paths, &basis_kernel_source));
571   {
572     char   *grad_kernel_name_base = "ceed/jit-source/magma/magma-basis-grad";
573     CeedInt grad_kernel_name_len  = strlen(grad_kernel_name_base) + 6;
574     char    grad_kernel_name[grad_kernel_name_len];
575 
576     snprintf(grad_kernel_name, grad_kernel_name_len, "%s-%" CeedInt_FMT "d.h", grad_kernel_name_base, dim);
577     CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_kernel_name, &grad_kernel_path));
578   }
579   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_kernel_path, &num_file_paths, &file_paths, &basis_kernel_source));
580   {
581     char   *weight_kernel_name_base = "ceed/jit-source/magma/magma-basis-weight";
582     CeedInt weight_kernel_name_len  = strlen(weight_kernel_name_base) + 6;
583     char    weight_kernel_name[weight_kernel_name_len];
584 
585     snprintf(weight_kernel_name, weight_kernel_name_len, "%s-%" CeedInt_FMT "d.h", weight_kernel_name_base, dim);
586     CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_kernel_name, &weight_kernel_path));
587   }
588   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_kernel_path, &num_file_paths, &file_paths, &basis_kernel_source));
589   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
590   CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module, 5, "BASIS_DIM", dim, "BASIS_NUM_COMP", num_comp, "BASIS_P",
591                                    P_1d, "BASIS_Q", Q_1d, "BASIS_MAX_P_Q", CeedIntMax(P_1d, Q_1d)));
592   switch (dim) {
593     case 1:
594       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->Interp));
595       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->InterpTranspose));
596       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpta_1d_kernel", &impl->InterpTransposeAdd));
597       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->Grad));
598       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->GradTranspose));
599       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradta_1d_kernel", &impl->GradTransposeAdd));
600       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->Weight));
601       break;
602     case 2:
603       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->Interp));
604       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->InterpTranspose));
605       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpta_2d_kernel", &impl->InterpTransposeAdd));
606       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->Grad));
607       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->GradTranspose));
608       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradta_2d_kernel", &impl->GradTransposeAdd));
609       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->Weight));
610       break;
611     case 3:
612       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->Interp));
613       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->InterpTranspose));
614       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpta_3d_kernel", &impl->InterpTransposeAdd));
615       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->Grad));
616       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->GradTranspose));
617       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradta_3d_kernel", &impl->GradTransposeAdd));
618       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->Weight));
619       break;
620   }
621   CeedCallBackend(CeedFree(&interp_kernel_path));
622   CeedCallBackend(CeedFree(&grad_kernel_path));
623   CeedCallBackend(CeedFree(&weight_kernel_path));
624   CeedCallBackend(CeedFree(&basis_kernel_source));
625   for (CeedInt i = 0; i < num_file_paths; i++) CeedCallBackend(CeedFree(&file_paths[i]));
626   CeedCallBackend(CeedFree(&file_paths));
627 
628   CeedCallBackend(CeedBasisSetData(basis, impl));
629 
630   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma));
631   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAdd_Magma));
632   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAtPoints", CeedBasisApplyAtPoints_Magma));
633   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma));
634   CeedCallBackend(CeedDestroy(&ceed));
635   CeedCallBackend(CeedDestroy(&ceed_delegate));
636   return CEED_ERROR_SUCCESS;
637 }
638 
639 //------------------------------------------------------------------------------
640 // Create non-tensor H^1
641 //------------------------------------------------------------------------------
642 int CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad,
643                             const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
644   Ceed                      ceed;
645   Ceed_Magma               *data;
646   CeedBasisNonTensor_Magma *impl;
647 
648   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
649   CeedCallBackend(CeedGetData(ceed, &data));
650   CeedCallBackend(CeedCalloc(1, &impl));
651 
652   // Copy basis data to GPU
653   if (q_weight) {
654     CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0])));
655     magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue);
656   }
657   if (interp) {
658     CeedInt q_comp_interp;
659 
660     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
661     CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * q_comp_interp * sizeof(interp[0])));
662     magma_setvector(num_qpts * num_nodes * q_comp_interp, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue);
663   }
664   if (grad) {
665     CeedInt q_comp_grad;
666 
667     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad));
668     CeedCallBackend(magma_malloc((void **)&impl->d_grad, num_qpts * num_nodes * q_comp_grad * sizeof(grad[0])));
669     magma_setvector(num_qpts * num_nodes * q_comp_grad, sizeof(grad[0]), grad, 1, impl->d_grad, 1, data->queue);
670   }
671 
672   // Compile the weight kernel if it won't be compiled later on
673   if (num_nodes > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P || num_qpts > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
674     Ceed        ceed_delegate;
675     char       *basis_kernel_source;
676     const char *weight_kernel_path;
677 
678     // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
679     CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
680 
681     // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply)
682     CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path));
683     CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
684     CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source));
685     CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
686     CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module[0], 1, "BASIS_Q", num_qpts));
687     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[0], "magma_weight_nontensor", &impl->Weight));
688     CeedCallBackend(CeedFree(&weight_kernel_path));
689     CeedCallBackend(CeedFree(&basis_kernel_source));
690     CeedCallBackend(CeedDestroy(&ceed_delegate));
691   }
692 
693   CeedCallBackend(CeedBasisSetData(basis, impl));
694 
695   // Register backend functions
696   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
697   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Magma));
698   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
699   CeedCallBackend(CeedDestroy(&ceed));
700   return CEED_ERROR_SUCCESS;
701 }
702 
703 //------------------------------------------------------------------------------
704 // Create non-tensor H(div)
705 //------------------------------------------------------------------------------
706 int CeedBasisCreateHdiv_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp,
707                               const CeedScalar *div, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
708   Ceed                      ceed;
709   Ceed_Magma               *data;
710   CeedBasisNonTensor_Magma *impl;
711 
712   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
713   CeedCallBackend(CeedGetData(ceed, &data));
714   CeedCallBackend(CeedCalloc(1, &impl));
715 
716   // Copy basis data to GPU
717   if (q_weight) {
718     CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0])));
719     magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue);
720   }
721   if (interp) {
722     CeedInt q_comp_interp;
723 
724     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
725     CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * q_comp_interp * sizeof(interp[0])));
726     magma_setvector(num_qpts * num_nodes * q_comp_interp, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue);
727   }
728   if (div) {
729     CeedInt q_comp_div;
730 
731     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div));
732     CeedCallBackend(magma_malloc((void **)&impl->d_div, num_qpts * num_nodes * q_comp_div * sizeof(div[0])));
733     magma_setvector(num_qpts * num_nodes * q_comp_div, sizeof(div[0]), div, 1, impl->d_div, 1, data->queue);
734   }
735 
736   // Compile the weight kernel if it won't be compiled later on
737   if (num_nodes > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P || num_qpts > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
738     Ceed        ceed_delegate;
739     char       *basis_kernel_source;
740     const char *weight_kernel_path;
741 
742     // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
743     CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
744 
745     // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply)
746     CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path));
747     CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
748     CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source));
749     CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
750     CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module[0], 1, "BASIS_Q", num_qpts));
751     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[0], "magma_weight_nontensor", &impl->Weight));
752     CeedCallBackend(CeedFree(&weight_kernel_path));
753     CeedCallBackend(CeedFree(&basis_kernel_source));
754     CeedCallBackend(CeedDestroy(&ceed_delegate));
755   }
756 
757   CeedCallBackend(CeedBasisSetData(basis, impl));
758 
759   // Register backend functions
760   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
761   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Magma));
762   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
763   CeedCallBackend(CeedDestroy(&ceed));
764   return CEED_ERROR_SUCCESS;
765 }
766 
767 //------------------------------------------------------------------------------
768 // Create non-tensor H(curl)
769 //------------------------------------------------------------------------------
770 int CeedBasisCreateHcurl_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp,
771                                const CeedScalar *curl, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
772   Ceed                      ceed;
773   Ceed_Magma               *data;
774   CeedBasisNonTensor_Magma *impl;
775 
776   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
777   CeedCallBackend(CeedGetData(ceed, &data));
778   CeedCallBackend(CeedCalloc(1, &impl));
779 
780   // Copy basis data to GPU
781   if (q_weight) {
782     CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0])));
783     magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue);
784   }
785   if (interp) {
786     CeedInt q_comp_interp;
787 
788     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
789     CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * q_comp_interp * sizeof(interp[0])));
790     magma_setvector(num_qpts * num_nodes * q_comp_interp, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue);
791   }
792   if (curl) {
793     CeedInt q_comp_curl;
794 
795     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl));
796     CeedCallBackend(magma_malloc((void **)&impl->d_curl, num_qpts * num_nodes * q_comp_curl * sizeof(curl[0])));
797     magma_setvector(num_qpts * num_nodes * q_comp_curl, sizeof(curl[0]), curl, 1, impl->d_curl, 1, data->queue);
798   }
799 
800   // Compile the weight kernel if it won't be compiled later on
801   if (num_nodes > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P || num_qpts > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
802     Ceed        ceed_delegate;
803     char       *basis_kernel_source;
804     const char *weight_kernel_path;
805 
806     // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
807     CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
808 
809     // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply)
810     CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path));
811     CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
812     CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source));
813     CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
814     CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module[0], 1, "BASIS_Q", num_qpts));
815     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[0], "magma_weight_nontensor", &impl->Weight));
816     CeedCallBackend(CeedFree(&weight_kernel_path));
817     CeedCallBackend(CeedFree(&basis_kernel_source));
818     CeedCallBackend(CeedDestroy(&ceed_delegate));
819   }
820 
821   CeedCallBackend(CeedBasisSetData(basis, impl));
822 
823   // Register backend functions
824   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
825   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Magma));
826   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
827   CeedCallBackend(CeedDestroy(&ceed));
828   return CEED_ERROR_SUCCESS;
829 }
830 
831 //------------------------------------------------------------------------------
832