xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 9e201c85545dd39529c090846df629a32c15659b)
13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3c532df63SYohann //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
5c532df63SYohann //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
7c532df63SYohann 
8ec3da8bcSJed Brown #include <ceed/ceed.h>
9ec3da8bcSJed Brown #include <ceed/backend.h>
10437930d1SJeremy L Thompson #include <ceed/jit-tools.h>
113d576824SJeremy L Thompson #include <cuda.h>
123d576824SJeremy L Thompson #include <cuda_runtime.h>
133d576824SJeremy L Thompson #include <stddef.h>
14c532df63SYohann #include "ceed-cuda-shared.h"
156d69246aSJeremy L Thompson #include "../cuda/ceed-cuda-compile.h"
16c532df63SYohann 
17c532df63SYohann 
18ab213215SJeremy L Thompson //------------------------------------------------------------------------------
19ab213215SJeremy L Thompson // Device initalization
20ab213215SJeremy L Thompson //------------------------------------------------------------------------------
21437930d1SJeremy L Thompson int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P_1d, CeedInt Q_1d,
22c532df63SYohann                        CeedScalar **c_B);
23*9e201c85SYohann int CeedCudaInitGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P_1d,
24*9e201c85SYohann                      CeedInt Q_1d, CeedScalar **c_B_ptr,
25*9e201c85SYohann                      CeedScalar **c_G_ptr);
26*9e201c85SYohann int CeedCudaInitCollocatedGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P_1d,
27437930d1SJeremy L Thompson                                CeedInt Q_1d, CeedScalar **c_B_ptr,
287f823360Sjeremylt                                CeedScalar **c_G_ptr);
29c532df63SYohann 
30ab213215SJeremy L Thompson //------------------------------------------------------------------------------
31ab213215SJeremy L Thompson // Apply basis
32ab213215SJeremy L Thompson //------------------------------------------------------------------------------
33437930d1SJeremy L Thompson int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt num_elem,
34437930d1SJeremy L Thompson                                      CeedTransposeMode t_mode,
35437930d1SJeremy L Thompson                                      CeedEvalMode eval_mode, CeedVector u,
367f823360Sjeremylt                                      CeedVector v) {
37c532df63SYohann   int ierr;
38c532df63SYohann   Ceed ceed;
39e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
406dbfb411Snbeams   Ceed_Cuda *ceed_Cuda;
41e15f9bd0SJeremy L Thompson   CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr);
42c532df63SYohann   CeedBasis_Cuda_shared *data;
43e15f9bd0SJeremy L Thompson   CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
44437930d1SJeremy L Thompson   CeedInt dim, num_comp;
45e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
46437930d1SJeremy L Thompson   ierr = CeedBasisGetNumComponents(basis, &num_comp); CeedChkBackend(ierr);
47c532df63SYohann 
48ab213215SJeremy L Thompson   // Read vectors
49c532df63SYohann   const CeedScalar *d_u;
50c532df63SYohann   CeedScalar *d_v;
51437930d1SJeremy L Thompson   if (eval_mode != CEED_EVAL_WEIGHT) {
52e15f9bd0SJeremy L Thompson     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
53c532df63SYohann   }
549c774eddSJeremy L Thompson   ierr = CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
55c532df63SYohann 
56ab213215SJeremy L Thompson   // Apply basis operation
57437930d1SJeremy L Thompson   switch (eval_mode) {
58ab213215SJeremy L Thompson   case CEED_EVAL_INTERP: {
59437930d1SJeremy L Thompson     CeedInt P_1d, Q_1d;
60437930d1SJeremy L Thompson     ierr = CeedBasisGetNumNodes1D(basis, &P_1d); CeedChkBackend(ierr);
61437930d1SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d); CeedChkBackend(ierr);
62437930d1SJeremy L Thompson     CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
63437930d1SJeremy L Thompson     ierr = CeedCudaInitInterp(data->d_interp_1d, P_1d, Q_1d, &data->c_B);
64e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
65*9e201c85SYohann     void *interp_args[] = {(void *) &num_elem, &data->c_B,
66ccf0fe6fSjeremylt                            &d_u, &d_v
67ccf0fe6fSjeremylt                           };
684d537eeaSYohann     if (dim == 1) {
69e6f67ff7SJed Brown       CeedInt elems_per_block = CeedIntMin(ceed_Cuda->device_prop.maxThreadsDim[2],
70e6f67ff7SJed Brown                                            CeedIntMax(512 / thread_1d,
71e6f67ff7SJed Brown                                                1)); // avoid >512 total threads
72437930d1SJeremy L Thompson       CeedInt grid = num_elem/elems_per_block +
73437930d1SJeremy L Thompson                      ((num_elem / elems_per_block*elems_per_block < num_elem) ? 1 : 0 );
74437930d1SJeremy L Thompson       CeedInt shared_mem = elems_per_block*thread_1d*sizeof(CeedScalar);
75*9e201c85SYohann       if (t_mode == CEED_TRANSPOSE) {
76*9e201c85SYohann         ierr = CeedRunKernelDimSharedCuda(ceed, data->InterpTranspose, grid, thread_1d,
77*9e201c85SYohann                                           1,
78*9e201c85SYohann                                           elems_per_block, shared_mem,
79*9e201c85SYohann                                           interp_args); CeedChkBackend(ierr);
80*9e201c85SYohann       } else {
81437930d1SJeremy L Thompson         ierr = CeedRunKernelDimSharedCuda(ceed, data->Interp, grid, thread_1d, 1,
82437930d1SJeremy L Thompson                                           elems_per_block, shared_mem,
83437930d1SJeremy L Thompson                                           interp_args); CeedChkBackend(ierr);
84*9e201c85SYohann       }
85074be161SYohann Dudouit     } else if (dim == 2) {
86437930d1SJeremy L Thompson       const CeedInt opt_elems[7] = {0, 32, 8, 6, 4, 2, 8};
87437930d1SJeremy L Thompson       // elems_per_block must be at least 1
88437930d1SJeremy L Thompson       CeedInt elems_per_block = CeedIntMax(thread_1d < 7 ? opt_elems[thread_1d] /
89437930d1SJeremy L Thompson                                            num_comp : 1, 1);
90437930d1SJeremy L Thompson       CeedInt grid = num_elem / elems_per_block +
91437930d1SJeremy L Thompson                      ((num_elem / elems_per_block*elems_per_block < num_elem) ? 1 : 0 );
92*9e201c85SYohann       CeedInt shared_mem = elems_per_block*thread_1d*thread_1d*sizeof(
93437930d1SJeremy L Thompson                              CeedScalar);
94*9e201c85SYohann       if (t_mode == CEED_TRANSPOSE) {
95*9e201c85SYohann         ierr = CeedRunKernelDimSharedCuda(ceed, data->InterpTranspose, grid, thread_1d,
96*9e201c85SYohann                                           thread_1d,
97*9e201c85SYohann                                           elems_per_block, shared_mem,
98*9e201c85SYohann                                           interp_args); CeedChkBackend(ierr);
99*9e201c85SYohann       } else {
100437930d1SJeremy L Thompson         ierr = CeedRunKernelDimSharedCuda(ceed, data->Interp, grid, thread_1d,
101437930d1SJeremy L Thompson                                           thread_1d,
102*9e201c85SYohann                                           elems_per_block, shared_mem,
103437930d1SJeremy L Thompson                                           interp_args); CeedChkBackend(ierr);
104*9e201c85SYohann       }
105074be161SYohann Dudouit     } else if (dim == 3) {
106437930d1SJeremy L Thompson       CeedInt elems_per_block = 1;
107437930d1SJeremy L Thompson       CeedInt grid = num_elem / elems_per_block +
108437930d1SJeremy L Thompson                      ((num_elem / elems_per_block*elems_per_block < num_elem) ? 1 : 0 );
109*9e201c85SYohann       CeedInt shared_mem = elems_per_block*thread_1d*thread_1d*sizeof(
110437930d1SJeremy L Thompson                              CeedScalar);
111*9e201c85SYohann       if (t_mode == CEED_TRANSPOSE) {
112*9e201c85SYohann         ierr = CeedRunKernelDimSharedCuda(ceed, data->InterpTranspose, grid, thread_1d,
113*9e201c85SYohann                                           thread_1d,
114*9e201c85SYohann                                           elems_per_block, shared_mem,
115*9e201c85SYohann                                           interp_args); CeedChkBackend(ierr);
116*9e201c85SYohann       } else {
117437930d1SJeremy L Thompson         ierr = CeedRunKernelDimSharedCuda(ceed, data->Interp, grid, thread_1d,
118437930d1SJeremy L Thompson                                           thread_1d,
119*9e201c85SYohann                                           elems_per_block, shared_mem,
120437930d1SJeremy L Thompson                                           interp_args); CeedChkBackend(ierr);
121074be161SYohann Dudouit       }
122*9e201c85SYohann     }
123ab213215SJeremy L Thompson   } break;
124ab213215SJeremy L Thompson   case CEED_EVAL_GRAD: {
125437930d1SJeremy L Thompson     CeedInt P_1d, Q_1d;
126437930d1SJeremy L Thompson     ierr = CeedBasisGetNumNodes1D(basis, &P_1d); CeedChkBackend(ierr);
127437930d1SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d); CeedChkBackend(ierr);
128437930d1SJeremy L Thompson     CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
129*9e201c85SYohann     if (data->d_collo_grad_1d) {
130*9e201c85SYohann       ierr = CeedCudaInitCollocatedGrad(data->d_interp_1d, data->d_collo_grad_1d,
131*9e201c85SYohann                                         P_1d,
132437930d1SJeremy L Thompson                                         Q_1d, &data->c_B, &data->c_G);
133e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
134*9e201c85SYohann     } else {
135*9e201c85SYohann       ierr = CeedCudaInitGrad(data->d_interp_1d, data->d_grad_1d, P_1d,
136*9e201c85SYohann                               Q_1d, &data->c_B, &data->c_G);
137*9e201c85SYohann       CeedChkBackend(ierr);
138*9e201c85SYohann     }
139*9e201c85SYohann     void *grad_args[] = {(void *) &num_elem, &data->c_B,
140ccf0fe6fSjeremylt                          &data->c_G, &d_u, &d_v
141ccf0fe6fSjeremylt                         };
1424d537eeaSYohann     if (dim == 1) {
143e6f67ff7SJed Brown       CeedInt elems_per_block = CeedIntMin(ceed_Cuda->device_prop.maxThreadsDim[2],
144e6f67ff7SJed Brown                                            CeedIntMax(512 / thread_1d,
145e6f67ff7SJed Brown                                                1)); // avoid >512 total threads
146437930d1SJeremy L Thompson       CeedInt grid = num_elem / elems_per_block +
147437930d1SJeremy L Thompson                      ((num_elem / elems_per_block*elems_per_block<num_elem) ? 1 : 0 );
148437930d1SJeremy L Thompson       CeedInt shared_mem = elems_per_block*thread_1d*sizeof(CeedScalar);
149*9e201c85SYohann       if (t_mode == CEED_TRANSPOSE) {
150*9e201c85SYohann         ierr = CeedRunKernelDimSharedCuda(ceed, data->GradTranspose, grid, thread_1d, 1,
151*9e201c85SYohann                                           elems_per_block, shared_mem, grad_args);
152*9e201c85SYohann         CeedChkBackend(ierr);
153*9e201c85SYohann       } else {
154437930d1SJeremy L Thompson         ierr = CeedRunKernelDimSharedCuda(ceed, data->Grad, grid, thread_1d, 1,
155437930d1SJeremy L Thompson                                           elems_per_block, shared_mem, grad_args);
156e15f9bd0SJeremy L Thompson         CeedChkBackend(ierr);
157*9e201c85SYohann       }
158074be161SYohann Dudouit     } else if (dim == 2) {
159437930d1SJeremy L Thompson       const CeedInt opt_elems[7] = {0, 32, 8, 6, 4, 2, 8};
160437930d1SJeremy L Thompson       // elems_per_block must be at least 1
161437930d1SJeremy L Thompson       CeedInt elems_per_block = CeedIntMax(thread_1d < 7 ? opt_elems[thread_1d] /
162437930d1SJeremy L Thompson                                            num_comp : 1, 1);
163437930d1SJeremy L Thompson       CeedInt grid = num_elem / elems_per_block +
164437930d1SJeremy L Thompson                      ((num_elem / elems_per_block*elems_per_block < num_elem) ? 1 : 0 );
165*9e201c85SYohann       CeedInt shared_mem = elems_per_block*thread_1d*thread_1d*sizeof(
166437930d1SJeremy L Thompson                              CeedScalar);
167*9e201c85SYohann       if (t_mode == CEED_TRANSPOSE) {
168*9e201c85SYohann         ierr = CeedRunKernelDimSharedCuda(ceed, data->GradTranspose, grid, thread_1d,
169*9e201c85SYohann                                           thread_1d,
170*9e201c85SYohann                                           elems_per_block, shared_mem,
171437930d1SJeremy L Thompson                                           grad_args); CeedChkBackend(ierr);
172*9e201c85SYohann       } else {
173*9e201c85SYohann         ierr = CeedRunKernelDimSharedCuda(ceed, data->Grad, grid, thread_1d, thread_1d,
174*9e201c85SYohann                                           elems_per_block, shared_mem,
175*9e201c85SYohann                                           grad_args); CeedChkBackend(ierr);
176*9e201c85SYohann       }
177074be161SYohann Dudouit     } else if (dim == 3) {
178437930d1SJeremy L Thompson       CeedInt elems_per_block = 1;
179437930d1SJeremy L Thompson       CeedInt grid = num_elem / elems_per_block +
180437930d1SJeremy L Thompson                      ((num_elem / elems_per_block*elems_per_block < num_elem) ? 1 : 0 );
181*9e201c85SYohann       CeedInt shared_mem = elems_per_block*thread_1d*thread_1d*sizeof(
182437930d1SJeremy L Thompson                              CeedScalar);
183*9e201c85SYohann       if (t_mode == CEED_TRANSPOSE) {
184*9e201c85SYohann         ierr = CeedRunKernelDimSharedCuda(ceed, data->GradTranspose, grid, thread_1d,
185*9e201c85SYohann                                           thread_1d,
186*9e201c85SYohann                                           elems_per_block, shared_mem,
187437930d1SJeremy L Thompson                                           grad_args); CeedChkBackend(ierr);
188*9e201c85SYohann       } else {
189*9e201c85SYohann         ierr = CeedRunKernelDimSharedCuda(ceed, data->Grad, grid, thread_1d, thread_1d,
190*9e201c85SYohann                                           elems_per_block, shared_mem,
191*9e201c85SYohann                                           grad_args); CeedChkBackend(ierr);
192*9e201c85SYohann       }
193074be161SYohann Dudouit     }
194ab213215SJeremy L Thompson   } break;
195ab213215SJeremy L Thompson   case CEED_EVAL_WEIGHT: {
196437930d1SJeremy L Thompson     CeedInt Q_1d;
197437930d1SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d); CeedChkBackend(ierr);
198437930d1SJeremy L Thompson     void *weight_args[] = {(void *) &num_elem, (void *) &data->d_q_weight_1d, &d_v};
199074be161SYohann Dudouit     if (dim == 1) {
200437930d1SJeremy L Thompson       const CeedInt elems_per_block = 32 / Q_1d;
201437930d1SJeremy L Thompson       const CeedInt gridsize = num_elem / elems_per_block +
202437930d1SJeremy L Thompson                                ((num_elem / elems_per_block*elems_per_block < num_elem) ? 1 : 0 );
203437930d1SJeremy L Thompson       ierr = CeedRunKernelDimCuda(ceed, data->Weight, gridsize, Q_1d,
204437930d1SJeremy L Thompson                                   elems_per_block, 1, weight_args);
205e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
206074be161SYohann Dudouit     } else if (dim == 2) {
207437930d1SJeremy L Thompson       const CeedInt opt_elems = 32 / (Q_1d * Q_1d);
208437930d1SJeremy L Thompson       const CeedInt elems_per_block = opt_elems > 0 ? opt_elems : 1;
209437930d1SJeremy L Thompson       const CeedInt gridsize = num_elem / elems_per_block +
210437930d1SJeremy L Thompson                                ((num_elem / elems_per_block*elems_per_block < num_elem) ? 1 : 0 );
211437930d1SJeremy L Thompson       ierr = CeedRunKernelDimCuda(ceed, data->Weight, gridsize, Q_1d, Q_1d,
212437930d1SJeremy L Thompson                                   elems_per_block, weight_args);
213e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
214074be161SYohann Dudouit     } else if (dim == 3) {
215*9e201c85SYohann       const CeedInt opt_elems = 32 / (Q_1d * Q_1d);
216*9e201c85SYohann       const CeedInt elems_per_block = opt_elems > 0 ? opt_elems : 1;
217*9e201c85SYohann       const CeedInt gridsize = num_elem / elems_per_block +
218*9e201c85SYohann                                ((num_elem / elems_per_block*elems_per_block < num_elem) ? 1 : 0 );
219*9e201c85SYohann       ierr = CeedRunKernelDimCuda(ceed, data->Weight, gridsize, Q_1d, Q_1d,
220*9e201c85SYohann                                   elems_per_block, weight_args);
221e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
222074be161SYohann Dudouit     }
223ab213215SJeremy L Thompson   } break;
224ab213215SJeremy L Thompson   // LCOV_EXCL_START
225ab213215SJeremy L Thompson   // Evaluate the divergence to/from the quadrature points
226ab213215SJeremy L Thompson   case CEED_EVAL_DIV:
227e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
228ab213215SJeremy L Thompson   // Evaluate the curl to/from the quadrature points
229ab213215SJeremy L Thompson   case CEED_EVAL_CURL:
230e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
231ab213215SJeremy L Thompson   // Take no action, BasisApply should not have been called
232ab213215SJeremy L Thompson   case CEED_EVAL_NONE:
233e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND,
234ab213215SJeremy L Thompson                      "CEED_EVAL_NONE does not make sense in this context");
235ab213215SJeremy L Thompson     // LCOV_EXCL_STOP
236c532df63SYohann   }
237c532df63SYohann 
238ab213215SJeremy L Thompson   // Restore vectors
239437930d1SJeremy L Thompson   if (eval_mode != CEED_EVAL_WEIGHT) {
240e15f9bd0SJeremy L Thompson     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
241c532df63SYohann   }
242e15f9bd0SJeremy L Thompson   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
243e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
244c532df63SYohann }
245c532df63SYohann 
246ab213215SJeremy L Thompson //------------------------------------------------------------------------------
247ab213215SJeremy L Thompson // Destroy basis
248ab213215SJeremy L Thompson //------------------------------------------------------------------------------
249c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
250c532df63SYohann   int ierr;
251c532df63SYohann   Ceed ceed;
252e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
253c532df63SYohann 
254c532df63SYohann   CeedBasis_Cuda_shared *data;
255e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
256c532df63SYohann 
257c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
258c532df63SYohann 
259437930d1SJeremy L Thompson   ierr = cudaFree(data->d_q_weight_1d); CeedChk_Cu(ceed, ierr);
260437930d1SJeremy L Thompson   ierr = cudaFree(data->d_interp_1d); CeedChk_Cu(ceed, ierr);
261437930d1SJeremy L Thompson   ierr = cudaFree(data->d_grad_1d); CeedChk_Cu(ceed, ierr);
262437930d1SJeremy L Thompson   ierr = cudaFree(data->d_collo_grad_1d); CeedChk_Cu(ceed, ierr);
263c532df63SYohann 
264e15f9bd0SJeremy L Thompson   ierr = CeedFree(&data); CeedChkBackend(ierr);
265c532df63SYohann 
266e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
267c532df63SYohann }
268c532df63SYohann 
269ab213215SJeremy L Thompson //------------------------------------------------------------------------------
270ab213215SJeremy L Thompson // Create tensor basis
271ab213215SJeremy L Thompson //------------------------------------------------------------------------------
272437930d1SJeremy L Thompson int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P_1d, CeedInt Q_1d,
273437930d1SJeremy L Thompson                                         const CeedScalar *interp_1d,
274437930d1SJeremy L Thompson                                         const CeedScalar *grad_1d,
275437930d1SJeremy L Thompson                                         const CeedScalar *q_ref_1d,
276437930d1SJeremy L Thompson                                         const CeedScalar *q_weight_1d,
277c532df63SYohann                                         CeedBasis basis) {
278c532df63SYohann   int ierr;
279c532df63SYohann   Ceed ceed;
280e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
281c532df63SYohann   CeedBasis_Cuda_shared *data;
282e15f9bd0SJeremy L Thompson   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
283c532df63SYohann 
284ab213215SJeremy L Thompson   // Copy basis data to GPU
285437930d1SJeremy L Thompson   const CeedInt q_bytes = Q_1d * sizeof(CeedScalar);
286437930d1SJeremy L Thompson   ierr = cudaMalloc((void **)&data->d_q_weight_1d, q_bytes);
287437930d1SJeremy L Thompson   CeedChk_Cu(ceed, ierr);
288437930d1SJeremy L Thompson   ierr = cudaMemcpy(data->d_q_weight_1d, q_weight_1d, q_bytes,
289c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
290c532df63SYohann 
291437930d1SJeremy L Thompson   const CeedInt interp_bytes = q_bytes * P_1d;
292437930d1SJeremy L Thompson   ierr = cudaMalloc((void **)&data->d_interp_1d, interp_bytes);
293437930d1SJeremy L Thompson   CeedChk_Cu(ceed, ierr);
294437930d1SJeremy L Thompson   ierr = cudaMemcpy(data->d_interp_1d, interp_1d, interp_bytes,
295c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
296c532df63SYohann 
297437930d1SJeremy L Thompson   ierr = cudaMalloc((void **)&data->d_grad_1d, interp_bytes);
298437930d1SJeremy L Thompson   CeedChk_Cu(ceed, ierr);
299437930d1SJeremy L Thompson   ierr = cudaMemcpy(data->d_grad_1d, grad_1d, interp_bytes,
300c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
301c532df63SYohann 
302ab213215SJeremy L Thompson   // Compute collocated gradient and copy to GPU
303437930d1SJeremy L Thompson   data->d_collo_grad_1d = NULL;
304*9e201c85SYohann   bool has_collocated_grad = dim == 3 && Q_1d >= P_1d;
305*9e201c85SYohann   if (has_collocated_grad) {
306437930d1SJeremy L Thompson     CeedScalar *collo_grad_1d;
307437930d1SJeremy L Thompson     ierr = CeedMalloc(Q_1d*Q_1d, &collo_grad_1d); CeedChkBackend(ierr);
308437930d1SJeremy L Thompson     ierr = CeedBasisGetCollocatedGrad(basis, collo_grad_1d); CeedChkBackend(ierr);
309437930d1SJeremy L Thompson     ierr = cudaMalloc((void **)&data->d_collo_grad_1d, q_bytes * Q_1d);
310ac421f39SYohann     CeedChk_Cu(ceed, ierr);
311437930d1SJeremy L Thompson     ierr = cudaMemcpy(data->d_collo_grad_1d, collo_grad_1d, q_bytes * Q_1d,
312ac421f39SYohann                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
313437930d1SJeremy L Thompson     ierr = CeedFree(&collo_grad_1d); CeedChkBackend(ierr);
314ac421f39SYohann   }
315ac421f39SYohann 
316ab213215SJeremy L Thompson   // Compile basis kernels
317437930d1SJeremy L Thompson   CeedInt num_comp;
318437930d1SJeremy L Thompson   ierr = CeedBasisGetNumComponents(basis, &num_comp); CeedChkBackend(ierr);
319437930d1SJeremy L Thompson   char *basis_kernel_path, *basis_kernel_source;
320ee5a26f2SJeremy L Thompson   ierr = CeedGetJitAbsolutePath(ceed,
321*9e201c85SYohann                                 "ceed/jit-source/cuda/cuda-shared-basis-tensor.h",
322437930d1SJeremy L Thompson                                 &basis_kernel_path); CeedChkBackend(ierr);
32346dc0734SJeremy L Thompson   CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source -----\n");
324437930d1SJeremy L Thompson   ierr = CeedLoadSourceToBuffer(ceed, basis_kernel_path, &basis_kernel_source);
325437930d1SJeremy L Thompson   CeedChkBackend(ierr);
32646dc0734SJeremy L Thompson   CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source Complete -----\n");
327437930d1SJeremy L Thompson   ierr = CeedCompileCuda(ceed, basis_kernel_source, &data->module, 8,
328d7d111ecSJeremy L Thompson                          "BASIS_Q_1D", Q_1d,
329d7d111ecSJeremy L Thompson                          "BASIS_P_1D", P_1d,
330*9e201c85SYohann                          "T_1D", CeedIntMax(Q_1d, P_1d),
331c532df63SYohann                          "BASIS_DIM", dim,
332d7d111ecSJeremy L Thompson                          "BASIS_NUM_COMP", num_comp,
333d7d111ecSJeremy L Thompson                          "BASIS_NUM_NODES", CeedIntPow(P_1d, dim),
334*9e201c85SYohann                          "BASIS_NUM_QPTS", CeedIntPow(Q_1d, dim),
335*9e201c85SYohann                          "BASIS_HAS_COLLOCATED_GRAD", has_collocated_grad
336e15f9bd0SJeremy L Thompson                         ); CeedChkBackend(ierr);
337437930d1SJeremy L Thompson   ierr = CeedGetKernelCuda(ceed, data->module, "Interp", &data->Interp);
338e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
339*9e201c85SYohann   ierr = CeedGetKernelCuda(ceed, data->module, "InterpTranspose",
340*9e201c85SYohann                            &data->InterpTranspose);
341*9e201c85SYohann   CeedChkBackend(ierr);
342437930d1SJeremy L Thompson   ierr = CeedGetKernelCuda(ceed, data->module, "Grad", &data->Grad);
343e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
344*9e201c85SYohann   ierr = CeedGetKernelCuda(ceed, data->module, "GradTranspose",
345*9e201c85SYohann                            &data->GradTranspose);
346*9e201c85SYohann   CeedChkBackend(ierr);
347437930d1SJeremy L Thompson   ierr = CeedGetKernelCuda(ceed, data->module, "Weight", &data->Weight);
348e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
349437930d1SJeremy L Thompson   ierr = CeedFree(&basis_kernel_path); CeedChkBackend(ierr);
350437930d1SJeremy L Thompson   ierr = CeedFree(&basis_kernel_source); CeedChkBackend(ierr);
351c532df63SYohann 
352e15f9bd0SJeremy L Thompson   ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr);
353ab213215SJeremy L Thompson 
354ab213215SJeremy L Thompson   // Register backend functions
355c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
356c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
357e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
358c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
359e15f9bd0SJeremy L Thompson                                 CeedBasisDestroy_Cuda_shared); CeedChkBackend(ierr);
360e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
361c532df63SYohann }
362ab213215SJeremy L Thompson //------------------------------------------------------------------------------
363