xref: /libCEED/backends/magma/ceed-magma-basis.c (revision 5cd6c1fb67d52eb6a42b887bb79c183682dd86ca)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <string.h>
12 
13 #ifdef CEED_MAGMA_USE_HIP
14 #include "../hip/ceed-hip-common.h"
15 #include "../hip/ceed-hip-compile.h"
16 #else
17 #include "../cuda/ceed-cuda-common.h"
18 #include "../cuda/ceed-cuda-compile.h"
19 #endif
20 #include "ceed-magma-common.h"
21 #include "ceed-magma.h"
22 
23 #ifdef __cplusplus
24 CEED_INTERN "C"
25 #endif
26     int
27     CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector U, CeedVector V) {
28   Ceed              ceed;
29   Ceed_Magma       *data;
30   CeedInt           dim, num_comp, num_dof, P_1d, Q_1d;
31   const CeedScalar *du;
32   CeedScalar       *dv;
33   CeedBasis_Magma  *impl;
34 
35   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
36   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
37   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
38   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_dof));
39 
40   CeedCallBackend(CeedGetData(ceed, &data));
41 
42   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
43   else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
44   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
45 
46   CeedCallBackend(CeedBasisGetData(basis, &impl));
47 
48   CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
49   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
50 
51   CeedDebug256(ceed, 4, "[CeedBasisApply_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, num_comp * CeedIntPow(P_1d, dim), num_comp);
52 
53   if (t_mode == CEED_TRANSPOSE) {
54     CeedSize length;
55 
56     CeedCallBackend(CeedVectorGetLength(V, &length));
57     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
58       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
59     } else {
60       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
61     }
62     ceed_magma_queue_sync(data->queue);
63   }
64 
65   switch (e_mode) {
66     case CEED_EVAL_INTERP: {
67       CeedInt P = P_1d, Q = Q_1d;
68 
69       if (t_mode == CEED_TRANSPOSE) {
70         P = Q_1d;
71         Q = P_1d;
72       }
73 
74       // Define element sizes for dofs/quad
75       CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim);
76       CeedInt elem_dofs_size = CeedIntPow(P_1d, dim);
77 
78       // E-vector ordering -------------- Q-vector ordering
79       //  component                        component
80       //    elem                             elem
81       //       node                            node
82 
83       // ---  Define strides for NOTRANSPOSE mode: ---
84       // Input (du) is E-vector, output (dv) is Q-vector
85 
86       // Element strides
87       CeedInt u_elem_stride = elem_dofs_size;
88       CeedInt v_elem_stride = elem_qpts_size;
89       // Component strides
90       CeedInt u_comp_stride = num_elem * elem_dofs_size;
91       CeedInt v_comp_stride = num_elem * elem_qpts_size;
92 
93       // ---  Swap strides for TRANSPOSE mode: ---
94       if (t_mode == CEED_TRANSPOSE) {
95         // Input (du) is Q-vector, output (dv) is E-vector
96         // Element strides
97         v_elem_stride = elem_dofs_size;
98         u_elem_stride = elem_qpts_size;
99         // Component strides
100         v_comp_stride = num_elem * elem_dofs_size;
101         u_comp_stride = num_elem * elem_qpts_size;
102       }
103       CeedInt num_threads = 1;
104       CeedInt num_t_col   = 1;
105       CeedInt shared_mem  = 0;
106       CeedInt max_P_Q     = CeedIntMax(P, Q);
107 
108       switch (dim) {
109         case 1:
110           num_threads = max_P_Q;
111           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
112           shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q));
113           shared_mem += sizeof(CeedScalar) * (P * Q);
114           break;
115         case 2:
116           num_threads = max_P_Q;
117           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
118           shared_mem += P * Q * sizeof(CeedScalar);                      // for sT
119           shared_mem += num_t_col * (P * max_P_Q * sizeof(CeedScalar));  // for reforming rU we need PxP, and for the intermediate output we need PxQ
120           break;
121         case 3:
122           num_threads = max_P_Q * max_P_Q;
123           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
124           shared_mem += sizeof(CeedScalar) * (P * Q);  // for sT
125           shared_mem += sizeof(CeedScalar) * num_t_col *
126                         (CeedIntMax(P * P * max_P_Q,
127                                     P * Q * Q));  // rU needs P^2xP, the intermediate output needs max(P^2xQ,PQ^2)
128       }
129       CeedInt grid   = (num_elem + num_t_col - 1) / num_t_col;
130       void   *args[] = {&impl->d_interp_1d, &du, &u_elem_stride, &u_comp_stride, &dv, &v_elem_stride, &v_comp_stride, &num_elem};
131 
132       if (t_mode == CEED_TRANSPOSE) {
133         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp_tr, grid, num_threads, num_t_col, 1, shared_mem, args));
134       } else {
135         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_interp, grid, num_threads, num_t_col, 1, shared_mem, args));
136       }
137     } break;
138     case CEED_EVAL_GRAD: {
139       CeedInt P = P_1d, Q = Q_1d;
140 
141       // In CEED_NOTRANSPOSE mode:
142       // du is (P^dim x nc), column-major layout (nc = num_comp)
143       // dv is (Q^dim x nc x dim), column-major layout (nc = num_comp)
144       // In CEED_TRANSPOSE mode, the sizes of du and dv are switched.
145       if (t_mode == CEED_TRANSPOSE) {
146         P = Q_1d;
147         Q = P_1d;
148       }
149 
150       // Define element sizes for dofs/quad
151       CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim);
152       CeedInt elem_dofs_size = CeedIntPow(P_1d, dim);
153 
154       // E-vector ordering -------------- Q-vector ordering
155       //                                  dim
156       //  component                        component
157       //    elem                              elem
158       //       node                            node
159 
160       // ---  Define strides for NOTRANSPOSE mode: ---
161       // Input (du) is E-vector, output (dv) is Q-vector
162 
163       // Element strides
164       CeedInt u_elem_stride = elem_dofs_size;
165       CeedInt v_elem_stride = elem_qpts_size;
166       // Component strides
167       CeedInt u_comp_stride = num_elem * elem_dofs_size;
168       CeedInt v_comp_stride = num_elem * elem_qpts_size;
169       // Dimension strides
170       CeedInt u_dim_stride = 0;
171       CeedInt v_dim_stride = num_elem * elem_qpts_size * num_comp;
172 
173       // ---  Swap strides for TRANSPOSE mode: ---
174       if (t_mode == CEED_TRANSPOSE) {
175         // Input (du) is Q-vector, output (dv) is E-vector
176         // Element strides
177         v_elem_stride = elem_dofs_size;
178         u_elem_stride = elem_qpts_size;
179         // Component strides
180         v_comp_stride = num_elem * elem_dofs_size;
181         u_comp_stride = num_elem * elem_qpts_size;
182         // Dimension strides
183         v_dim_stride = 0;
184         u_dim_stride = num_elem * elem_qpts_size * num_comp;
185       }
186       CeedInt num_threads = 1;
187       CeedInt num_t_col   = 1;
188       CeedInt shared_mem  = 0;
189       CeedInt max_P_Q     = CeedIntMax(P, Q);
190 
191       switch (dim) {
192         case 1:
193           num_threads = max_P_Q;
194           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
195           shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q));
196           shared_mem += sizeof(CeedScalar) * (P * Q);
197           break;
198         case 2:
199           num_threads = max_P_Q;
200           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
201           shared_mem += sizeof(CeedScalar) * 2 * P * Q;                  // for sTinterp and sTgrad
202           shared_mem += sizeof(CeedScalar) * num_t_col * (P * max_P_Q);  // for reforming rU we need PxP, and for the intermediate output we need PxQ
203           break;
204         case 3:
205           num_threads = max_P_Q * max_P_Q;
206           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
207           shared_mem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
208           shared_mem += sizeof(CeedScalar) * num_t_col *
209                         CeedIntMax(P * P * P,
210                                    (P * P * Q) + (P * Q * Q));  // rU needs P^2xP, the intermediate outputs need (P^2.Q + P.Q^2)
211       }
212       CeedInt grid   = (num_elem + num_t_col - 1) / num_t_col;
213       void   *args[] = {&impl->d_interp_1d, &impl->d_grad_1d, &du,           &u_elem_stride, &u_comp_stride, &u_dim_stride, &dv,
214                         &v_elem_stride,     &v_comp_stride,   &v_dim_stride, &num_elem};
215 
216       if (t_mode == CEED_TRANSPOSE) {
217         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad_tr, grid, num_threads, num_t_col, 1, shared_mem, args));
218       } else {
219         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_grad, grid, num_threads, num_t_col, 1, shared_mem, args));
220       }
221     } break;
222     case CEED_EVAL_WEIGHT: {
223       CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT inum_compatible with CEED_TRANSPOSE");
224       CeedInt Q              = Q_1d;
225       CeedInt elem_dofs_size = CeedIntPow(Q, dim);
226       CeedInt num_threads    = 1;
227       CeedInt num_t_col      = 1;
228       CeedInt shared_mem     = 0;
229 
230       switch (dim) {
231         case 1:
232           num_threads = Q;
233           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
234           shared_mem += sizeof(CeedScalar) * Q;              // for d_q_weight_1d
235           shared_mem += sizeof(CeedScalar) * num_t_col * Q;  // for output
236           break;
237         case 2:
238           num_threads = Q;
239           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
240           shared_mem += sizeof(CeedScalar) * Q;  // for d_q_weight_1d
241           break;
242         case 3:
243           num_threads = Q * Q;
244           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
245           shared_mem += sizeof(CeedScalar) * Q;  // for d_q_weight_1d
246       }
247       CeedInt grid   = (num_elem + num_t_col - 1) / num_t_col;
248       void   *args[] = {&impl->d_q_weight_1d, &dv, &elem_dofs_size, &num_elem};
249 
250       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->magma_weight, grid, num_threads, num_t_col, 1, shared_mem, args));
251     } break;
252     // LCOV_EXCL_START
253     case CEED_EVAL_DIV:
254       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
255     case CEED_EVAL_CURL:
256       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
257     case CEED_EVAL_NONE:
258       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
259       // LCOV_EXCL_STOP
260   }
261 
262   // must sync to ensure completeness
263   ceed_magma_queue_sync(data->queue);
264 
265   if (e_mode != CEED_EVAL_WEIGHT) {
266     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
267   }
268   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
269   return CEED_ERROR_SUCCESS;
270 }
271 
272 #ifdef __cplusplus
273 CEED_INTERN "C"
274 #endif
275     int
276     CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector U, CeedVector V) {
277   Ceed                      ceed;
278   Ceed_Magma               *data;
279   CeedInt                   dim, num_comp, num_dof, num_qpts, NB = 1;
280   const CeedScalar         *du;
281   CeedScalar               *dv;
282   CeedBasisNonTensor_Magma *impl;
283   CeedMagmaFunction        *interp, *grad;
284 
285   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
286   CeedCallBackend(CeedGetData(ceed, &data));
287   magma_int_t arch = magma_getdevice_arch();
288 
289   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
290   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
291   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_dof));
292   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
293   CeedInt P = num_dof, Q = num_qpts, N = num_elem * num_comp;
294 
295   if (U != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(U, CEED_MEM_DEVICE, &du));
296   else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
297   CeedCallBackend(CeedVectorGetArrayWrite(V, CEED_MEM_DEVICE, &dv));
298 
299   CeedCallBackend(CeedBasisGetData(basis, &impl));
300 
301   CeedDebug256(ceed, 4, "[CeedBasisApplyNonTensor_Magma] vsize=%" CeedInt_FMT ", comp = %" CeedInt_FMT, num_comp * num_dof, num_comp);
302 
303   if (t_mode == CEED_TRANSPOSE) {
304     CeedSize length;
305 
306     CeedCallBackend(CeedVectorGetLength(V, &length));
307     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
308       magmablas_slaset(MagmaFull, length, 1, 0., 0., (float *)dv, length, data->queue);
309     } else {
310       magmablas_dlaset(MagmaFull, length, 1, 0., 0., (double *)dv, length, data->queue);
311     }
312     ceed_magma_queue_sync(data->queue);
313   }
314 
315   CeedInt n_array[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_N_VALUES};
316   CeedInt iN                                        = 0;
317   CeedInt diff                                      = abs(n_array[iN] - N);
318 
319   for (CeedInt in = iN + 1; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
320     CeedInt idiff = abs(n_array[in] - N);
321     if (idiff < diff) {
322       iN   = in;
323       diff = idiff;
324     }
325   }
326 
327   NB     = nontensor_rtc_get_nb(arch, 'd', e_mode, t_mode, P, n_array[iN], Q);
328   interp = (t_mode == CEED_TRANSPOSE) ? &impl->magma_interp_tr_nontensor[iN] : &impl->magma_interp_nontensor[iN];
329   grad   = (t_mode == CEED_TRANSPOSE) ? &impl->magma_grad_tr_nontensor[iN] : &impl->magma_grad_nontensor[iN];
330 
331   switch (e_mode) {
332     case CEED_EVAL_INTERP: {
333       CeedInt P = num_dof, Q = num_qpts;
334       if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
335         CeedInt M          = (t_mode == CEED_TRANSPOSE) ? P : Q;
336         CeedInt K          = (t_mode == CEED_TRANSPOSE) ? Q : P;
337         CeedInt num_t_col  = MAGMA_NONTENSOR_BASIS_NTCOL(M);
338         CeedInt shared_mem = 0, shared_mem_A = 0, shared_mem_B = 0;
339         shared_mem_B += num_t_col * K * NB * sizeof(CeedScalar);
340         shared_mem_A += (t_mode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
341         shared_mem = (t_mode == CEED_TRANSPOSE) ? (shared_mem_A + shared_mem_B) : CeedIntMax(shared_mem_A, shared_mem_B);
342 
343         CeedInt       grid    = MAGMA_CEILDIV(MAGMA_CEILDIV(N, NB), num_t_col);
344         magma_trans_t trans_A = (t_mode == CEED_TRANSPOSE) ? MagmaNoTrans : MagmaTrans;
345         magma_trans_t trans_B = MagmaNoTrans;
346         CeedScalar    alpha = 1.0, beta = 0.0;
347 
348         void *args[] = {&trans_A, &trans_B, &N, &alpha, &impl->d_interp, &P, &du, &K, &beta, &dv, &M};
349         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, *interp, grid, M, num_t_col, 1, shared_mem, args));
350       } else {
351         if (t_mode == CEED_TRANSPOSE) {
352           magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, num_elem * num_comp, Q, 1.0, impl->d_interp, P, du, Q, 0.0, dv, P, data->queue);
353         } else {
354           magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, num_elem * num_comp, P, 1.0, impl->d_interp, P, du, P, 0.0, dv, Q, data->queue);
355         }
356       }
357     } break;
358 
359     case CEED_EVAL_GRAD: {
360       CeedInt P = num_dof, Q = num_qpts;
361       if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
362         CeedInt M          = (t_mode == CEED_TRANSPOSE) ? P : Q;
363         CeedInt K          = (t_mode == CEED_TRANSPOSE) ? Q : P;
364         CeedInt num_t_col  = MAGMA_NONTENSOR_BASIS_NTCOL(M);
365         CeedInt shared_mem = 0, shared_mem_A = 0, shared_mem_B = 0;
366         shared_mem_B += num_t_col * K * NB * sizeof(CeedScalar);
367         shared_mem_A += (t_mode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
368         shared_mem = shared_mem_A + shared_mem_B;
369 
370         CeedInt       grid    = MAGMA_CEILDIV(MAGMA_CEILDIV(N, NB), num_t_col);
371         magma_trans_t trans_A = (t_mode == CEED_TRANSPOSE) ? MagmaNoTrans : MagmaTrans;
372         magma_trans_t trans_B = MagmaNoTrans;
373 
374         void *args[] = {&trans_A, &trans_B, &N, &impl->d_grad, &P, &du, &K, &dv, &M};
375         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, *grad, grid, M, num_t_col, 1, shared_mem, args));
376       } else {
377         if (t_mode == CEED_TRANSPOSE) {
378           CeedScalar beta = 0.0;
379           for (int d = 0; d < dim; d++) {
380             if (d > 0) beta = 1.0;
381             magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, num_elem * num_comp, Q, 1.0, impl->d_grad + d * P * Q, P,
382                                  du + d * num_elem * num_comp * Q, Q, beta, dv, P, data->queue);
383           }
384         } else {
385           for (int d = 0; d < dim; d++)
386             magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, num_elem * num_comp, P, 1.0, impl->d_grad + d * P * Q, P, du, P, 0.0,
387                                  dv + d * num_elem * num_comp * Q, Q, data->queue);
388         }
389       }
390     } break;
391 
392     case CEED_EVAL_WEIGHT: {
393       CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT inum_compatible with CEED_TRANSPOSE");
394 
395       int elemsPerBlock = 1;  // basis->Q_1d < 7 ? optElems[basis->Q_1d] : 1;
396       int grid          = num_elem / elemsPerBlock + ((num_elem / elemsPerBlock * elemsPerBlock < num_elem) ? 1 : 0);
397 
398       magma_weight_nontensor(grid, num_qpts, num_elem, num_qpts, impl->d_q_weight, dv, data->queue);
399     } break;
400 
401     // LCOV_EXCL_START
402     case CEED_EVAL_DIV:
403       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
404     case CEED_EVAL_CURL:
405       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
406     case CEED_EVAL_NONE:
407       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
408       // LCOV_EXCL_STOP
409   }
410 
411   // must sync to ensure completeness
412   ceed_magma_queue_sync(data->queue);
413 
414   if (e_mode != CEED_EVAL_WEIGHT) {
415     CeedCallBackend(CeedVectorRestoreArrayRead(U, &du));
416   }
417   CeedCallBackend(CeedVectorRestoreArray(V, &dv));
418   return CEED_ERROR_SUCCESS;
419 }
420 
421 #ifdef __cplusplus
422 CEED_INTERN "C"
423 #endif
424     int
425     CeedBasisDestroy_Magma(CeedBasis basis) {
426   Ceed             ceed;
427   CeedBasis_Magma *impl;
428 
429   CeedCallBackend(CeedBasisGetData(basis, &impl));
430   CeedCallBackend(magma_free(impl->d_q_ref_1d));
431   CeedCallBackend(magma_free(impl->d_interp_1d));
432   CeedCallBackend(magma_free(impl->d_grad_1d));
433   CeedCallBackend(magma_free(impl->d_q_weight_1d));
434   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
435 #ifdef CEED_MAGMA_USE_HIP
436   CeedCallHip(ceed, hipModuleUnload(impl->module));
437 #else
438   CeedCallCuda(ceed, cuModuleUnload(impl->module));
439 #endif
440   CeedCallBackend(CeedFree(&impl));
441   return CEED_ERROR_SUCCESS;
442 }
443 
444 #ifdef __cplusplus
445 CEED_INTERN "C"
446 #endif
447     int
448     CeedBasisDestroyNonTensor_Magma(CeedBasis basis) {
449   Ceed                      ceed;
450   CeedBasisNonTensor_Magma *impl;
451 
452   CeedCallBackend(CeedBasisGetData(basis, &impl));
453   CeedCallBackend(magma_free(impl->d_q_ref));
454   CeedCallBackend(magma_free(impl->d_interp));
455   CeedCallBackend(magma_free(impl->d_grad));
456   CeedCallBackend(magma_free(impl->d_q_weight));
457   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
458 #ifdef CEED_MAGMA_USE_HIP
459   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
460     CeedCallHip(ceed, hipModuleUnload(impl->module[in]));
461   }
462 #else
463   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
464     CeedCallCuda(ceed, cuModuleUnload(impl->module[in]));
465   }
466 #endif
467   CeedCallBackend(CeedFree(&impl));
468   return CEED_ERROR_SUCCESS;
469 }
470 
471 #ifdef __cplusplus
472 CEED_INTERN "C"
473 #endif
474     int
475     CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
476                                   const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
477   Ceed             ceed, ceed_delegate;
478   Ceed_Magma      *data;
479   char            *magma_common_path, *interp_path, *grad_path, *weight_path, *basis_kernel_source;
480   CeedInt          num_comp = 0;
481   CeedBasis_Magma *impl;
482 
483   CeedCallBackend(CeedCalloc(1, &impl));
484   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
485 
486   // Check for supported parameters
487   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
488   CeedCallBackend(CeedGetData(ceed, &data));
489 
490   // Compile kernels
491   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_defs.h", &magma_common_path));
492   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
493   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
494   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_tensor.h", &magma_common_path));
495   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, magma_common_path, &basis_kernel_source));
496   char   *interp_name_base = "ceed/jit-source/magma/interp";
497   CeedInt interp_name_len  = strlen(interp_name_base) + 6;
498   char    interp_name[interp_name_len];
499 
500   snprintf(interp_name, interp_name_len, "%s-%" CeedInt_FMT "d.h", interp_name_base, dim);
501   CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_name, &interp_path));
502   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
503   char   *grad_name_base = "ceed/jit-source/magma/grad";
504   CeedInt grad_name_len  = strlen(grad_name_base) + 6;
505   char    grad_name[grad_name_len];
506 
507   snprintf(grad_name, grad_name_len, "%s-%" CeedInt_FMT "d.h", grad_name_base, dim);
508   CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_name, &grad_path));
509   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
510   char   *weight_name_base = "ceed/jit-source/magma/weight";
511   CeedInt weight_name_len  = strlen(weight_name_base) + 6;
512   char    weight_name[weight_name_len];
513 
514   snprintf(weight_name, weight_name_len, "%s-%" CeedInt_FMT "d.h", weight_name_base, dim);
515   CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_name, &weight_path));
516   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_path, &basis_kernel_source));
517   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
518   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip
519   // data
520   CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
521   CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module, 5, "DIM", dim, "NCOMP", num_comp, "P", P_1d, "Q", Q_1d, "MAXPQ",
522                                    CeedIntMax(P_1d, Q_1d)));
523 
524   // Kernel setup
525   switch (dim) {
526     case 1:
527       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->magma_interp));
528       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->magma_interp_tr));
529       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->magma_grad));
530       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->magma_grad_tr));
531       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->magma_weight));
532       break;
533     case 2:
534       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->magma_interp));
535       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->magma_interp_tr));
536       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->magma_grad));
537       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->magma_grad_tr));
538       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->magma_weight));
539       break;
540     case 3:
541       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->magma_interp));
542       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->magma_interp_tr));
543       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->magma_grad));
544       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->magma_grad_tr));
545       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->magma_weight));
546   }
547 
548   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma));
549   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma));
550 
551   // Copy q_ref_1d to the GPU
552   CeedCallBackend(magma_malloc((void **)&impl->d_q_ref_1d, Q_1d * sizeof(q_ref_1d[0])));
553   magma_setvector(Q_1d, sizeof(q_ref_1d[0]), q_ref_1d, 1, impl->d_q_ref_1d, 1, data->queue);
554 
555   // Copy interp_1d to the GPU
556   CeedCallBackend(magma_malloc((void **)&impl->d_interp_1d, Q_1d * P_1d * sizeof(interp_1d[0])));
557   magma_setvector(Q_1d * P_1d, sizeof(interp_1d[0]), interp_1d, 1, impl->d_interp_1d, 1, data->queue);
558 
559   // Copy grad_1d to the GPU
560   CeedCallBackend(magma_malloc((void **)&impl->d_grad_1d, Q_1d * P_1d * sizeof(grad_1d[0])));
561   magma_setvector(Q_1d * P_1d, sizeof(grad_1d[0]), grad_1d, 1, impl->d_grad_1d, 1, data->queue);
562 
563   // Copy q_weight_1d to the GPU
564   CeedCallBackend(magma_malloc((void **)&impl->d_q_weight_1d, Q_1d * sizeof(q_weight_1d[0])));
565   magma_setvector(Q_1d, sizeof(q_weight_1d[0]), q_weight_1d, 1, impl->d_q_weight_1d, 1, data->queue);
566 
567   CeedCallBackend(CeedBasisSetData(basis, impl));
568   CeedCallBackend(CeedFree(&magma_common_path));
569   CeedCallBackend(CeedFree(&interp_path));
570   CeedCallBackend(CeedFree(&grad_path));
571   CeedCallBackend(CeedFree(&weight_path));
572   CeedCallBackend(CeedFree(&basis_kernel_source));
573   return CEED_ERROR_SUCCESS;
574 }
575 
576 #ifdef __cplusplus
577 CEED_INTERN "C"
578 #endif
579     int
580     CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_dof, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad,
581                             const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
582   Ceed                      ceed, ceed_delegate;
583   Ceed_Magma               *data;
584   char                     *magma_common_path, *interp_path, *grad_path, *basis_kernel_source;
585   CeedBasisNonTensor_Magma *impl;
586 
587   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
588   CeedCallBackend(CeedGetData(ceed, &data));
589   magma_int_t arch = magma_getdevice_arch();
590 
591   CeedCallBackend(CeedCalloc(1, &impl));
592   // Compile kernels
593   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_defs.h", &magma_common_path));
594   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
595   CeedCallBackend(CeedLoadSourceToBuffer(ceed, magma_common_path, &basis_kernel_source));
596   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma_common_nontensor.h", &magma_common_path));
597   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, magma_common_path, &basis_kernel_source));
598   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/interp-nontensor.h", &interp_path));
599   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, interp_path, &basis_kernel_source));
600   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/grad-nontensor.h", &grad_path));
601   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_path, &basis_kernel_source));
602 
603   // tuning parameters for nb
604   CeedInt nb_interp_n[MAGMA_NONTENSOR_KERNEL_INSTANCES];
605   CeedInt nb_interp_t[MAGMA_NONTENSOR_KERNEL_INSTANCES];
606   CeedInt nb_grad_n[MAGMA_NONTENSOR_KERNEL_INSTANCES];
607   CeedInt nb_grad_t[MAGMA_NONTENSOR_KERNEL_INSTANCES];
608   CeedInt P = num_dof, Q = num_qpts;
609   CeedInt n_array[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_N_VALUES};
610 
611   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
612     nb_interp_n[in] = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_INTERP, CEED_NOTRANSPOSE, P, n_array[in], Q);
613     nb_interp_t[in] = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_INTERP, CEED_TRANSPOSE, P, n_array[in], Q);
614     nb_grad_n[in]   = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_GRAD, CEED_NOTRANSPOSE, P, n_array[in], Q);
615     nb_grad_t[in]   = nontensor_rtc_get_nb(arch, 'd', CEED_EVAL_GRAD, CEED_TRANSPOSE, P, n_array[in], Q);
616   }
617 
618   // compile
619   CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
620   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
621     CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module[in], 7, "DIM", dim, "P", P, "Q", Q, "NB_INTERP_N",
622                                      nb_interp_n[in], "NB_INTERP_T", nb_interp_t[in], "NB_GRAD_N", nb_grad_n[in], "NB_GRAD_T", nb_grad_t[in]));
623   }
624 
625   // get kernels
626   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
627     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_interp_nontensor_n", &impl->magma_interp_nontensor[in]));
628     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_interp_nontensor_t", &impl->magma_interp_tr_nontensor[in]));
629     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_grad_nontensor_n", &impl->magma_grad_nontensor[in]));
630     CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[in], "magma_grad_nontensor_t", &impl->magma_grad_tr_nontensor[in]));
631   }
632 
633   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
634   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
635 
636   // Copy q_ref to the GPU
637   CeedCallBackend(magma_malloc((void **)&impl->d_q_ref, num_qpts * sizeof(q_ref[0])));
638   magma_setvector(num_qpts, sizeof(q_ref[0]), q_ref, 1, impl->d_q_ref, 1, data->queue);
639 
640   // Copy interp to the GPU
641   CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_dof * sizeof(interp[0])));
642   magma_setvector(num_qpts * num_dof, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue);
643 
644   // Copy grad to the GPU
645   CeedCallBackend(magma_malloc((void **)&impl->d_grad, num_qpts * num_dof * dim * sizeof(grad[0])));
646   magma_setvector(num_qpts * num_dof * dim, sizeof(grad[0]), grad, 1, impl->d_grad, 1, data->queue);
647 
648   // Copy q_weight to the GPU
649   CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0])));
650   magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue);
651 
652   CeedCallBackend(CeedBasisSetData(basis, impl));
653   CeedCallBackend(CeedFree(&magma_common_path));
654   CeedCallBackend(CeedFree(&interp_path));
655   CeedCallBackend(CeedFree(&grad_path));
656   CeedCallBackend(CeedFree(&basis_kernel_source));
657   return CEED_ERROR_SUCCESS;
658 }
659