xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma-basis.c (revision 833aa127b2360f2d0ee487784605de022811dcd8)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <string.h>
12 
13 #ifdef CEED_MAGMA_USE_HIP
14 #include "../hip/ceed-hip-common.h"
15 #include "../hip/ceed-hip-compile.h"
16 #else
17 #include "../cuda/ceed-cuda-common.h"
18 #include "../cuda/ceed-cuda-compile.h"
19 #endif
20 #include "ceed-magma-common.h"
21 #include "ceed-magma.h"
22 
23 #include "ceed-magma-gemm-nontensor.h"
24 #include "ceed-magma-gemm-selector.h"
25 
26 //------------------------------------------------------------------------------
27 // Basis apply - tensor
28 //------------------------------------------------------------------------------
29 static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, CeedVector v) {
30   Ceed              ceed;
31   Ceed_Magma       *data;
32   CeedInt           dim, num_comp, num_nodes, P_1d, Q_1d, P, Q;
33   const CeedScalar *d_u;
34   CeedScalar       *d_v;
35   CeedBasis_Magma  *impl;
36 
37   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
38   CeedCallBackend(CeedGetData(ceed, &data));
39   CeedCallBackend(CeedBasisGetData(basis, &impl));
40   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
41   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
42   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
43   CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
44   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
45   P = P_1d;
46   Q = Q_1d;
47   if (t_mode == CEED_TRANSPOSE) {
48     P = Q_1d;
49     Q = P_1d;
50   }
51 
52   // Read vectors
53   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
54   else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
55   CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
56 
57   // Clear v for transpose operation
58   if (t_mode == CEED_TRANSPOSE) {
59     CeedSize length;
60 
61     CeedCallBackend(CeedVectorGetLength(v, &length));
62     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
63       magmablas_slaset(MagmaFull, length, 1, 0.0, 0.0, (float *)d_v, length, data->queue);
64     } else {
65       magmablas_dlaset(MagmaFull, length, 1, 0.0, 0.0, (double *)d_v, length, data->queue);
66     }
67     ceed_magma_queue_sync(data->queue);
68   }
69 
70   // Apply basis operation
71   switch (e_mode) {
72     case CEED_EVAL_INTERP: {
73       // Define element sizes for dofs/quad
74       CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim);
75       CeedInt elem_dofs_size = CeedIntPow(P_1d, dim);
76 
77       // E-vector ordering -------------- Q-vector ordering
78       //  component                        component
79       //    elem                             elem
80       //       node                            node
81 
82       // ---  Define strides for NOTRANSPOSE mode: ---
83       // Input (d_u) is E-vector, output (d_v) is Q-vector
84 
85       // Element strides
86       CeedInt u_elem_stride = elem_dofs_size;
87       CeedInt v_elem_stride = elem_qpts_size;
88       // Component strides
89       CeedInt u_comp_stride = num_elem * elem_dofs_size;
90       CeedInt v_comp_stride = num_elem * elem_qpts_size;
91       if (t_mode == CEED_TRANSPOSE) {
92         // Input (d_u) is Q-vector, output (d_v) is E-vector
93         // Element strides
94         v_elem_stride = elem_dofs_size;
95         u_elem_stride = elem_qpts_size;
96         // Component strides
97         v_comp_stride = num_elem * elem_dofs_size;
98         u_comp_stride = num_elem * elem_qpts_size;
99       }
100       CeedInt num_threads = 1;
101       CeedInt num_t_col   = 1;
102       CeedInt shared_mem  = 0;
103       CeedInt max_P_Q     = CeedIntMax(P, Q);
104 
105       switch (dim) {
106         case 1:
107           num_threads = max_P_Q;
108           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
109           shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q));
110           shared_mem += sizeof(CeedScalar) * (P * Q);
111           break;
112         case 2:
113           num_threads = max_P_Q;
114           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
115           shared_mem += P * Q * sizeof(CeedScalar);  // for sT
116           // for reforming rU we need P x P, and for the intermediate output we need P x Q
117           shared_mem += num_t_col * (P * max_P_Q * sizeof(CeedScalar));
118           break;
119         case 3:
120           num_threads = max_P_Q * max_P_Q;
121           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
122           shared_mem += sizeof(CeedScalar) * (P * Q);  // for sT
123           // rU needs P^2 x P, the intermediate output needs max(P^2 x Q, P x Q^2)
124           shared_mem += sizeof(CeedScalar) * num_t_col * (CeedIntMax(P * P * max_P_Q, P * Q * Q));
125           break;
126       }
127       CeedInt grid   = CeedDivUpInt(num_elem, num_t_col);
128       void   *args[] = {&impl->d_interp_1d, &d_u, &u_elem_stride, &u_comp_stride, &d_v, &v_elem_stride, &v_comp_stride, &num_elem};
129 
130       if (t_mode == CEED_TRANSPOSE) {
131         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->InterpTranspose, grid, num_threads, num_t_col, 1, shared_mem, args));
132       } else {
133         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Interp, grid, num_threads, num_t_col, 1, shared_mem, args));
134       }
135     } break;
136     case CEED_EVAL_GRAD: {
137       // Define element sizes for dofs/quad
138       CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim);
139       CeedInt elem_dofs_size = CeedIntPow(P_1d, dim);
140 
141       // In CEED_NOTRANSPOSE mode:
142       // d_u is (P^dim x nc), column-major layout (nc = num_comp)
143       // d_v is (Q^dim x nc x dim), column-major layout (nc = num_comp)
144       // In CEED_TRANSPOSE mode, the sizes of d_u and d_v are switched.
145 
146       // E-vector ordering -------------- Q-vector ordering
147       //                                  dim
148       //  component                        component
149       //    elem                              elem
150       //       node                            node
151 
152       // ---  Define strides for NOTRANSPOSE mode: ---
153       // Input (d_u) is E-vector, output (d_v) is Q-vector
154 
155       // Element strides
156       CeedInt u_elem_stride = elem_dofs_size;
157       CeedInt v_elem_stride = elem_qpts_size;
158       // Component strides
159       CeedInt u_comp_stride = num_elem * elem_dofs_size;
160       CeedInt v_comp_stride = num_elem * elem_qpts_size;
161       // Dimension strides
162       CeedInt u_dim_stride = 0;
163       CeedInt v_dim_stride = num_elem * elem_qpts_size * num_comp;
164       if (t_mode == CEED_TRANSPOSE) {
165         // Input (d_u) is Q-vector, output (d_v) is E-vector
166         // Element strides
167         v_elem_stride = elem_dofs_size;
168         u_elem_stride = elem_qpts_size;
169         // Component strides
170         v_comp_stride = num_elem * elem_dofs_size;
171         u_comp_stride = num_elem * elem_qpts_size;
172         // Dimension strides
173         v_dim_stride = 0;
174         u_dim_stride = num_elem * elem_qpts_size * num_comp;
175       }
176       CeedInt num_threads = 1;
177       CeedInt num_t_col   = 1;
178       CeedInt shared_mem  = 0;
179       CeedInt max_P_Q     = CeedIntMax(P, Q);
180 
181       switch (dim) {
182         case 1:
183           num_threads = max_P_Q;
184           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
185           shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q));
186           shared_mem += sizeof(CeedScalar) * (P * Q);
187           break;
188         case 2:
189           num_threads = max_P_Q;
190           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
191           shared_mem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
192           // for reforming rU we need P x P, and for the intermediate output we need P x Q
193           shared_mem += sizeof(CeedScalar) * num_t_col * (P * max_P_Q);
194           break;
195         case 3:
196           num_threads = max_P_Q * max_P_Q;
197           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
198           shared_mem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
199           // rU needs P^2 x P, the intermediate outputs need (P^2 x Q + P x Q^2)
200           shared_mem += sizeof(CeedScalar) * num_t_col * CeedIntMax(P * P * P, (P * P * Q) + (P * Q * Q));
201           break;
202       }
203       CeedInt grid   = CeedDivUpInt(num_elem, num_t_col);
204       void   *args[] = {&impl->d_interp_1d, &impl->d_grad_1d, &d_u,          &u_elem_stride, &u_comp_stride, &u_dim_stride, &d_v,
205                         &v_elem_stride,     &v_comp_stride,   &v_dim_stride, &num_elem};
206 
207       if (t_mode == CEED_TRANSPOSE) {
208         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->GradTranspose, grid, num_threads, num_t_col, 1, shared_mem, args));
209       } else {
210         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Grad, grid, num_threads, num_t_col, 1, shared_mem, args));
211       }
212     } break;
213     case CEED_EVAL_WEIGHT: {
214       CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
215       CeedInt elem_dofs_size = CeedIntPow(Q, dim);
216       CeedInt num_threads    = 1;
217       CeedInt num_t_col      = 1;
218       CeedInt shared_mem     = 0;
219 
220       switch (dim) {
221         case 1:
222           num_threads = Q;
223           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
224           shared_mem += sizeof(CeedScalar) * Q;              // for d_q_weight_1d
225           shared_mem += sizeof(CeedScalar) * num_t_col * Q;  // for output
226           break;
227         case 2:
228           num_threads = Q;
229           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
230           shared_mem += sizeof(CeedScalar) * Q;  // for d_q_weight_1d
231           break;
232         case 3:
233           num_threads = Q * Q;
234           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
235           shared_mem += sizeof(CeedScalar) * Q;  // for d_q_weight_1d
236           break;
237       }
238       CeedInt grid   = CeedDivUpInt(num_elem, num_t_col);
239       void   *args[] = {&impl->d_q_weight_1d, &d_v, &elem_dofs_size, &num_elem};
240 
241       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Weight, grid, num_threads, num_t_col, 1, shared_mem, args));
242     } break;
243     // LCOV_EXCL_START
244     case CEED_EVAL_DIV:
245       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
246     case CEED_EVAL_CURL:
247       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
248     case CEED_EVAL_NONE:
249       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
250       // LCOV_EXCL_STOP
251   }
252 
253   // Must sync to ensure completeness
254   ceed_magma_queue_sync(data->queue);
255 
256   // Restore vectors
257   if (e_mode != CEED_EVAL_WEIGHT) {
258     CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
259   }
260   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
261   return CEED_ERROR_SUCCESS;
262 }
263 
264 //------------------------------------------------------------------------------
265 // Basis apply - non-tensor
266 //------------------------------------------------------------------------------
267 static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u,
268                                          CeedVector v) {
269   Ceed                      ceed;
270   Ceed_Magma               *data;
271   CeedInt                   num_comp, q_comp, num_nodes, num_qpts, P, Q, N;
272   const CeedScalar         *d_u;
273   CeedScalar               *d_v;
274   CeedBasisNonTensor_Magma *impl;
275 
276   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
277   CeedCallBackend(CeedGetData(ceed, &data));
278   CeedCallBackend(CeedBasisGetData(basis, &impl));
279   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
280   CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, e_mode, &q_comp));
281   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
282   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
283   P = num_nodes;
284   Q = num_qpts;
285   N = num_elem * num_comp;
286 
287   // Read vectors
288   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
289   else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
290   CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
291 
292   // Clear v for transpose operation
293   if (t_mode == CEED_TRANSPOSE) {
294     CeedSize length;
295 
296     CeedCallBackend(CeedVectorGetLength(v, &length));
297     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
298       magmablas_slaset(MagmaFull, length, 1, 0.0, 0.0, (float *)d_v, length, data->queue);
299     } else {
300       magmablas_dlaset(MagmaFull, length, 1, 0.0, 0.0, (double *)d_v, length, data->queue);
301     }
302     ceed_magma_queue_sync(data->queue);
303   }
304 
305   // Apply basis operation
306   if (e_mode != CEED_EVAL_WEIGHT) {
307     const CeedScalar *d_b = NULL;
308     switch (e_mode) {
309       case CEED_EVAL_INTERP:
310         d_b = impl->d_interp;
311         break;
312       case CEED_EVAL_GRAD:
313         d_b = impl->d_grad;
314         break;
315       case CEED_EVAL_DIV:
316         d_b = impl->d_div;
317         break;
318       case CEED_EVAL_CURL:
319         d_b = impl->d_curl;
320         break;
321       // LCOV_EXCL_START
322       case CEED_EVAL_WEIGHT:
323         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT does not make sense in this context");
324       case CEED_EVAL_NONE:
325         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
326         // LCOV_EXCL_STOP
327     }
328 
329     // Apply basis operation
330     if (P <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
331       CeedInt n_array[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_KERNEL_N_VALUES};
332       CeedInt iN = 0, diff = abs(n_array[iN] - N), idiff;
333       CeedInt M = (t_mode == CEED_TRANSPOSE) ? P : Q, K = (t_mode == CEED_TRANSPOSE) ? Q : P;
334 
335       for (CeedInt in = iN + 1; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
336         idiff = abs(n_array[in] - N);
337         if (idiff < diff) {
338           iN   = in;
339           diff = idiff;
340         }
341       }
342 
343       // Compile kernels for N as needed
344       if (!impl->NB_interp[iN]) {
345         CeedFESpace fe_space;
346         CeedInt     q_comp_interp, q_comp_deriv;
347         Ceed        ceed_delegate;
348         char       *basis_kernel_path, *basis_kernel_source;
349         magma_int_t arch = magma_getdevice_arch();
350 
351         // Tuning parameters for NB
352         CeedCallBackend(CeedBasisGetFESpace(basis, &fe_space));
353         CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
354         switch (fe_space) {
355           case CEED_FE_SPACE_H1:
356             CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_deriv));
357             break;
358           case CEED_FE_SPACE_HDIV:
359             CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_deriv));
360             break;
361           case CEED_FE_SPACE_HCURL:
362             CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_deriv));
363             break;
364         }
365         impl->NB_interp[iN]   = nontensor_rtc_get_nb(arch, 'n', q_comp_interp, P, Q, n_array[iN]);
366         impl->NB_interp_t[iN] = nontensor_rtc_get_nb(arch, 't', q_comp_interp, P, Q, n_array[iN]);
367         impl->NB_deriv[iN]    = nontensor_rtc_get_nb(arch, 'n', q_comp_deriv, P, Q, n_array[iN]);
368         impl->NB_deriv_t[iN]  = nontensor_rtc_get_nb(arch, 't', q_comp_deriv, P, Q, n_array[iN]);
369 
370         // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
371         CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
372 
373         // Compile kernels
374         CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-interp-deriv-nontensor.h", &basis_kernel_path));
375         CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
376         CeedCallBackend(CeedLoadSourceToBuffer(ceed, basis_kernel_path, &basis_kernel_source));
377         CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
378         CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module_interp[iN], 8, "BASIS_Q_COMP_INTERP", q_comp_interp,
379                                          "BASIS_Q_COMP_DERIV", q_comp_deriv, "BASIS_P", P, "BASIS_Q", Q, "BASIS_NB_INTERP_N", impl->NB_interp[iN],
380                                          "BASIS_NB_INTERP_T", impl->NB_interp_t[iN], "BASIS_NB_DERIV_N", impl->NB_deriv[iN], "BASIS_NB_DERIV_T",
381                                          impl->NB_deriv_t[iN]));
382         CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_interp_nontensor_n", &impl->Interp[iN]));
383         CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_interp_nontensor_t", &impl->InterpTranspose[iN]));
384         CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_deriv_nontensor_n", &impl->Deriv[iN]));
385         CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_deriv_nontensor_t", &impl->DerivTranspose[iN]));
386         CeedCallBackend(CeedFree(&basis_kernel_path));
387         CeedCallBackend(CeedFree(&basis_kernel_source));
388       }
389       CeedMagmaFunction Kernel;
390       CeedInt           NB;
391       if (e_mode == CEED_EVAL_INTERP) {
392         if (t_mode == CEED_TRANSPOSE) {
393           Kernel = impl->InterpTranspose[iN];
394           NB     = impl->NB_interp_t[iN];
395         } else {
396           Kernel = impl->Interp[iN];
397           NB     = impl->NB_interp[iN];
398         }
399       } else {
400         if (t_mode == CEED_TRANSPOSE) {
401           Kernel = impl->DerivTranspose[iN];
402           NB     = impl->NB_deriv_t[iN];
403         } else {
404           Kernel = impl->Deriv[iN];
405           NB     = impl->NB_deriv[iN];
406         }
407       }
408       CeedInt num_t_col    = MAGMA_BASIS_NTCOL(M, MAGMA_MAXTHREADS_1D);
409       CeedInt grid         = CeedDivUpInt(N, num_t_col * NB);
410       CeedInt shared_mem_A = P * Q * sizeof(CeedScalar);
411       CeedInt shared_mem_B = num_t_col * K * NB * sizeof(CeedScalar);
412       CeedInt shared_mem   = (t_mode != CEED_TRANSPOSE && q_comp > 1) ? (shared_mem_A + shared_mem_B) : CeedIntMax(shared_mem_A, shared_mem_B);
413       void   *args[]       = {&N, &d_b, &d_u, &d_v};
414 
415       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, Kernel, grid, M, num_t_col, 1, shared_mem, args));
416     } else {
417       for (CeedInt d = 0; d < q_comp; d++) {
418         if (t_mode == CEED_TRANSPOSE) {
419           const CeedScalar beta = (d > 0) ? 1.0 : 0.0;
420           magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, N, Q, 1.0, d_b + d * P * Q, P, d_u + d * N * Q, Q, beta, d_v, P, data->queue);
421         } else {
422           magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, N, P, 1.0, d_b + d * P * Q, P, d_u, P, 0.0, d_v + d * N * Q, Q, data->queue);
423         }
424       }
425     }
426   } else {
427     CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
428     CeedInt num_t_col  = MAGMA_BASIS_NTCOL(Q, MAGMA_MAXTHREADS_1D);
429     CeedInt grid       = CeedDivUpInt(num_elem, num_t_col);
430     CeedInt shared_mem = Q * sizeof(CeedScalar) + num_t_col * Q * sizeof(CeedScalar);
431     void   *args[]     = {&num_elem, &impl->d_q_weight, &d_v};
432 
433     CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Weight, grid, Q, num_t_col, 1, shared_mem, args));
434   }
435 
436   // Must sync to ensure completeness
437   ceed_magma_queue_sync(data->queue);
438 
439   // Restore vectors
440   if (e_mode != CEED_EVAL_WEIGHT) {
441     CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
442   }
443   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
444   return CEED_ERROR_SUCCESS;
445 }
446 
447 //------------------------------------------------------------------------------
448 // Destroy tensor basis
449 //------------------------------------------------------------------------------
450 static int CeedBasisDestroy_Magma(CeedBasis basis) {
451   Ceed             ceed;
452   CeedBasis_Magma *impl;
453 
454   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
455   CeedCallBackend(CeedBasisGetData(basis, &impl));
456 #ifdef CEED_MAGMA_USE_HIP
457   CeedCallHip(ceed, hipModuleUnload(impl->module));
458 #else
459   CeedCallCuda(ceed, cuModuleUnload(impl->module));
460 #endif
461   CeedCallBackend(magma_free(impl->d_interp_1d));
462   CeedCallBackend(magma_free(impl->d_grad_1d));
463   CeedCallBackend(magma_free(impl->d_q_weight_1d));
464   CeedCallBackend(CeedFree(&impl));
465   return CEED_ERROR_SUCCESS;
466 }
467 
468 //------------------------------------------------------------------------------
469 // Destroy non-tensor basis
470 //------------------------------------------------------------------------------
471 static int CeedBasisDestroyNonTensor_Magma(CeedBasis basis) {
472   Ceed                      ceed;
473   CeedBasisNonTensor_Magma *impl;
474 
475   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
476   CeedCallBackend(CeedBasisGetData(basis, &impl));
477 #ifdef CEED_MAGMA_USE_HIP
478   CeedCallHip(ceed, hipModuleUnload(impl->module_weight));
479 #else
480   CeedCallCuda(ceed, cuModuleUnload(impl->module_weight));
481 #endif
482   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
483     if (impl->module_interp[in]) {
484 #ifdef CEED_MAGMA_USE_HIP
485       CeedCallHip(ceed, hipModuleUnload(impl->module_interp[in]));
486 #else
487       CeedCallCuda(ceed, cuModuleUnload(impl->module_interp[in]));
488 #endif
489     }
490   }
491   CeedCallBackend(magma_free(impl->d_interp));
492   CeedCallBackend(magma_free(impl->d_grad));
493   CeedCallBackend(magma_free(impl->d_div));
494   CeedCallBackend(magma_free(impl->d_curl));
495   CeedCallBackend(magma_free(impl->d_q_weight));
496   CeedCallBackend(CeedFree(&impl));
497   return CEED_ERROR_SUCCESS;
498 }
499 
500 //------------------------------------------------------------------------------
501 // Create tensor
502 //------------------------------------------------------------------------------
503 int CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
504                                   const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
505   Ceed             ceed, ceed_delegate;
506   Ceed_Magma      *data;
507   char            *interp_kernel_path, *grad_kernel_path, *weight_kernel_path, *basis_kernel_source;
508   CeedInt          num_comp;
509   CeedBasis_Magma *impl;
510 
511   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
512   CeedCallBackend(CeedGetData(ceed, &data));
513   CeedCallBackend(CeedCalloc(1, &impl));
514 
515   // Copy basis data to GPU
516   CeedCallBackend(magma_malloc((void **)&impl->d_q_weight_1d, Q_1d * sizeof(q_weight_1d[0])));
517   magma_setvector(Q_1d, sizeof(q_weight_1d[0]), q_weight_1d, 1, impl->d_q_weight_1d, 1, data->queue);
518   CeedCallBackend(magma_malloc((void **)&impl->d_interp_1d, Q_1d * P_1d * sizeof(interp_1d[0])));
519   magma_setvector(Q_1d * P_1d, sizeof(interp_1d[0]), interp_1d, 1, impl->d_interp_1d, 1, data->queue);
520   CeedCallBackend(magma_malloc((void **)&impl->d_grad_1d, Q_1d * P_1d * sizeof(grad_1d[0])));
521   magma_setvector(Q_1d * P_1d, sizeof(grad_1d[0]), grad_1d, 1, impl->d_grad_1d, 1, data->queue);
522 
523   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
524   CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
525 
526   // Compile kernels
527   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
528   {
529     char   *interp_kernel_name_base = "ceed/jit-source/magma/magma-basis-interp";
530     CeedInt interp_kernel_name_len  = strlen(interp_kernel_name_base) + 6;
531     char    interp_kernel_name[interp_kernel_name_len];
532 
533     snprintf(interp_kernel_name, interp_kernel_name_len, "%s-%" CeedInt_FMT "d.h", interp_kernel_name_base, dim);
534     CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_kernel_name, &interp_kernel_path));
535   }
536   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
537   CeedCallBackend(CeedLoadSourceToBuffer(ceed, interp_kernel_path, &basis_kernel_source));
538   {
539     char   *grad_kernel_name_base = "ceed/jit-source/magma/magma-basis-grad";
540     CeedInt grad_kernel_name_len  = strlen(grad_kernel_name_base) + 6;
541     char    grad_kernel_name[grad_kernel_name_len];
542 
543     snprintf(grad_kernel_name, grad_kernel_name_len, "%s-%" CeedInt_FMT "d.h", grad_kernel_name_base, dim);
544     CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_kernel_name, &grad_kernel_path));
545   }
546   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_kernel_path, &basis_kernel_source));
547   {
548     char   *weight_kernel_name_base = "ceed/jit-source/magma/magma-basis-weight";
549     CeedInt weight_kernel_name_len  = strlen(weight_kernel_name_base) + 6;
550     char    weight_kernel_name[weight_kernel_name_len];
551 
552     snprintf(weight_kernel_name, weight_kernel_name_len, "%s-%" CeedInt_FMT "d.h", weight_kernel_name_base, dim);
553     CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_kernel_name, &weight_kernel_path));
554   }
555   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_kernel_path, &basis_kernel_source));
556   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
557   CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module, 5, "BASIS_DIM", dim, "BASIS_NUM_COMP", num_comp, "BASIS_P",
558                                    P_1d, "BASIS_Q", Q_1d, "BASIS_MAX_P_Q", CeedIntMax(P_1d, Q_1d)));
559   switch (dim) {
560     case 1:
561       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->Interp));
562       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->InterpTranspose));
563       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->Grad));
564       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->GradTranspose));
565       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->Weight));
566       break;
567     case 2:
568       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->Interp));
569       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->InterpTranspose));
570       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->Grad));
571       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->GradTranspose));
572       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->Weight));
573       break;
574     case 3:
575       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->Interp));
576       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->InterpTranspose));
577       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->Grad));
578       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->GradTranspose));
579       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->Weight));
580       break;
581   }
582   CeedCallBackend(CeedFree(&interp_kernel_path));
583   CeedCallBackend(CeedFree(&grad_kernel_path));
584   CeedCallBackend(CeedFree(&weight_kernel_path));
585   CeedCallBackend(CeedFree(&basis_kernel_source));
586 
587   CeedCallBackend(CeedBasisSetData(basis, impl));
588 
589   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma));
590   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma));
591   return CEED_ERROR_SUCCESS;
592 }
593 
594 //------------------------------------------------------------------------------
595 // Create non-tensor H^1
596 //------------------------------------------------------------------------------
597 int CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad,
598                             const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
599   Ceed                      ceed, ceed_delegate;
600   Ceed_Magma               *data;
601   char                     *weight_kernel_path, *basis_kernel_source;
602   CeedBasisNonTensor_Magma *impl;
603 
604   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
605   CeedCallBackend(CeedGetData(ceed, &data));
606   CeedCallBackend(CeedCalloc(1, &impl));
607 
608   // Copy basis data to GPU
609   CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0])));
610   magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue);
611   if (interp) {
612     CeedInt q_comp_interp;
613 
614     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
615     CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * q_comp_interp * sizeof(interp[0])));
616     magma_setvector(num_qpts * num_nodes * q_comp_interp, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue);
617   }
618   if (grad) {
619     CeedInt q_comp_grad;
620 
621     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad));
622     CeedCallBackend(magma_malloc((void **)&impl->d_grad, num_qpts * num_nodes * q_comp_grad * sizeof(grad[0])));
623     magma_setvector(num_qpts * num_nodes * q_comp_grad, sizeof(grad[0]), grad, 1, impl->d_grad, 1, data->queue);
624   }
625 
626   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
627   CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
628 
629   // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply)
630   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path));
631   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
632   CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source));
633   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
634   CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module_weight, 1, "BASIS_Q", num_qpts));
635   CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_weight, "magma_weight_nontensor", &impl->Weight));
636   CeedCallBackend(CeedFree(&weight_kernel_path));
637   CeedCallBackend(CeedFree(&basis_kernel_source));
638 
639   CeedCallBackend(CeedBasisSetData(basis, impl));
640 
641   // Register backend functions
642   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
643   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
644   return CEED_ERROR_SUCCESS;
645 }
646 
647 //------------------------------------------------------------------------------
648 // Create non-tensor H(div)
649 //------------------------------------------------------------------------------
650 int CeedBasisCreateHdiv_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp,
651                               const CeedScalar *div, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
652   Ceed                      ceed, ceed_delegate;
653   Ceed_Magma               *data;
654   char                     *weight_kernel_path, *basis_kernel_source;
655   CeedBasisNonTensor_Magma *impl;
656 
657   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
658   CeedCallBackend(CeedGetData(ceed, &data));
659   CeedCallBackend(CeedCalloc(1, &impl));
660 
661   // Copy basis data to GPU
662   CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0])));
663   magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue);
664   if (interp) {
665     CeedInt q_comp_interp;
666 
667     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
668     CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * q_comp_interp * sizeof(interp[0])));
669     magma_setvector(num_qpts * num_nodes * q_comp_interp, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue);
670   }
671   if (div) {
672     CeedInt q_comp_div;
673 
674     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div));
675     CeedCallBackend(magma_malloc((void **)&impl->d_div, num_qpts * num_nodes * q_comp_div * sizeof(div[0])));
676     magma_setvector(num_qpts * num_nodes * q_comp_div, sizeof(div[0]), div, 1, impl->d_div, 1, data->queue);
677   }
678 
679   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
680   CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
681 
682   // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply)
683   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path));
684   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
685   CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source));
686   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
687   CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module_weight, 1, "BASIS_Q", num_qpts));
688   CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_weight, "magma_weight_nontensor", &impl->Weight));
689   CeedCallBackend(CeedFree(&weight_kernel_path));
690   CeedCallBackend(CeedFree(&basis_kernel_source));
691 
692   CeedCallBackend(CeedBasisSetData(basis, impl));
693 
694   // Register backend functions
695   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
696   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
697   return CEED_ERROR_SUCCESS;
698 }
699 
700 //------------------------------------------------------------------------------
701 // Create non-tensor H(curl)
702 //------------------------------------------------------------------------------
703 int CeedBasisCreateHcurl_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp,
704                                const CeedScalar *curl, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
705   Ceed                      ceed, ceed_delegate;
706   Ceed_Magma               *data;
707   char                     *weight_kernel_path, *basis_kernel_source;
708   CeedBasisNonTensor_Magma *impl;
709 
710   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
711   CeedCallBackend(CeedGetData(ceed, &data));
712   CeedCallBackend(CeedCalloc(1, &impl));
713 
714   // Copy basis data to GPU
715   CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0])));
716   magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue);
717   if (interp) {
718     CeedInt q_comp_interp;
719 
720     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
721     CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * q_comp_interp * sizeof(interp[0])));
722     magma_setvector(num_qpts * num_nodes * q_comp_interp, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue);
723   }
724   if (curl) {
725     CeedInt q_comp_curl;
726 
727     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl));
728     CeedCallBackend(magma_malloc((void **)&impl->d_curl, num_qpts * num_nodes * q_comp_curl * sizeof(curl[0])));
729     magma_setvector(num_qpts * num_nodes * q_comp_curl, sizeof(curl[0]), curl, 1, impl->d_curl, 1, data->queue);
730   }
731 
732   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
733   CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
734 
735   // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply)
736   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path));
737   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
738   CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source));
739   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
740   CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module_weight, 1, "BASIS_Q", num_qpts));
741   CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_weight, "magma_weight_nontensor", &impl->Weight));
742   CeedCallBackend(CeedFree(&weight_kernel_path));
743   CeedCallBackend(CeedFree(&basis_kernel_source));
744 
745   CeedCallBackend(CeedBasisSetData(basis, impl));
746 
747   // Register backend functions
748   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
749   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
750   return CEED_ERROR_SUCCESS;
751 }
752 
753 //------------------------------------------------------------------------------
754