xref: /libCEED/backends/magma/ceed-magma-basis.c (revision e6aaec067db7781c9cc5c3b72972e2661403e34f)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <string.h>
12 
13 #ifdef CEED_MAGMA_USE_HIP
14 #include "../hip/ceed-hip-common.h"
15 #include "../hip/ceed-hip-compile.h"
16 #else
17 #include "../cuda/ceed-cuda-common.h"
18 #include "../cuda/ceed-cuda-compile.h"
19 #endif
20 #include "ceed-magma-common.h"
21 #include "ceed-magma.h"
22 
23 #include "ceed-magma-gemm-nontensor.h"
24 #include "ceed-magma-gemm-selector.h"
25 
26 //------------------------------------------------------------------------------
27 // Basis apply - tensor
28 //------------------------------------------------------------------------------
29 static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, CeedVector v) {
30   Ceed              ceed;
31   Ceed_Magma       *data;
32   CeedInt           dim, num_comp, num_nodes, P_1d, Q_1d, P, Q;
33   const CeedScalar *d_u;
34   CeedScalar       *d_v;
35   CeedBasis_Magma  *impl;
36 
37   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
38   CeedCallBackend(CeedGetData(ceed, &data));
39   CeedCallBackend(CeedBasisGetData(basis, &impl));
40   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
41   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
42   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
43   CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
44   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
45   P = P_1d;
46   Q = Q_1d;
47   if (t_mode == CEED_TRANSPOSE) {
48     P = Q_1d;
49     Q = P_1d;
50   }
51 
52   // Read vectors
53   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
54   else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
55   CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
56 
57   // Clear v for transpose operation
58   if (t_mode == CEED_TRANSPOSE) {
59     CeedSize length;
60 
61     CeedCallBackend(CeedVectorGetLength(v, &length));
62     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
63       magmablas_slaset(MagmaFull, length, 1, 0.0, 0.0, (float *)d_v, length, data->queue);
64     } else {
65       magmablas_dlaset(MagmaFull, length, 1, 0.0, 0.0, (double *)d_v, length, data->queue);
66     }
67     ceed_magma_queue_sync(data->queue);
68   }
69 
70   // Apply basis operation
71   switch (e_mode) {
72     case CEED_EVAL_INTERP: {
73       // Define element sizes for dofs/quad
74       CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim);
75       CeedInt elem_dofs_size = CeedIntPow(P_1d, dim);
76 
77       // E-vector ordering -------------- Q-vector ordering
78       //  component                        component
79       //    elem                             elem
80       //       node                            node
81 
82       // ---  Define strides for NOTRANSPOSE mode: ---
83       // Input (d_u) is E-vector, output (d_v) is Q-vector
84 
85       // Element strides
86       CeedInt u_elem_stride = elem_dofs_size;
87       CeedInt v_elem_stride = elem_qpts_size;
88       // Component strides
89       CeedInt u_comp_stride = num_elem * elem_dofs_size;
90       CeedInt v_comp_stride = num_elem * elem_qpts_size;
91       if (t_mode == CEED_TRANSPOSE) {
92         // Input (d_u) is Q-vector, output (d_v) is E-vector
93         // Element strides
94         v_elem_stride = elem_dofs_size;
95         u_elem_stride = elem_qpts_size;
96         // Component strides
97         v_comp_stride = num_elem * elem_dofs_size;
98         u_comp_stride = num_elem * elem_qpts_size;
99       }
100       CeedInt num_threads = 1;
101       CeedInt num_t_col   = 1;
102       CeedInt shared_mem  = 0;
103       CeedInt max_P_Q     = CeedIntMax(P, Q);
104 
105       switch (dim) {
106         case 1:
107           num_threads = max_P_Q;
108           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
109           shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q));
110           shared_mem += sizeof(CeedScalar) * (P * Q);
111           break;
112         case 2:
113           num_threads = max_P_Q;
114           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
115           shared_mem += P * Q * sizeof(CeedScalar);  // for sT
116           // for reforming rU we need P x P, and for the intermediate output we need P x Q
117           shared_mem += num_t_col * (P * max_P_Q * sizeof(CeedScalar));
118           break;
119         case 3:
120           num_threads = max_P_Q * max_P_Q;
121           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
122           shared_mem += sizeof(CeedScalar) * (P * Q);  // for sT
123           // rU needs P^2 x P, the intermediate output needs max(P^2 x Q, P x Q^2)
124           shared_mem += sizeof(CeedScalar) * num_t_col * (CeedIntMax(P * P * max_P_Q, P * Q * Q));
125           break;
126       }
127       CeedInt grid   = CeedDivUpInt(num_elem, num_t_col);
128       void   *args[] = {&impl->d_interp_1d, &d_u, &u_elem_stride, &u_comp_stride, &d_v, &v_elem_stride, &v_comp_stride, &num_elem};
129 
130       if (t_mode == CEED_TRANSPOSE) {
131         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->InterpTranspose, grid, num_threads, num_t_col, 1, shared_mem, args));
132       } else {
133         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Interp, grid, num_threads, num_t_col, 1, shared_mem, args));
134       }
135     } break;
136     case CEED_EVAL_GRAD: {
137       // Define element sizes for dofs/quad
138       CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim);
139       CeedInt elem_dofs_size = CeedIntPow(P_1d, dim);
140 
141       // In CEED_NOTRANSPOSE mode:
142       // d_u is (P^dim x nc), column-major layout (nc = num_comp)
143       // d_v is (Q^dim x nc x dim), column-major layout (nc = num_comp)
144       // In CEED_TRANSPOSE mode, the sizes of d_u and d_v are switched.
145 
146       // E-vector ordering -------------- Q-vector ordering
147       //                                  dim
148       //  component                        component
149       //    elem                              elem
150       //       node                            node
151 
152       // ---  Define strides for NOTRANSPOSE mode: ---
153       // Input (d_u) is E-vector, output (d_v) is Q-vector
154 
155       // Element strides
156       CeedInt u_elem_stride = elem_dofs_size;
157       CeedInt v_elem_stride = elem_qpts_size;
158       // Component strides
159       CeedInt u_comp_stride = num_elem * elem_dofs_size;
160       CeedInt v_comp_stride = num_elem * elem_qpts_size;
161       // Dimension strides
162       CeedInt u_dim_stride = 0;
163       CeedInt v_dim_stride = num_elem * elem_qpts_size * num_comp;
164       if (t_mode == CEED_TRANSPOSE) {
165         // Input (d_u) is Q-vector, output (d_v) is E-vector
166         // Element strides
167         v_elem_stride = elem_dofs_size;
168         u_elem_stride = elem_qpts_size;
169         // Component strides
170         v_comp_stride = num_elem * elem_dofs_size;
171         u_comp_stride = num_elem * elem_qpts_size;
172         // Dimension strides
173         v_dim_stride = 0;
174         u_dim_stride = num_elem * elem_qpts_size * num_comp;
175       }
176       CeedInt num_threads = 1;
177       CeedInt num_t_col   = 1;
178       CeedInt shared_mem  = 0;
179       CeedInt max_P_Q     = CeedIntMax(P, Q);
180 
181       switch (dim) {
182         case 1:
183           num_threads = max_P_Q;
184           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
185           shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q));
186           shared_mem += sizeof(CeedScalar) * (P * Q);
187           break;
188         case 2:
189           num_threads = max_P_Q;
190           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
191           shared_mem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
192           // for reforming rU we need P x P, and for the intermediate output we need P x Q
193           shared_mem += sizeof(CeedScalar) * num_t_col * (P * max_P_Q);
194           break;
195         case 3:
196           num_threads = max_P_Q * max_P_Q;
197           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
198           shared_mem += sizeof(CeedScalar) * 2 * P * Q;  // for sTinterp and sTgrad
199           // rU needs P^2 x P, the intermediate outputs need (P^2 x Q + P x Q^2)
200           shared_mem += sizeof(CeedScalar) * num_t_col * CeedIntMax(P * P * P, (P * P * Q) + (P * Q * Q));
201           break;
202       }
203       CeedInt grid   = CeedDivUpInt(num_elem, num_t_col);
204       void   *args[] = {&impl->d_interp_1d, &impl->d_grad_1d, &d_u,          &u_elem_stride, &u_comp_stride, &u_dim_stride, &d_v,
205                         &v_elem_stride,     &v_comp_stride,   &v_dim_stride, &num_elem};
206 
207       if (t_mode == CEED_TRANSPOSE) {
208         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->GradTranspose, grid, num_threads, num_t_col, 1, shared_mem, args));
209       } else {
210         CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Grad, grid, num_threads, num_t_col, 1, shared_mem, args));
211       }
212     } break;
213     case CEED_EVAL_WEIGHT: {
214       CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
215       CeedInt elem_dofs_size = CeedIntPow(Q, dim);
216       CeedInt num_threads    = 1;
217       CeedInt num_t_col      = 1;
218       CeedInt shared_mem     = 0;
219 
220       switch (dim) {
221         case 1:
222           num_threads = Q;
223           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D);
224           shared_mem += sizeof(CeedScalar) * Q;              // for d_q_weight_1d
225           shared_mem += sizeof(CeedScalar) * num_t_col * Q;  // for output
226           break;
227         case 2:
228           num_threads = Q;
229           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D);
230           shared_mem += sizeof(CeedScalar) * Q;  // for d_q_weight_1d
231           break;
232         case 3:
233           num_threads = Q * Q;
234           num_t_col   = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D);
235           shared_mem += sizeof(CeedScalar) * Q;  // for d_q_weight_1d
236           break;
237       }
238       CeedInt grid   = CeedDivUpInt(num_elem, num_t_col);
239       void   *args[] = {&impl->d_q_weight_1d, &d_v, &elem_dofs_size, &num_elem};
240 
241       CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Weight, grid, num_threads, num_t_col, 1, shared_mem, args));
242     } break;
243     // LCOV_EXCL_START
244     case CEED_EVAL_DIV:
245       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
246     case CEED_EVAL_CURL:
247       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
248     case CEED_EVAL_NONE:
249       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
250       // LCOV_EXCL_STOP
251   }
252 
253   // Must sync to ensure completeness
254   ceed_magma_queue_sync(data->queue);
255 
256   // Restore vectors
257   if (e_mode != CEED_EVAL_WEIGHT) {
258     CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
259   }
260   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
261   return CEED_ERROR_SUCCESS;
262 }
263 
264 //------------------------------------------------------------------------------
265 // Basis apply - non-tensor
266 //------------------------------------------------------------------------------
267 static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u,
268                                          CeedVector v) {
269   Ceed                      ceed;
270   Ceed_Magma               *data;
271   CeedInt                   dim, num_comp, num_nodes, num_qpts, P, Q, N;
272   const CeedScalar         *d_u;
273   CeedScalar               *d_v;
274   CeedBasisNonTensor_Magma *impl;
275 
276   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
277   CeedCallBackend(CeedGetData(ceed, &data));
278   CeedCallBackend(CeedBasisGetData(basis, &impl));
279   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
280   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
281   CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
282   CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts));
283   P = num_nodes;
284   Q = num_qpts;
285   N = num_elem * num_comp;
286 
287   // Read vectors
288   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
289   else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
290   CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
291 
292   // Clear v for transpose operation
293   if (t_mode == CEED_TRANSPOSE) {
294     CeedSize length;
295 
296     CeedCallBackend(CeedVectorGetLength(v, &length));
297     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
298       magmablas_slaset(MagmaFull, length, 1, 0.0, 0.0, (float *)d_v, length, data->queue);
299     } else {
300       magmablas_dlaset(MagmaFull, length, 1, 0.0, 0.0, (double *)d_v, length, data->queue);
301     }
302     ceed_magma_queue_sync(data->queue);
303   }
304 
305   // Apply basis operation
306   if (e_mode != CEED_EVAL_WEIGHT) {
307     if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) {
308       CeedInt n_array[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_KERNEL_N_VALUES};
309       CeedInt iN = 0, diff = abs(n_array[iN] - N), idiff;
310       CeedInt M = (t_mode == CEED_TRANSPOSE) ? P : Q, K = (t_mode == CEED_TRANSPOSE) ? Q : P;
311 
312       for (CeedInt in = iN + 1; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
313         idiff = abs(n_array[in] - N);
314         if (idiff < diff) {
315           iN   = in;
316           diff = idiff;
317         }
318       }
319 
320       // Compile kernels for N as needed
321       if (!impl->NB_interp[iN]) {
322         Ceed        ceed_delegate;
323         char       *interp_kernel_path, *grad_kernel_path, *basis_kernel_source;
324         magma_int_t arch = magma_getdevice_arch();
325 
326         // Tuning parameters for NB
327         impl->NB_interp[iN]   = nontensor_rtc_get_nb(arch, 'n', 1, P, Q, n_array[iN]);
328         impl->NB_interp_t[iN] = nontensor_rtc_get_nb(arch, 't', 1, P, Q, n_array[iN]);
329         impl->NB_grad[iN]     = nontensor_rtc_get_nb(arch, 'n', dim, P, Q, n_array[iN]);
330         impl->NB_grad_t[iN]   = nontensor_rtc_get_nb(arch, 't', dim, P, Q, n_array[iN]);
331 
332         // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
333         CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
334 
335         // Compile kernels
336         CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-interp-nontensor.h", &interp_kernel_path));
337         CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
338         CeedCallBackend(CeedLoadSourceToBuffer(ceed, interp_kernel_path, &basis_kernel_source));
339         CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-grad-nontensor.h", &grad_kernel_path));
340         CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_kernel_path, &basis_kernel_source));
341         CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
342         CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module_interp[iN], 7, "BASIS_DIM", dim, "BASIS_P", P, "BASIS_Q",
343                                          Q, "BASIS_NB_INTERP_N", impl->NB_interp[iN], "BASIS_NB_INTERP_T", impl->NB_interp_t[iN], "BASIS_NB_GRAD_N",
344                                          impl->NB_grad[iN], "BASIS_NB_GRAD_T", impl->NB_grad_t[iN]));
345         CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_interp_nontensor_n", &impl->Interp[iN]));
346         CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_interp_nontensor_t", &impl->InterpTranspose[iN]));
347         CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_grad_nontensor_n", &impl->Grad[iN]));
348         CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_grad_nontensor_t", &impl->GradTranspose[iN]));
349         CeedCallBackend(CeedFree(&interp_kernel_path));
350         CeedCallBackend(CeedFree(&grad_kernel_path));
351         CeedCallBackend(CeedFree(&basis_kernel_source));
352       }
353 
354       // Apply basis operation
355       CeedInt num_t_col = MAGMA_BASIS_NTCOL(M, MAGMA_MAXTHREADS_1D);
356       if (e_mode == CEED_EVAL_INTERP) {
357         CeedInt NB           = (t_mode == CEED_TRANSPOSE) ? impl->NB_interp_t[iN] : impl->NB_interp[iN];
358         CeedInt grid         = CeedDivUpInt(N, NB * num_t_col);
359         CeedInt shared_mem_A = (t_mode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar);
360         CeedInt shared_mem_B = num_t_col * K * NB * sizeof(CeedScalar);
361         CeedInt shared_mem   = (t_mode == CEED_TRANSPOSE) ? (shared_mem_A + shared_mem_B) : CeedIntMax(shared_mem_A, shared_mem_B);
362         void   *args[]       = {&N, &impl->d_interp, &P, &d_u, &K, &d_v, &M};
363 
364         if (t_mode == CEED_TRANSPOSE) {
365           CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->InterpTranspose[iN], grid, M, num_t_col, 1, shared_mem, args));
366         } else {
367           CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Interp[iN], grid, M, num_t_col, 1, shared_mem, args));
368         }
369       } else if (e_mode == CEED_EVAL_GRAD) {
370         CeedInt NB         = (t_mode == CEED_TRANSPOSE) ? impl->NB_grad_t[iN] : impl->NB_grad[iN];
371         CeedInt grid       = CeedDivUpInt(N, NB * num_t_col);
372         CeedInt shared_mem = num_t_col * K * NB * sizeof(CeedScalar) + ((t_mode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar));
373         void   *args[]     = {&N, &impl->d_grad, &P, &d_u, &K, &d_v, &M};
374 
375         if (t_mode == CEED_TRANSPOSE) {
376           CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->GradTranspose[iN], grid, M, num_t_col, 1, shared_mem, args));
377         } else {
378           CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Grad[iN], grid, M, num_t_col, 1, shared_mem, args));
379         }
380       } else {
381         // LCOV_EXCL_START
382         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV, CEED_EVAL_CURL not supported");
383         // LCOV_EXCL_STOP
384       }
385     } else {
386       if (e_mode == CEED_EVAL_INTERP) {
387         if (t_mode == CEED_TRANSPOSE) {
388           magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, N, Q, 1.0, impl->d_interp, P, d_u, Q, 0.0, d_v, P, data->queue);
389         } else {
390           magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, N, P, 1.0, impl->d_interp, P, d_u, P, 0.0, d_v, Q, data->queue);
391         }
392       } else if (e_mode == CEED_EVAL_GRAD) {
393         if (t_mode == CEED_TRANSPOSE) {
394           for (int d = 0; d < dim; d++) {
395             const CeedScalar beta = (d > 0) ? 1.0 : 0.0;
396             magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, N, Q, 1.0, impl->d_grad + d * P * Q, P, d_u + d * N * Q, Q, beta, d_v, P,
397                                  data->queue);
398           }
399         } else {
400           for (int d = 0; d < dim; d++) {
401             magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, N, P, 1.0, impl->d_grad + d * P * Q, P, d_u, P, 0.0, d_v + d * N * Q, Q, data->queue);
402           }
403         }
404       } else {
405         // LCOV_EXCL_START
406         return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV, CEED_EVAL_CURL not supported");
407         // LCOV_EXCL_STOP
408       }
409     }
410   } else {
411     CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE");
412     CeedInt num_t_col  = MAGMA_BASIS_NTCOL(Q, MAGMA_MAXTHREADS_1D);
413     CeedInt grid       = CeedDivUpInt(num_elem, num_t_col);
414     CeedInt shared_mem = Q * sizeof(CeedScalar) + num_t_col * Q * sizeof(CeedScalar);
415     void   *args[]     = {&num_elem, &impl->d_q_weight, &d_v, &Q};
416 
417     CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Weight, grid, Q, num_t_col, 1, shared_mem, args));
418   }
419 
420   // Must sync to ensure completeness
421   ceed_magma_queue_sync(data->queue);
422 
423   // Restore vectors
424   if (e_mode != CEED_EVAL_WEIGHT) {
425     CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
426   }
427   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
428   return CEED_ERROR_SUCCESS;
429 }
430 
431 //------------------------------------------------------------------------------
432 // Destroy tensor basis
433 //------------------------------------------------------------------------------
434 static int CeedBasisDestroy_Magma(CeedBasis basis) {
435   Ceed             ceed;
436   CeedBasis_Magma *impl;
437 
438   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
439   CeedCallBackend(CeedBasisGetData(basis, &impl));
440 #ifdef CEED_MAGMA_USE_HIP
441   CeedCallHip(ceed, hipModuleUnload(impl->module));
442 #else
443   CeedCallCuda(ceed, cuModuleUnload(impl->module));
444 #endif
445   CeedCallBackend(magma_free(impl->d_interp_1d));
446   CeedCallBackend(magma_free(impl->d_grad_1d));
447   CeedCallBackend(magma_free(impl->d_q_weight_1d));
448   CeedCallBackend(CeedFree(&impl));
449   return CEED_ERROR_SUCCESS;
450 }
451 
452 //------------------------------------------------------------------------------
453 // Destroy non-tensor basis
454 //------------------------------------------------------------------------------
455 static int CeedBasisDestroyNonTensor_Magma(CeedBasis basis) {
456   Ceed                      ceed;
457   CeedBasisNonTensor_Magma *impl;
458 
459   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
460   CeedCallBackend(CeedBasisGetData(basis, &impl));
461 #ifdef CEED_MAGMA_USE_HIP
462   CeedCallHip(ceed, hipModuleUnload(impl->module_weight));
463 #else
464   CeedCallCuda(ceed, cuModuleUnload(impl->module_weight));
465 #endif
466   for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) {
467     if (impl->module_interp[in]) {
468 #ifdef CEED_MAGMA_USE_HIP
469       CeedCallHip(ceed, hipModuleUnload(impl->module_interp[in]));
470 #else
471       CeedCallCuda(ceed, cuModuleUnload(impl->module_interp[in]));
472 #endif
473     }
474   }
475   CeedCallBackend(magma_free(impl->d_interp));
476   CeedCallBackend(magma_free(impl->d_grad));
477   CeedCallBackend(magma_free(impl->d_q_weight));
478   CeedCallBackend(CeedFree(&impl));
479   return CEED_ERROR_SUCCESS;
480 }
481 
482 //------------------------------------------------------------------------------
483 // Create tensor
484 //------------------------------------------------------------------------------
485 int CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
486                                   const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
487   Ceed             ceed, ceed_delegate;
488   Ceed_Magma      *data;
489   char            *interp_kernel_path, *grad_kernel_path, *weight_kernel_path, *basis_kernel_source;
490   CeedInt          num_comp;
491   CeedBasis_Magma *impl;
492 
493   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
494   CeedCallBackend(CeedGetData(ceed, &data));
495   CeedCallBackend(CeedCalloc(1, &impl));
496 
497   // Copy basis data to GPU
498   CeedCallBackend(magma_malloc((void **)&impl->d_q_weight_1d, Q_1d * sizeof(q_weight_1d[0])));
499   magma_setvector(Q_1d, sizeof(q_weight_1d[0]), q_weight_1d, 1, impl->d_q_weight_1d, 1, data->queue);
500   CeedCallBackend(magma_malloc((void **)&impl->d_interp_1d, Q_1d * P_1d * sizeof(interp_1d[0])));
501   magma_setvector(Q_1d * P_1d, sizeof(interp_1d[0]), interp_1d, 1, impl->d_interp_1d, 1, data->queue);
502   CeedCallBackend(magma_malloc((void **)&impl->d_grad_1d, Q_1d * P_1d * sizeof(grad_1d[0])));
503   magma_setvector(Q_1d * P_1d, sizeof(grad_1d[0]), grad_1d, 1, impl->d_grad_1d, 1, data->queue);
504 
505   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
506   CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
507 
508   // Compile kernels
509   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
510   {
511     char   *interp_kernel_name_base = "ceed/jit-source/magma/magma-basis-interp";
512     CeedInt interp_kernel_name_len  = strlen(interp_kernel_name_base) + 6;
513     char    interp_kernel_name[interp_kernel_name_len];
514 
515     snprintf(interp_kernel_name, interp_kernel_name_len, "%s-%" CeedInt_FMT "d.h", interp_kernel_name_base, dim);
516     CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_kernel_name, &interp_kernel_path));
517   }
518   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
519   CeedCallBackend(CeedLoadSourceToBuffer(ceed, interp_kernel_path, &basis_kernel_source));
520   {
521     char   *grad_kernel_name_base = "ceed/jit-source/magma/magma-basis-grad";
522     CeedInt grad_kernel_name_len  = strlen(grad_kernel_name_base) + 6;
523     char    grad_kernel_name[grad_kernel_name_len];
524 
525     snprintf(grad_kernel_name, grad_kernel_name_len, "%s-%" CeedInt_FMT "d.h", grad_kernel_name_base, dim);
526     CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_kernel_name, &grad_kernel_path));
527   }
528   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_kernel_path, &basis_kernel_source));
529   {
530     char   *weight_kernel_name_base = "ceed/jit-source/magma/magma-basis-weight";
531     CeedInt weight_kernel_name_len  = strlen(weight_kernel_name_base) + 6;
532     char    weight_kernel_name[weight_kernel_name_len];
533 
534     snprintf(weight_kernel_name, weight_kernel_name_len, "%s-%" CeedInt_FMT "d.h", weight_kernel_name_base, dim);
535     CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_kernel_name, &weight_kernel_path));
536   }
537   CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_kernel_path, &basis_kernel_source));
538   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
539   CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module, 5, "BASIS_DIM", dim, "BASIS_NUM_COMP", num_comp, "BASIS_P",
540                                    P_1d, "BASIS_Q", Q_1d, "BASIS_MAX_P_Q", CeedIntMax(P_1d, Q_1d)));
541   switch (dim) {
542     case 1:
543       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->Interp));
544       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->InterpTranspose));
545       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->Grad));
546       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->GradTranspose));
547       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->Weight));
548       break;
549     case 2:
550       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->Interp));
551       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->InterpTranspose));
552       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->Grad));
553       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->GradTranspose));
554       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->Weight));
555       break;
556     case 3:
557       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->Interp));
558       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->InterpTranspose));
559       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->Grad));
560       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->GradTranspose));
561       CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->Weight));
562       break;
563   }
564   CeedCallBackend(CeedFree(&interp_kernel_path));
565   CeedCallBackend(CeedFree(&grad_kernel_path));
566   CeedCallBackend(CeedFree(&weight_kernel_path));
567   CeedCallBackend(CeedFree(&basis_kernel_source));
568 
569   CeedCallBackend(CeedBasisSetData(basis, impl));
570 
571   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma));
572   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma));
573   return CEED_ERROR_SUCCESS;
574 }
575 
576 //------------------------------------------------------------------------------
577 // Create non-tensor H^1
578 //------------------------------------------------------------------------------
579 int CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad,
580                             const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
581   Ceed                      ceed, ceed_delegate;
582   Ceed_Magma               *data;
583   char                     *weight_kernel_path, *basis_kernel_source;
584   CeedBasisNonTensor_Magma *impl;
585 
586   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
587   CeedCallBackend(CeedGetData(ceed, &data));
588   CeedCallBackend(CeedCalloc(1, &impl));
589 
590   // Copy basis data to GPU
591   CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0])));
592   magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue);
593   CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * sizeof(interp[0])));
594   magma_setvector(num_qpts * num_nodes, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue);
595   CeedCallBackend(magma_malloc((void **)&impl->d_grad, num_qpts * num_nodes * dim * sizeof(grad[0])));
596   magma_setvector(num_qpts * num_nodes * dim, sizeof(grad[0]), grad, 1, impl->d_grad, 1, data->queue);
597 
598   // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data
599   CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate));
600 
601   // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply)
602   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path));
603   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
604   CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source));
605   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
606   CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module_weight, 1, "BASIS_Q", num_qpts));
607   CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_weight, "magma_weight_nontensor", &impl->Weight));
608   CeedCallBackend(CeedFree(&weight_kernel_path));
609   CeedCallBackend(CeedFree(&basis_kernel_source));
610 
611   CeedCallBackend(CeedBasisSetData(basis, impl));
612 
613   // Register backend functions
614   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma));
615   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma));
616   return CEED_ERROR_SUCCESS;
617 }
618 
619 //------------------------------------------------------------------------------
620