xref: /libCEED/backends/hip-shared/ceed-hip-shared-basis.c (revision a24d84eaf50532bd6ddb3309c91171c35669c827)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <string.h>
14 #include <hip/hip_runtime.h>
15 
16 #include "../hip/ceed-hip-common.h"
17 #include "../hip/ceed-hip-compile.h"
18 #include "ceed-hip-shared.h"
19 
20 //------------------------------------------------------------------------------
21 // Compute a block size based on required minimum threads
22 //------------------------------------------------------------------------------
23 static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) {
24   CeedInt maxSize     = 1024;  // Max total threads per block
25   CeedInt currentSize = 64;    // Start with one group
26 
27   while (currentSize < maxSize) {
28     if (currentSize > required) break;
29     else currentSize = currentSize * 2;
30   }
31   return currentSize;
32 }
33 
34 //------------------------------------------------------------------------------
35 // Compute required thread block sizes for basis kernels given P, Q, dim, and
36 // num_comp (num_comp not currently used, but may be again in other basis
37 // parallelization options)
38 //------------------------------------------------------------------------------
39 static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P_1d, const CeedInt Q_1d, const CeedInt num_comp, CeedInt *block_sizes) {
40   // Note that this will use the same block sizes for all dimensions when compiling,
41   // but as each basis object is defined for a particular dimension, we will never
42   // call any kernels except the ones for the dimension for which we have computed the
43   // block sizes.
44   const CeedInt thread_1d = CeedIntMax(P_1d, Q_1d);
45 
46   switch (dim) {
47     case 1: {
48       // Interp kernels:
49       block_sizes[0] = 256;
50 
51       // Grad kernels:
52       block_sizes[1] = 256;
53 
54       // Weight kernels:
55       block_sizes[2] = 256;
56     } break;
57     case 2: {
58       // Interp kernels:
59       CeedInt required = thread_1d * thread_1d;
60 
61       block_sizes[0] = CeedIntMax(256, ComputeBlockSizeFromRequirement(required));
62 
63       // Grad kernels: currently use same required minimum threads
64       block_sizes[1] = CeedIntMax(256, ComputeBlockSizeFromRequirement(required));
65 
66       // Weight kernels:
67       required       = CeedIntMax(64, Q_1d * Q_1d);
68       block_sizes[2] = CeedIntMax(256, ComputeBlockSizeFromRequirement(required));
69 
70     } break;
71     case 3: {
72       // Interp kernels:
73       CeedInt required = thread_1d * thread_1d;
74 
75       block_sizes[0] = CeedIntMax(256, ComputeBlockSizeFromRequirement(required));
76 
77       // Grad kernels: currently use same required minimum threads
78       block_sizes[1] = CeedIntMax(256, ComputeBlockSizeFromRequirement(required));
79 
80       // Weight kernels:
81       required       = Q_1d * Q_1d * Q_1d;
82       block_sizes[2] = CeedIntMax(256, ComputeBlockSizeFromRequirement(required));
83     }
84   }
85   return CEED_ERROR_SUCCESS;
86 }
87 
88 //------------------------------------------------------------------------------
89 // Apply basis
90 //------------------------------------------------------------------------------
91 static int CeedBasisApplyTensorCore_Hip_shared(CeedBasis basis, bool apply_add, const CeedInt num_elem, CeedTransposeMode t_mode,
92                                                CeedEvalMode eval_mode, CeedVector u, CeedVector v) {
93   Ceed                  ceed;
94   Ceed_Hip             *ceed_Hip;
95   CeedInt               dim, num_comp;
96   const CeedScalar     *d_u;
97   CeedScalar           *d_v;
98   CeedBasis_Hip_shared *data;
99 
100   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
101   CeedCallBackend(CeedGetData(ceed, &ceed_Hip));
102   CeedCallBackend(CeedBasisGetData(basis, &data));
103   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
104   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
105 
106   // Get read/write access to u, v
107   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
108   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
109   if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
110   else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
111 
112   // Apply basis operation
113   switch (eval_mode) {
114     case CEED_EVAL_INTERP: {
115       CeedInt P_1d, Q_1d;
116       CeedInt block_size = data->block_sizes[0];
117 
118       CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
119       CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
120       CeedInt thread_1d     = CeedIntMax(Q_1d, P_1d);
121       void   *interp_args[] = {(void *)&num_elem, &data->d_interp_1d, &d_u, &d_v};
122 
123       if (dim == 1) {
124         CeedInt elems_per_block = 64 * thread_1d > 256 ? 256 / thread_1d : 64;
125         elems_per_block         = elems_per_block > 0 ? elems_per_block : 1;
126         CeedInt grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
127         CeedInt shared_mem      = elems_per_block * thread_1d * sizeof(CeedScalar);
128 
129         if (t_mode == CEED_TRANSPOSE) {
130           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, apply_add ? data->InterpTransposeAdd : data->InterpTranspose, grid, thread_1d, 1,
131                                                      elems_per_block, shared_mem, interp_args));
132         } else {
133           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->Interp, grid, thread_1d, 1, elems_per_block, shared_mem, interp_args));
134         }
135       } else if (dim == 2) {
136         // Check if required threads is small enough to do multiple elems
137         const CeedInt elems_per_block = CeedIntMax(block_size / (thread_1d * thread_1d), 1);
138         CeedInt       grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
139         CeedInt       shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
140 
141         if (t_mode == CEED_TRANSPOSE) {
142           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, apply_add ? data->InterpTransposeAdd : data->InterpTranspose, grid, thread_1d, thread_1d,
143                                                      elems_per_block, shared_mem, interp_args));
144         } else {
145           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->Interp, grid, thread_1d, thread_1d, elems_per_block, shared_mem, interp_args));
146         }
147       } else if (dim == 3) {
148         const CeedInt elems_per_block = CeedIntMax(block_size / (thread_1d * thread_1d), 1);
149         CeedInt       grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
150         CeedInt       shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
151 
152         if (t_mode == CEED_TRANSPOSE) {
153           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, apply_add ? data->InterpTransposeAdd : data->InterpTranspose, grid, thread_1d, thread_1d,
154                                                      elems_per_block, shared_mem, interp_args));
155         } else {
156           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->Interp, grid, thread_1d, thread_1d, elems_per_block, shared_mem, interp_args));
157         }
158       }
159     } break;
160     case CEED_EVAL_GRAD: {
161       CeedInt P_1d, Q_1d;
162       CeedInt block_size = data->block_sizes[1];
163 
164       CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
165       CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
166       CeedInt     thread_1d = CeedIntMax(Q_1d, P_1d);
167       CeedScalar *d_grad_1d = data->d_grad_1d;
168 
169       if (data->d_collo_grad_1d) {
170         d_grad_1d = data->d_collo_grad_1d;
171       }
172       void *grad_args[] = {(void *)&num_elem, &data->d_interp_1d, &d_grad_1d, &d_u, &d_v};
173 
174       if (dim == 1) {
175         CeedInt elems_per_block = 64 * thread_1d > 256 ? 256 / thread_1d : 64;
176         elems_per_block         = elems_per_block > 0 ? elems_per_block : 1;
177         CeedInt grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
178         CeedInt shared_mem      = elems_per_block * thread_1d * sizeof(CeedScalar);
179 
180         if (t_mode == CEED_TRANSPOSE) {
181           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, apply_add ? data->GradTransposeAdd : data->GradTranspose, grid, thread_1d, 1,
182                                                      elems_per_block, shared_mem, grad_args));
183         } else {
184           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->Grad, grid, thread_1d, 1, elems_per_block, shared_mem, grad_args));
185         }
186       } else if (dim == 2) {
187         // Check if required threads is small enough to do multiple elems
188         const CeedInt elems_per_block = CeedIntMax(block_size / (thread_1d * thread_1d), 1);
189         CeedInt       grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
190         CeedInt       shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
191 
192         if (t_mode == CEED_TRANSPOSE) {
193           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, apply_add ? data->GradTransposeAdd : data->GradTranspose, grid, thread_1d, thread_1d,
194                                                      elems_per_block, shared_mem, grad_args));
195         } else {
196           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->Grad, grid, thread_1d, thread_1d, elems_per_block, shared_mem, grad_args));
197         }
198       } else if (dim == 3) {
199         const CeedInt elems_per_block = CeedIntMax(block_size / (thread_1d * thread_1d), 1);
200         CeedInt       grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
201         CeedInt       shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
202 
203         if (t_mode == CEED_TRANSPOSE) {
204           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, apply_add ? data->GradTransposeAdd : data->GradTranspose, grid, thread_1d, thread_1d,
205                                                      elems_per_block, shared_mem, grad_args));
206         } else {
207           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->Grad, grid, thread_1d, thread_1d, elems_per_block, shared_mem, grad_args));
208         }
209       }
210     } break;
211     case CEED_EVAL_WEIGHT: {
212       CeedInt Q_1d;
213       CeedInt block_size = data->block_sizes[2];
214 
215       CeedCheck(data->d_q_weight_1d, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weights_1d not set", CeedEvalModes[eval_mode]);
216       CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
217       void *weight_args[] = {(void *)&num_elem, (void *)&data->d_q_weight_1d, &d_v};
218 
219       if (dim == 1) {
220         const CeedInt opt_elems       = block_size / Q_1d;
221         const CeedInt elems_per_block = opt_elems > 0 ? opt_elems : 1;
222         const CeedInt grid_size       = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
223 
224         CeedCallBackend(CeedRunKernelDim_Hip(ceed, data->Weight, grid_size, Q_1d, elems_per_block, 1, weight_args));
225       } else if (dim == 2) {
226         const CeedInt opt_elems       = block_size / (Q_1d * Q_1d);
227         const CeedInt elems_per_block = opt_elems > 0 ? opt_elems : 1;
228         const CeedInt grid_size       = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
229 
230         CeedCallBackend(CeedRunKernelDim_Hip(ceed, data->Weight, grid_size, Q_1d, Q_1d, elems_per_block, weight_args));
231       } else if (dim == 3) {
232         const CeedInt opt_elems       = block_size / (Q_1d * Q_1d);
233         const CeedInt elems_per_block = opt_elems > 0 ? opt_elems : 1;
234         const CeedInt grid_size       = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
235 
236         CeedCallBackend(CeedRunKernelDim_Hip(ceed, data->Weight, grid_size, Q_1d, Q_1d, elems_per_block, weight_args));
237       }
238     } break;
239     case CEED_EVAL_NONE: /* handled separately below */
240       break;
241     // LCOV_EXCL_START
242     case CEED_EVAL_DIV:
243     case CEED_EVAL_CURL:
244       return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
245       // LCOV_EXCL_STOP
246   }
247 
248   // Restore vectors, cover CEED_EVAL_NONE
249   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
250   if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
251   if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
252   CeedCallBackend(CeedDestroy(&ceed));
253   return CEED_ERROR_SUCCESS;
254 }
255 
256 int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
257                                     CeedVector v) {
258   CeedCallBackend(CeedBasisApplyTensorCore_Hip_shared(basis, false, num_elem, t_mode, eval_mode, u, v));
259   return CEED_ERROR_SUCCESS;
260 }
261 
262 int CeedBasisApplyAddTensor_Hip_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
263                                        CeedVector v) {
264   CeedCallBackend(CeedBasisApplyTensorCore_Hip_shared(basis, true, num_elem, t_mode, eval_mode, u, v));
265   return CEED_ERROR_SUCCESS;
266 }
267 
268 //------------------------------------------------------------------------------
269 // Basis apply - tensor AtPoints
270 //------------------------------------------------------------------------------
271 static int CeedBasisApplyAtPointsCore_Hip_shared(CeedBasis basis, bool apply_add, const CeedInt num_elem, const CeedInt *num_points,
272                                                  CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector x_ref, CeedVector u, CeedVector v) {
273   Ceed                  ceed;
274   CeedInt               Q_1d, dim, max_num_points = num_points[0];
275   const CeedInt         is_transpose = t_mode == CEED_TRANSPOSE;
276   const CeedScalar     *d_x, *d_u;
277   CeedScalar           *d_v;
278   CeedBasis_Hip_shared *data;
279 
280   CeedCallBackend(CeedBasisGetData(basis, &data));
281   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
282   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
283 
284   // Weight handled separately
285   if (eval_mode == CEED_EVAL_WEIGHT) {
286     CeedCallBackend(CeedVectorSetValue(v, 1.0));
287     return CEED_ERROR_SUCCESS;
288   }
289 
290   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
291 
292   // Check padded to uniform number of points per elem
293   for (CeedInt i = 1; i < num_elem; i++) max_num_points = CeedIntMax(max_num_points, num_points[i]);
294   {
295     CeedInt  num_comp, q_comp;
296     CeedSize len, len_required;
297 
298     CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
299     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
300     CeedCallBackend(CeedVectorGetLength(is_transpose ? u : v, &len));
301     len_required = (CeedSize)num_comp * (CeedSize)q_comp * (CeedSize)num_elem * (CeedSize)max_num_points;
302     CeedCheck(len >= len_required, ceed, CEED_ERROR_BACKEND,
303               "Vector at points must be padded to the same number of points in each element for BasisApplyAtPoints on GPU backends."
304               " Found %" CeedSize_FMT ", Required %" CeedSize_FMT,
305               len, len_required);
306   }
307 
308   // Move num_points array to device
309   if (is_transpose) {
310     const CeedInt num_bytes = num_elem * sizeof(CeedInt);
311 
312     if (num_elem != data->num_elem_at_points) {
313       data->num_elem_at_points = num_elem;
314 
315       if (data->d_points_per_elem) CeedCallHip(ceed, hipFree(data->d_points_per_elem));
316       CeedCallHip(ceed, hipMalloc((void **)&data->d_points_per_elem, num_bytes));
317       CeedCallBackend(CeedFree(&data->h_points_per_elem));
318       CeedCallBackend(CeedCalloc(num_elem, &data->h_points_per_elem));
319     }
320     if (memcmp(data->h_points_per_elem, num_points, num_bytes)) {
321       memcpy(data->h_points_per_elem, num_points, num_bytes);
322       CeedCallHip(ceed, hipMemcpy(data->d_points_per_elem, num_points, num_bytes, hipMemcpyHostToDevice));
323     }
324   }
325 
326   // Build kernels if needed
327   if (data->num_points != max_num_points) {
328     CeedInt P_1d;
329 
330     CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
331     data->num_points = max_num_points;
332 
333     // -- Create interp matrix to Chebyshev coefficients
334     if (!data->d_chebyshev_interp_1d) {
335       CeedSize    interp_bytes;
336       CeedScalar *chebyshev_interp_1d;
337 
338       interp_bytes = P_1d * Q_1d * sizeof(CeedScalar);
339       CeedCallBackend(CeedCalloc(P_1d * Q_1d, &chebyshev_interp_1d));
340       CeedCallBackend(CeedBasisGetChebyshevInterp1D(basis, chebyshev_interp_1d));
341       CeedCallHip(ceed, hipMalloc((void **)&data->d_chebyshev_interp_1d, interp_bytes));
342       CeedCallHip(ceed, hipMemcpy(data->d_chebyshev_interp_1d, chebyshev_interp_1d, interp_bytes, hipMemcpyHostToDevice));
343       CeedCallBackend(CeedFree(&chebyshev_interp_1d));
344     }
345 
346     // -- Compile kernels
347     const char basis_kernel_source[] = "// AtPoints basis source\n#include <ceed/jit-source/hip/hip-shared-basis-tensor-at-points.h>\n";
348     CeedInt    num_comp;
349 
350     if (data->moduleAtPoints) CeedCallHip(ceed, hipModuleUnload(data->moduleAtPoints));
351     CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
352     CeedCallBackend(CeedCompile_Hip(ceed, basis_kernel_source, &data->moduleAtPoints, 9, "BASIS_Q_1D", Q_1d, "BASIS_P_1D", P_1d, "T_1D",
353                                     CeedIntMax(Q_1d, P_1d), "BASIS_DIM", dim, "BASIS_NUM_COMP", num_comp, "BASIS_NUM_NODES", CeedIntPow(P_1d, dim),
354                                     "BASIS_NUM_QPTS", CeedIntPow(Q_1d, dim), "BASIS_NUM_PTS", max_num_points, "BASIS_INTERP_BLOCK_SIZE",
355                                     data->block_sizes[0]));
356     CeedCallBackend(CeedGetKernel_Hip(ceed, data->moduleAtPoints, "InterpAtPoints", &data->InterpAtPoints));
357     CeedCallBackend(CeedGetKernel_Hip(ceed, data->moduleAtPoints, "InterpTransposeAtPoints", &data->InterpTransposeAtPoints));
358     CeedCallBackend(CeedGetKernel_Hip(ceed, data->moduleAtPoints, "GradAtPoints", &data->GradAtPoints));
359     CeedCallBackend(CeedGetKernel_Hip(ceed, data->moduleAtPoints, "GradTransposeAtPoints", &data->GradTransposeAtPoints));
360   }
361 
362   // Get read/write access to u, v
363   CeedCallBackend(CeedVectorGetArrayRead(x_ref, CEED_MEM_DEVICE, &d_x));
364   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
365   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
366   if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
367   else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
368 
369   // Clear v for transpose operation
370   if (is_transpose && !apply_add) {
371     CeedInt  num_comp, q_comp, num_nodes;
372     CeedSize length;
373 
374     CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
375     CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp));
376     CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes));
377     length =
378         (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp));
379     CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar)));
380   }
381 
382   // Basis action
383   switch (eval_mode) {
384     case CEED_EVAL_INTERP: {
385       CeedInt P_1d, Q_1d;
386       CeedInt block_size = data->block_sizes[0];
387 
388       CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
389       CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
390       CeedInt thread_1d     = CeedIntMax(Q_1d, P_1d);
391       void   *interp_args[] = {(void *)&num_elem, &data->d_chebyshev_interp_1d, &data->d_points_per_elem, &d_x, &d_u, &d_v};
392 
393       if (dim == 1) {
394         CeedInt elems_per_block = 64 * thread_1d > 256 ? 256 / thread_1d : 64;
395         elems_per_block         = elems_per_block > 0 ? elems_per_block : 1;
396         CeedInt grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
397         CeedInt shared_mem      = elems_per_block * thread_1d * sizeof(CeedScalar);
398 
399         CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, is_transpose ? data->InterpTransposeAtPoints : data->InterpAtPoints, grid, thread_1d, 1,
400                                                    elems_per_block, shared_mem, interp_args));
401       } else if (dim == 2) {
402         // Check if required threads is small enough to do multiple elems
403         const CeedInt elems_per_block = CeedIntMax(block_size / (thread_1d * thread_1d), 1);
404         CeedInt       grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
405         CeedInt       shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
406 
407         CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, is_transpose ? data->InterpTransposeAtPoints : data->InterpAtPoints, grid, thread_1d,
408                                                    thread_1d, elems_per_block, shared_mem, interp_args));
409       } else if (dim == 3) {
410         const CeedInt elems_per_block = 1;
411         CeedInt       grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
412         CeedInt       shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
413 
414         CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, is_transpose ? data->InterpTransposeAtPoints : data->InterpAtPoints, grid, thread_1d,
415                                                    thread_1d, elems_per_block, shared_mem, interp_args));
416       }
417     } break;
418     case CEED_EVAL_GRAD: {
419       CeedInt P_1d, Q_1d;
420       CeedInt block_size = data->block_sizes[0];
421 
422       CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
423       CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
424       CeedInt thread_1d   = CeedIntMax(Q_1d, P_1d);
425       void   *grad_args[] = {(void *)&num_elem, &data->d_chebyshev_interp_1d, &data->d_points_per_elem, &d_x, &d_u, &d_v};
426 
427       if (dim == 1) {
428         CeedInt elems_per_block = 64 * thread_1d > 256 ? 256 / thread_1d : 64;
429         elems_per_block         = elems_per_block > 0 ? elems_per_block : 1;
430         CeedInt grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
431         CeedInt shared_mem      = elems_per_block * thread_1d * sizeof(CeedScalar);
432 
433         CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, is_transpose ? data->GradTransposeAtPoints : data->GradAtPoints, grid, thread_1d, 1,
434                                                    elems_per_block, shared_mem, grad_args));
435       } else if (dim == 2) {
436         // Check if required threads is small enough to do multiple elems
437         const CeedInt elems_per_block = CeedIntMax(block_size / (thread_1d * thread_1d), 1);
438         CeedInt       grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
439         CeedInt       shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
440 
441         CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, is_transpose ? data->GradTransposeAtPoints : data->GradAtPoints, grid, thread_1d, thread_1d,
442                                                    elems_per_block, shared_mem, grad_args));
443       } else if (dim == 3) {
444         const CeedInt elems_per_block = 1;
445         CeedInt       grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
446         CeedInt       shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
447 
448         CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, is_transpose ? data->GradTransposeAtPoints : data->GradAtPoints, grid, thread_1d, thread_1d,
449                                                    elems_per_block, shared_mem, grad_args));
450       }
451     } break;
452     case CEED_EVAL_WEIGHT:
453     case CEED_EVAL_NONE: /* handled separately below */
454       break;
455     // LCOV_EXCL_START
456     case CEED_EVAL_DIV:
457     case CEED_EVAL_CURL:
458       return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
459       // LCOV_EXCL_STOP
460   }
461 
462   // Restore vectors, cover CEED_EVAL_NONE
463   CeedCallBackend(CeedVectorRestoreArrayRead(x_ref, &d_x));
464   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
465   if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
466   if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
467   return CEED_ERROR_SUCCESS;
468 }
469 
470 static int CeedBasisApplyAtPoints_Hip_shared(CeedBasis basis, const CeedInt num_elem, const CeedInt *num_points, CeedTransposeMode t_mode,
471                                              CeedEvalMode eval_mode, CeedVector x_ref, CeedVector u, CeedVector v) {
472   CeedCallBackend(CeedBasisApplyAtPointsCore_Hip_shared(basis, false, num_elem, num_points, t_mode, eval_mode, x_ref, u, v));
473   return CEED_ERROR_SUCCESS;
474 }
475 
476 static int CeedBasisApplyAddAtPoints_Hip_shared(CeedBasis basis, const CeedInt num_elem, const CeedInt *num_points, CeedTransposeMode t_mode,
477                                                 CeedEvalMode eval_mode, CeedVector x_ref, CeedVector u, CeedVector v) {
478   CeedCallBackend(CeedBasisApplyAtPointsCore_Hip_shared(basis, true, num_elem, num_points, t_mode, eval_mode, x_ref, u, v));
479   return CEED_ERROR_SUCCESS;
480 }
481 
482 //------------------------------------------------------------------------------
483 // Apply basis
484 //------------------------------------------------------------------------------
485 static int CeedBasisApplyNonTensorCore_Hip_shared(CeedBasis basis, bool apply_add, const CeedInt num_elem, CeedTransposeMode t_mode,
486                                                   CeedEvalMode eval_mode, CeedVector u, CeedVector v) {
487   Ceed                  ceed;
488   Ceed_Hip             *ceed_Hip;
489   CeedInt               dim, num_comp;
490   const CeedScalar     *d_u;
491   CeedScalar           *d_v;
492   CeedBasis_Hip_shared *data;
493 
494   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
495   CeedCallBackend(CeedGetData(ceed, &ceed_Hip));
496   CeedCallBackend(CeedBasisGetData(basis, &data));
497   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
498   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
499 
500   // Get read/write access to u, v
501   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
502   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
503   if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
504   else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
505 
506   // Apply basis operation
507   switch (eval_mode) {
508     case CEED_EVAL_INTERP: {
509       CeedInt P, Q;
510 
511       CeedCallBackend(CeedBasisGetNumNodes(basis, &P));
512       CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &Q));
513       CeedInt thread        = CeedIntMax(Q, P);
514       void   *interp_args[] = {(void *)&num_elem, &data->d_interp_1d, &d_u, &d_v};
515 
516       {
517         CeedInt elems_per_block = 64 * thread > 256 ? 256 / thread : 64;
518         elems_per_block         = elems_per_block > 0 ? elems_per_block : 1;
519         CeedInt grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
520         CeedInt shared_mem      = elems_per_block * thread * sizeof(CeedScalar);
521 
522         if (t_mode == CEED_TRANSPOSE) {
523           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, apply_add ? data->InterpTransposeAdd : data->InterpTranspose, grid, thread, 1,
524                                                      elems_per_block, shared_mem, interp_args));
525         } else {
526           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->Interp, grid, thread, 1, elems_per_block, shared_mem, interp_args));
527         }
528       }
529     } break;
530     case CEED_EVAL_GRAD: {
531       CeedInt P, Q;
532 
533       CeedCallBackend(CeedBasisGetNumNodes(basis, &P));
534       CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &Q));
535       CeedInt thread      = CeedIntMax(Q, P);
536       void   *grad_args[] = {(void *)&num_elem, &data->d_interp_1d, &data->d_grad_1d, &d_u, &d_v};
537 
538       {
539         CeedInt elems_per_block = 64 * thread > 256 ? 256 / thread : 64;
540         elems_per_block         = elems_per_block > 0 ? elems_per_block : 1;
541         CeedInt grid            = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
542         CeedInt shared_mem      = elems_per_block * thread * sizeof(CeedScalar);
543 
544         if (t_mode == CEED_TRANSPOSE) {
545           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, apply_add ? data->GradTransposeAdd : data->GradTranspose, grid, thread, 1, elems_per_block,
546                                                      shared_mem, grad_args));
547         } else {
548           CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, data->Grad, grid, thread, 1, elems_per_block, shared_mem, grad_args));
549         }
550       }
551     } break;
552     case CEED_EVAL_WEIGHT: {
553       CeedInt Q;
554       CeedInt block_size = data->block_sizes[2];
555 
556       CeedCheck(data->d_q_weight_1d, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weights_1d not set", CeedEvalModes[eval_mode]);
557       CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q));
558       void *weight_args[] = {(void *)&num_elem, (void *)&data->d_q_weight_1d, &d_v};
559 
560       {
561         const CeedInt opt_elems       = block_size / Q;
562         const CeedInt elems_per_block = opt_elems > 0 ? opt_elems : 1;
563         const CeedInt grid_size       = num_elem / elems_per_block + (num_elem % elems_per_block > 0);
564 
565         CeedCallBackend(CeedRunKernelDim_Hip(ceed, data->Weight, grid_size, Q, elems_per_block, 1, weight_args));
566       }
567     } break;
568     case CEED_EVAL_NONE: /* handled separately below */
569       break;
570     // LCOV_EXCL_START
571     case CEED_EVAL_DIV:
572     case CEED_EVAL_CURL:
573       return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
574       // LCOV_EXCL_STOP
575   }
576 
577   // Restore vectors, cover CEED_EVAL_NONE
578   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
579   if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
580   if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
581   CeedCallBackend(CeedDestroy(&ceed));
582   return CEED_ERROR_SUCCESS;
583 }
584 
585 int CeedBasisApplyNonTensor_Hip_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
586                                        CeedVector v) {
587   CeedCallBackend(CeedBasisApplyNonTensorCore_Hip_shared(basis, false, num_elem, t_mode, eval_mode, u, v));
588   return CEED_ERROR_SUCCESS;
589 }
590 
591 int CeedBasisApplyAddNonTensor_Hip_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
592                                           CeedVector v) {
593   CeedCallBackend(CeedBasisApplyNonTensorCore_Hip_shared(basis, true, num_elem, t_mode, eval_mode, u, v));
594   return CEED_ERROR_SUCCESS;
595 }
596 
597 //------------------------------------------------------------------------------
598 // Destroy basis
599 //------------------------------------------------------------------------------
600 static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
601   Ceed                  ceed;
602   CeedBasis_Hip_shared *data;
603 
604   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
605   CeedCallBackend(CeedBasisGetData(basis, &data));
606   CeedCallHip(ceed, hipModuleUnload(data->module));
607   if (data->moduleAtPoints) CeedCallHip(ceed, hipModuleUnload(data->moduleAtPoints));
608   if (data->d_q_weight_1d) CeedCallHip(ceed, hipFree(data->d_q_weight_1d));
609   CeedCallBackend(CeedFree(&data->h_points_per_elem));
610   if (data->d_points_per_elem) CeedCallHip(ceed, hipFree(data->d_points_per_elem));
611   CeedCallHip(ceed, hipFree(data->d_interp_1d));
612   CeedCallHip(ceed, hipFree(data->d_grad_1d));
613   CeedCallHip(ceed, hipFree(data->d_collo_grad_1d));
614   CeedCallHip(ceed, hipFree(data->d_chebyshev_interp_1d));
615   CeedCallBackend(CeedFree(&data));
616   return CEED_ERROR_SUCCESS;
617 }
618 
619 //------------------------------------------------------------------------------
620 // Create tensor basis
621 //------------------------------------------------------------------------------
622 int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
623                                        const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
624   Ceed                  ceed;
625   CeedInt               num_comp;
626   const CeedInt         q_bytes      = Q_1d * sizeof(CeedScalar);
627   const CeedInt         interp_bytes = q_bytes * P_1d;
628   CeedBasis_Hip_shared *data;
629 
630   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
631   CeedCallBackend(CeedCalloc(1, &data));
632 
633   // Copy basis data to GPU
634   if (q_weight_1d) {
635     CeedCallHip(ceed, hipMalloc((void **)&data->d_q_weight_1d, q_bytes));
636     CeedCallHip(ceed, hipMemcpy(data->d_q_weight_1d, q_weight_1d, q_bytes, hipMemcpyHostToDevice));
637   }
638   CeedCallHip(ceed, hipMalloc((void **)&data->d_interp_1d, interp_bytes));
639   CeedCallHip(ceed, hipMemcpy(data->d_interp_1d, interp_1d, interp_bytes, hipMemcpyHostToDevice));
640   CeedCallHip(ceed, hipMalloc((void **)&data->d_grad_1d, interp_bytes));
641   CeedCallHip(ceed, hipMemcpy(data->d_grad_1d, grad_1d, interp_bytes, hipMemcpyHostToDevice));
642 
643   // Compute collocated gradient and copy to GPU
644   data->d_collo_grad_1d    = NULL;
645   bool has_collocated_grad = dim == 3 && Q_1d >= P_1d;
646 
647   if (has_collocated_grad) {
648     CeedScalar *collo_grad_1d;
649 
650     CeedCallBackend(CeedMalloc(Q_1d * Q_1d, &collo_grad_1d));
651     CeedCallBackend(CeedBasisGetCollocatedGrad(basis, collo_grad_1d));
652     CeedCallHip(ceed, hipMalloc((void **)&data->d_collo_grad_1d, q_bytes * Q_1d));
653     CeedCallHip(ceed, hipMemcpy(data->d_collo_grad_1d, collo_grad_1d, q_bytes * Q_1d, hipMemcpyHostToDevice));
654     CeedCallBackend(CeedFree(&collo_grad_1d));
655   }
656 
657   // Set number of threads per block for basis kernels
658   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
659   CeedCallBackend(ComputeBasisThreadBlockSizes(dim, P_1d, Q_1d, num_comp, data->block_sizes));
660 
661   // Compile basis kernels
662   const char basis_kernel_source[] = "// Tensor basis source\n#include <ceed/jit-source/hip/hip-shared-basis-tensor.h>\n";
663 
664   CeedCallBackend(CeedCompile_Hip(ceed, basis_kernel_source, &data->module, 11, "BASIS_Q_1D", Q_1d, "BASIS_P_1D", P_1d, "T_1D",
665                                   CeedIntMax(Q_1d, P_1d), "BASIS_DIM", dim, "BASIS_NUM_COMP", num_comp, "BASIS_NUM_NODES", CeedIntPow(P_1d, dim),
666                                   "BASIS_NUM_QPTS", CeedIntPow(Q_1d, dim), "BASIS_INTERP_BLOCK_SIZE", data->block_sizes[0], "BASIS_GRAD_BLOCK_SIZE",
667                                   data->block_sizes[1], "BASIS_WEIGHT_BLOCK_SIZE", data->block_sizes[2], "BASIS_HAS_COLLOCATED_GRAD",
668                                   has_collocated_grad));
669   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "Interp", &data->Interp));
670   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "InterpTranspose", &data->InterpTranspose));
671   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "InterpTransposeAdd", &data->InterpTransposeAdd));
672   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "Grad", &data->Grad));
673   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "GradTranspose", &data->GradTranspose));
674   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "GradTransposeAdd", &data->GradTransposeAdd));
675   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "Weight", &data->Weight));
676 
677   CeedCallBackend(CeedBasisSetData(basis, data));
678 
679   // Register backend functions
680   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyTensor_Hip_shared));
681   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddTensor_Hip_shared));
682   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAtPoints", CeedBasisApplyAtPoints_Hip_shared));
683   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAddAtPoints", CeedBasisApplyAddAtPoints_Hip_shared));
684   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Hip_shared));
685   CeedCallBackend(CeedDestroy(&ceed));
686   return CEED_ERROR_SUCCESS;
687 }
688 
689 //------------------------------------------------------------------------------
690 // Create non-tensor basis
691 //------------------------------------------------------------------------------
692 int CeedBasisCreateH1_Hip_shared(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp,
693                                  const CeedScalar *grad, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) {
694   Ceed                  ceed;
695   CeedInt               num_comp, q_comp_interp, q_comp_grad;
696   const CeedInt         q_bytes = num_qpts * sizeof(CeedScalar);
697   CeedBasis_Hip_shared *data;
698 
699   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
700   CeedCallBackend(CeedCalloc(1, &data));
701 
702   // Copy basis data to GPU
703   CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp));
704   CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad));
705   if (q_weight) {
706     CeedCallHip(ceed, hipMalloc((void **)&data->d_q_weight_1d, q_bytes));
707     CeedCallHip(ceed, hipMemcpy(data->d_q_weight_1d, q_weight, q_bytes, hipMemcpyHostToDevice));
708   }
709   if (interp) {
710     const CeedInt interp_bytes = q_bytes * num_nodes * q_comp_interp;
711 
712     CeedCallHip(ceed, hipMalloc((void **)&data->d_interp_1d, interp_bytes));
713     CeedCallHip(ceed, hipMemcpy(data->d_interp_1d, interp, interp_bytes, hipMemcpyHostToDevice));
714   }
715   if (grad) {
716     const CeedInt grad_bytes = q_bytes * num_nodes * q_comp_grad;
717 
718     CeedCallHip(ceed, hipMalloc((void **)&data->d_grad_1d, grad_bytes));
719     CeedCallHip(ceed, hipMemcpy(data->d_grad_1d, grad, grad_bytes, hipMemcpyHostToDevice));
720   }
721 
722   // Compile basis kernels
723   const char basis_kernel_source[] = "// Non-tensor basis source\n#include <ceed/jit-source/hip/hip-shared-basis-nontensor.h>\n";
724 
725   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
726   CeedCallBackend(CeedCompile_Hip(ceed, basis_kernel_source, &data->module, 5, "BASIS_Q", num_qpts, "BASIS_P", num_nodes, "T_1D",
727                                   CeedIntMax(num_qpts, num_nodes), "BASIS_DIM", dim, "BASIS_NUM_COMP", num_comp));
728   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "Interp", &data->Interp));
729   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "InterpTranspose", &data->InterpTranspose));
730   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "InterpTransposeAdd", &data->InterpTransposeAdd));
731   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "Grad", &data->Grad));
732   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "GradTranspose", &data->GradTranspose));
733   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "GradTransposeAdd", &data->GradTransposeAdd));
734   CeedCallBackend(CeedGetKernel_Hip(ceed, data->module, "Weight", &data->Weight));
735 
736   CeedCallBackend(CeedBasisSetData(basis, data));
737 
738   // Register backend functions
739   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Hip_shared));
740   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Hip_shared));
741   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Hip_shared));
742   CeedCallBackend(CeedDestroy(&ceed));
743   return CEED_ERROR_SUCCESS;
744 }
745 
746 //------------------------------------------------------------------------------
747