xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision d1931fc83dfaa61549375a0461a8efe2c16b442e)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <cuda.h>
12 #include <cuda_runtime.h>
13 #include <stdbool.h>
14 #include <stddef.h>
15 
16 #include "../cuda/ceed-cuda-common.h"
17 #include "../cuda/ceed-cuda-compile.h"
18 #include "ceed-cuda-shared.h"
19 
20 //------------------------------------------------------------------------------
21 // Device initalization
22 //------------------------------------------------------------------------------
23 int CeedInit_CudaInterp(CeedScalar *d_B, CeedInt P_1d, CeedInt Q_1d, CeedScalar **c_B);
24 int CeedInit_CudaGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P_1d, CeedInt Q_1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr);
25 int CeedInit_CudaCollocatedGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P_1d, CeedInt Q_1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr);
26 
27 //------------------------------------------------------------------------------
28 // Apply basis
29 //------------------------------------------------------------------------------
30 static int CeedBasisApplyTensorCore_Cuda_shared(CeedBasis basis, bool apply_add, const CeedInt num_elem, CeedTransposeMode t_mode,
31                                                 CeedEvalMode eval_mode, CeedVector u, CeedVector v) {
32   Ceed                   ceed;
33   Ceed_Cuda             *ceed_Cuda;
34   CeedInt                dim, num_comp;
35   const CeedScalar      *d_u;
36   CeedScalar            *d_v;
37   CeedBasis_Cuda_shared *data;
38 
39   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
40   CeedCallBackend(CeedGetData(ceed, &ceed_Cuda));
41   CeedCallBackend(CeedBasisGetData(basis, &data));
42   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
43   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
44 
45   // Get read/write access to u, v
46   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
47   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
48   if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
49   else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
50 
51   // Apply basis operation
52   switch (eval_mode) {
53     case CEED_EVAL_INTERP: {
54       CeedInt P_1d, Q_1d;
55 
56       CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
57       CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
58       CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
59 
60       CeedCallBackend(CeedInit_CudaInterp(data->d_interp_1d, P_1d, Q_1d, &data->c_B));
61       void *interp_args[] = {(void *)&num_elem, &data->c_B, &d_u, &d_v};
62 
63       if (dim == 1) {
64         CeedInt elems_per_block = CeedIntMin(ceed_Cuda->device_prop.maxThreadsDim[2], CeedIntMax(512 / thread_1d,
65                                                                                                  1));  // avoid >512 total threads
66         CeedInt grid            = num_elem / elems_per_block + ((num_elem / elems_per_block * elems_per_block < num_elem) ? 1 : 0);
67         CeedInt shared_mem      = elems_per_block * thread_1d * sizeof(CeedScalar);
68 
69         if (t_mode == CEED_TRANSPOSE) {
70           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, apply_add ? data->InterpTransposeAdd : data->InterpTranspose, grid, thread_1d, 1,
71                                                       elems_per_block, shared_mem, interp_args));
72         } else {
73           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, data->Interp, grid, thread_1d, 1, elems_per_block, shared_mem, interp_args));
74         }
75       } else if (dim == 2) {
76         const CeedInt opt_elems[7] = {0, 32, 8, 6, 4, 2, 8};
77         // elems_per_block must be at least 1
78         CeedInt elems_per_block = CeedIntMax(thread_1d < 7 ? opt_elems[thread_1d] / num_comp : 1, 1);
79         CeedInt grid            = num_elem / elems_per_block + ((num_elem / elems_per_block * elems_per_block < num_elem) ? 1 : 0);
80         CeedInt shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
81 
82         if (t_mode == CEED_TRANSPOSE) {
83           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, apply_add ? data->InterpTransposeAdd : data->InterpTranspose, grid, thread_1d, thread_1d,
84                                                       elems_per_block, shared_mem, interp_args));
85         } else {
86           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, data->Interp, grid, thread_1d, thread_1d, elems_per_block, shared_mem, interp_args));
87         }
88       } else if (dim == 3) {
89         CeedInt elems_per_block = 1;
90         CeedInt grid            = num_elem / elems_per_block + ((num_elem / elems_per_block * elems_per_block < num_elem) ? 1 : 0);
91         CeedInt shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
92 
93         if (t_mode == CEED_TRANSPOSE) {
94           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, apply_add ? data->InterpTransposeAdd : data->InterpTranspose, grid, thread_1d, thread_1d,
95                                                       elems_per_block, shared_mem, interp_args));
96         } else {
97           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, data->Interp, grid, thread_1d, thread_1d, elems_per_block, shared_mem, interp_args));
98         }
99       }
100     } break;
101     case CEED_EVAL_GRAD: {
102       CeedInt P_1d, Q_1d;
103 
104       CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
105       CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
106       CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
107 
108       if (data->d_collo_grad_1d) {
109         CeedCallBackend(CeedInit_CudaCollocatedGrad(data->d_interp_1d, data->d_collo_grad_1d, P_1d, Q_1d, &data->c_B, &data->c_G));
110       } else {
111         CeedCallBackend(CeedInit_CudaGrad(data->d_interp_1d, data->d_grad_1d, P_1d, Q_1d, &data->c_B, &data->c_G));
112       }
113       void *grad_args[] = {(void *)&num_elem, &data->c_B, &data->c_G, &d_u, &d_v};
114       if (dim == 1) {
115         CeedInt elems_per_block = CeedIntMin(ceed_Cuda->device_prop.maxThreadsDim[2], CeedIntMax(512 / thread_1d,
116                                                                                                  1));  // avoid >512 total threads
117         CeedInt grid            = num_elem / elems_per_block + ((num_elem / elems_per_block * elems_per_block < num_elem) ? 1 : 0);
118         CeedInt shared_mem      = elems_per_block * thread_1d * sizeof(CeedScalar);
119 
120         if (t_mode == CEED_TRANSPOSE) {
121           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, apply_add ? data->GradTransposeAdd : data->GradTranspose, grid, thread_1d, 1,
122                                                       elems_per_block, shared_mem, grad_args));
123         } else {
124           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, data->Grad, grid, thread_1d, 1, elems_per_block, shared_mem, grad_args));
125         }
126       } else if (dim == 2) {
127         const CeedInt opt_elems[7] = {0, 32, 8, 6, 4, 2, 8};
128         // elems_per_block must be at least 1
129         CeedInt elems_per_block = CeedIntMax(thread_1d < 7 ? opt_elems[thread_1d] / num_comp : 1, 1);
130         CeedInt grid            = num_elem / elems_per_block + ((num_elem / elems_per_block * elems_per_block < num_elem) ? 1 : 0);
131         CeedInt shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
132 
133         if (t_mode == CEED_TRANSPOSE) {
134           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, apply_add ? data->GradTransposeAdd : data->GradTranspose, grid, thread_1d, thread_1d,
135                                                       elems_per_block, shared_mem, grad_args));
136         } else {
137           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, data->Grad, grid, thread_1d, thread_1d, elems_per_block, shared_mem, grad_args));
138         }
139       } else if (dim == 3) {
140         CeedInt elems_per_block = 1;
141         CeedInt grid            = num_elem / elems_per_block + ((num_elem / elems_per_block * elems_per_block < num_elem) ? 1 : 0);
142         CeedInt shared_mem      = elems_per_block * thread_1d * thread_1d * sizeof(CeedScalar);
143 
144         if (t_mode == CEED_TRANSPOSE) {
145           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, apply_add ? data->GradTransposeAdd : data->GradTranspose, grid, thread_1d, thread_1d,
146                                                       elems_per_block, shared_mem, grad_args));
147         } else {
148           CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, data->Grad, grid, thread_1d, thread_1d, elems_per_block, shared_mem, grad_args));
149         }
150       }
151     } break;
152     case CEED_EVAL_WEIGHT: {
153       CeedInt Q_1d;
154       CeedInt block_size = 32;
155 
156       CeedCheck(data->d_q_weight_1d, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weights_1d not set", CeedEvalModes[eval_mode]);
157       CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
158       void *weight_args[] = {(void *)&num_elem, (void *)&data->d_q_weight_1d, &d_v};
159       if (dim == 1) {
160         const CeedInt elems_per_block = block_size / Q_1d;
161         const CeedInt grid_size       = num_elem / elems_per_block + ((num_elem / elems_per_block * elems_per_block < num_elem) ? 1 : 0);
162 
163         CeedCallBackend(CeedRunKernelDim_Cuda(ceed, data->Weight, grid_size, Q_1d, elems_per_block, 1, weight_args));
164       } else if (dim == 2) {
165         const CeedInt opt_elems       = block_size / (Q_1d * Q_1d);
166         const CeedInt elems_per_block = opt_elems > 0 ? opt_elems : 1;
167         const CeedInt grid_size       = num_elem / elems_per_block + ((num_elem / elems_per_block * elems_per_block < num_elem) ? 1 : 0);
168 
169         CeedCallBackend(CeedRunKernelDim_Cuda(ceed, data->Weight, grid_size, Q_1d, Q_1d, elems_per_block, weight_args));
170       } else if (dim == 3) {
171         const CeedInt opt_elems       = block_size / (Q_1d * Q_1d);
172         const CeedInt elems_per_block = opt_elems > 0 ? opt_elems : 1;
173         const CeedInt grid_size       = num_elem / elems_per_block + ((num_elem / elems_per_block * elems_per_block < num_elem) ? 1 : 0);
174 
175         CeedCallBackend(CeedRunKernelDim_Cuda(ceed, data->Weight, grid_size, Q_1d, Q_1d, elems_per_block, weight_args));
176       }
177     } break;
178     case CEED_EVAL_NONE: /* handled separately below */
179       break;
180     // LCOV_EXCL_START
181     case CEED_EVAL_DIV:
182     case CEED_EVAL_CURL:
183       return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
184       // LCOV_EXCL_STOP
185   }
186 
187   // Restore vectors, cover CEED_EVAL_NONE
188   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
189   if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
190   if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
191   return CEED_ERROR_SUCCESS;
192 }
193 
194 static int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
195                                             CeedVector v) {
196   CeedCallBackend(CeedBasisApplyTensorCore_Cuda_shared(basis, false, num_elem, t_mode, eval_mode, u, v));
197   return CEED_ERROR_SUCCESS;
198 }
199 
200 static int CeedBasisApplyAddTensor_Cuda_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode,
201                                                CeedVector u, CeedVector v) {
202   CeedCallBackend(CeedBasisApplyTensorCore_Cuda_shared(basis, true, num_elem, t_mode, eval_mode, u, v));
203   return CEED_ERROR_SUCCESS;
204 }
205 
206 //------------------------------------------------------------------------------
207 // Basis apply - tensor AtPoints
208 //------------------------------------------------------------------------------
209 static int CeedBasisApplyAtPointsCore_Cuda_shared(CeedBasis basis, bool apply_add, const CeedInt num_elem, const CeedInt *num_points,
210                                                   CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector x_ref, CeedVector u, CeedVector v) {
211   Ceed                   ceed;
212   CeedInt                Q_1d, dim, max_num_points = num_points[0];
213   const CeedInt          is_transpose   = t_mode == CEED_TRANSPOSE;
214   const int              max_block_size = 32;
215   const CeedScalar      *d_x, *d_u;
216   CeedScalar            *d_v;
217   CeedBasis_Cuda_shared *data;
218 
219   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
220   CeedCallBackend(CeedBasisGetData(basis, &data));
221   CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d));
222   CeedCallBackend(CeedBasisGetDimension(basis, &dim));
223 
224   // Check uniform number of points per elem
225   for (CeedInt i = 1; i < num_elem; i++) {
226     CeedCheck(max_num_points == num_points[i], ceed, CEED_ERROR_BACKEND,
227               "BasisApplyAtPoints only supported for the same number of points in each element");
228   }
229 
230   // Weight handled separately
231   if (eval_mode == CEED_EVAL_WEIGHT) {
232     CeedCallBackend(CeedVectorSetValue(v, 1.0));
233     return CEED_ERROR_SUCCESS;
234   }
235 
236   // Build kernels if needed
237   if (data->num_points != max_num_points) {
238     CeedInt P_1d;
239 
240     CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d));
241     data->num_points = max_num_points;
242 
243     // -- Create interp matrix to Chebyshev coefficients
244     if (!data->d_chebyshev_interp_1d) {
245       CeedSize    interp_bytes;
246       CeedScalar *chebyshev_interp_1d;
247 
248       interp_bytes = P_1d * Q_1d * sizeof(CeedScalar);
249       CeedCallBackend(CeedCalloc(P_1d * Q_1d, &chebyshev_interp_1d));
250       CeedCallBackend(CeedBasisGetChebyshevInterp1D(basis, chebyshev_interp_1d));
251       CeedCallCuda(ceed, cudaMalloc((void **)&data->d_chebyshev_interp_1d, interp_bytes));
252       CeedCallCuda(ceed, cudaMemcpy(data->d_chebyshev_interp_1d, chebyshev_interp_1d, interp_bytes, cudaMemcpyHostToDevice));
253       CeedCallBackend(CeedFree(&chebyshev_interp_1d));
254     }
255 
256     // -- Compile kernels
257     char       *basis_kernel_source;
258     const char *basis_kernel_path;
259     CeedInt     num_comp;
260 
261     if (data->moduleAtPoints) CeedCallCuda(ceed, cuModuleUnload(data->moduleAtPoints));
262     CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
263     CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-basis-tensor-at-points.h", &basis_kernel_path));
264     CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
265     CeedCallBackend(CeedLoadSourceToBuffer(ceed, basis_kernel_path, &basis_kernel_source));
266     CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n");
267     CeedCallBackend(CeedCompile_Cuda(ceed, basis_kernel_source, &data->moduleAtPoints, 9, "BASIS_Q_1D", Q_1d, "BASIS_P_1D", P_1d, "BASIS_BUF_LEN",
268                                      Q_1d * CeedIntPow(Q_1d > P_1d ? Q_1d : P_1d, dim - 1), "BASIS_DIM", dim, "BASIS_NUM_COMP", num_comp,
269                                      "BASIS_NUM_NODES", CeedIntPow(P_1d, dim), "BASIS_NUM_QPTS", CeedIntPow(Q_1d, dim), "BASIS_NUM_PTS",
270                                      max_num_points, "POINTS_BUFF_LEN", CeedIntPow(Q_1d, dim - 1)));
271     CeedCallBackend(CeedGetKernel_Cuda(ceed, data->moduleAtPoints, "InterpAtPoints", &data->InterpAtPoints));
272     CeedCallBackend(CeedGetKernel_Cuda(ceed, data->moduleAtPoints, "GradAtPoints", &data->GradAtPoints));
273     CeedCallBackend(CeedFree(&basis_kernel_path));
274     CeedCallBackend(CeedFree(&basis_kernel_source));
275   }
276 
277   // Get read/write access to u, v
278   CeedCallBackend(CeedVectorGetArrayRead(x_ref, CEED_MEM_DEVICE, &d_x));
279   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
280   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
281   if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
282   else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
283 
284   // Clear v for transpose operation
285   if (is_transpose && !apply_add) {
286     CeedSize length;
287 
288     CeedCallBackend(CeedVectorGetLength(v, &length));
289     CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar)));
290   }
291 
292   // Basis action
293   switch (eval_mode) {
294     case CEED_EVAL_INTERP: {
295       void         *interp_args[] = {(void *)&num_elem, (void *)&is_transpose, &data->d_chebyshev_interp_1d, &d_x, &d_u, &d_v};
296       const CeedInt block_size    = CeedIntMin(CeedIntPow(Q_1d, dim), max_block_size);
297 
298       CeedCallBackend(CeedRunKernel_Cuda(ceed, data->InterpAtPoints, num_elem, block_size, interp_args));
299     } break;
300     case CEED_EVAL_GRAD: {
301       void         *grad_args[] = {(void *)&num_elem, (void *)&is_transpose, &data->d_chebyshev_interp_1d, &d_x, &d_u, &d_v};
302       const CeedInt block_size  = CeedIntMin(CeedIntPow(Q_1d, dim), max_block_size);
303 
304       CeedCallBackend(CeedRunKernel_Cuda(ceed, data->GradAtPoints, num_elem, block_size, grad_args));
305     } break;
306     case CEED_EVAL_WEIGHT:
307     case CEED_EVAL_NONE: /* handled separately below */
308       break;
309     // LCOV_EXCL_START
310     case CEED_EVAL_DIV:
311     case CEED_EVAL_CURL:
312       return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
313       // LCOV_EXCL_STOP
314   }
315 
316   // Restore vectors, cover CEED_EVAL_NONE
317   CeedCallBackend(CeedVectorRestoreArrayRead(x_ref, &d_x));
318   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
319   if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
320   if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
321   return CEED_ERROR_SUCCESS;
322 }
323 
324 static int CeedBasisApplyAtPoints_Cuda_shared(CeedBasis basis, const CeedInt num_elem, const CeedInt *num_points, CeedTransposeMode t_mode,
325                                               CeedEvalMode eval_mode, CeedVector x_ref, CeedVector u, CeedVector v) {
326   CeedCallBackend(CeedBasisApplyAtPointsCore_Cuda_shared(basis, false, num_elem, num_points, t_mode, eval_mode, x_ref, u, v));
327   return CEED_ERROR_SUCCESS;
328 }
329 
330 static int CeedBasisApplyAddAtPoints_Cuda_shared(CeedBasis basis, const CeedInt num_elem, const CeedInt *num_points, CeedTransposeMode t_mode,
331                                                  CeedEvalMode eval_mode, CeedVector x_ref, CeedVector u, CeedVector v) {
332   CeedCallBackend(CeedBasisApplyAtPointsCore_Cuda_shared(basis, true, num_elem, num_points, t_mode, eval_mode, x_ref, u, v));
333   return CEED_ERROR_SUCCESS;
334 }
335 
336 //------------------------------------------------------------------------------
337 // Destroy basis
338 //------------------------------------------------------------------------------
339 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
340   Ceed                   ceed;
341   CeedBasis_Cuda_shared *data;
342 
343   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
344   CeedCallBackend(CeedBasisGetData(basis, &data));
345   CeedCallCuda(ceed, cuModuleUnload(data->module));
346   if (data->moduleAtPoints) CeedCallCuda(ceed, cuModuleUnload(data->moduleAtPoints));
347   if (data->d_q_weight_1d) CeedCallCuda(ceed, cudaFree(data->d_q_weight_1d));
348   CeedCallCuda(ceed, cudaFree(data->d_interp_1d));
349   CeedCallCuda(ceed, cudaFree(data->d_grad_1d));
350   CeedCallCuda(ceed, cudaFree(data->d_collo_grad_1d));
351   CeedCallCuda(ceed, cudaFree(data->d_chebyshev_interp_1d));
352   CeedCallBackend(CeedFree(&data));
353   return CEED_ERROR_SUCCESS;
354 }
355 
356 //------------------------------------------------------------------------------
357 // Create tensor basis
358 //------------------------------------------------------------------------------
359 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
360                                         const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
361   Ceed                   ceed;
362   char                  *basis_kernel_source;
363   const char            *basis_kernel_path;
364   CeedInt                num_comp;
365   const CeedInt          q_bytes      = Q_1d * sizeof(CeedScalar);
366   const CeedInt          interp_bytes = q_bytes * P_1d;
367   CeedBasis_Cuda_shared *data;
368 
369   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
370   CeedCallBackend(CeedCalloc(1, &data));
371 
372   // Copy basis data to GPU
373   if (q_weight_1d) {
374     CeedCallCuda(ceed, cudaMalloc((void **)&data->d_q_weight_1d, q_bytes));
375     CeedCallCuda(ceed, cudaMemcpy(data->d_q_weight_1d, q_weight_1d, q_bytes, cudaMemcpyHostToDevice));
376   }
377   CeedCallCuda(ceed, cudaMalloc((void **)&data->d_interp_1d, interp_bytes));
378   CeedCallCuda(ceed, cudaMemcpy(data->d_interp_1d, interp_1d, interp_bytes, cudaMemcpyHostToDevice));
379   CeedCallCuda(ceed, cudaMalloc((void **)&data->d_grad_1d, interp_bytes));
380   CeedCallCuda(ceed, cudaMemcpy(data->d_grad_1d, grad_1d, interp_bytes, cudaMemcpyHostToDevice));
381 
382   // Compute collocated gradient and copy to GPU
383   data->d_collo_grad_1d    = NULL;
384   bool has_collocated_grad = dim == 3 && Q_1d >= P_1d;
385 
386   if (has_collocated_grad) {
387     CeedScalar *collo_grad_1d;
388 
389     CeedCallBackend(CeedMalloc(Q_1d * Q_1d, &collo_grad_1d));
390     CeedCallBackend(CeedBasisGetCollocatedGrad(basis, collo_grad_1d));
391     CeedCallCuda(ceed, cudaMalloc((void **)&data->d_collo_grad_1d, q_bytes * Q_1d));
392     CeedCallCuda(ceed, cudaMemcpy(data->d_collo_grad_1d, collo_grad_1d, q_bytes * Q_1d, cudaMemcpyHostToDevice));
393     CeedCallBackend(CeedFree(&collo_grad_1d));
394   }
395 
396   // Compile basis kernels
397   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
398   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-shared-basis-tensor.h", &basis_kernel_path));
399   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
400   CeedCallBackend(CeedLoadSourceToBuffer(ceed, basis_kernel_path, &basis_kernel_source));
401   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete -----\n");
402   CeedCallBackend(CeedCompile_Cuda(ceed, basis_kernel_source, &data->module, 8, "BASIS_Q_1D", Q_1d, "BASIS_P_1D", P_1d, "T_1D",
403                                    CeedIntMax(Q_1d, P_1d), "BASIS_DIM", dim, "BASIS_NUM_COMP", num_comp, "BASIS_NUM_NODES", CeedIntPow(P_1d, dim),
404                                    "BASIS_NUM_QPTS", CeedIntPow(Q_1d, dim), "BASIS_HAS_COLLOCATED_GRAD", has_collocated_grad));
405   CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, "Interp", &data->Interp));
406   CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, "InterpTranspose", &data->InterpTranspose));
407   CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, "InterpTransposeAdd", &data->InterpTransposeAdd));
408   CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, "Grad", &data->Grad));
409   CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, "GradTranspose", &data->GradTranspose));
410   CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, "GradTransposeAdd", &data->GradTransposeAdd));
411   CeedCallBackend(CeedGetKernel_Cuda(ceed, data->module, "Weight", &data->Weight));
412   CeedCallBackend(CeedFree(&basis_kernel_path));
413   CeedCallBackend(CeedFree(&basis_kernel_source));
414 
415   CeedCallBackend(CeedBasisSetData(basis, data));
416 
417   // Register backend functions
418   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyTensor_Cuda_shared));
419   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddTensor_Cuda_shared));
420   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAtPoints", CeedBasisApplyAtPoints_Cuda_shared));
421   CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Cuda_shared));
422   return CEED_ERROR_SUCCESS;
423 }
424 
425 //------------------------------------------------------------------------------
426