xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-basis.c (revision d92fedf5b7546cf2fc50391dbcfb657a2e1f0a3b)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <cuda.h>
20 #include <cuda_runtime.h>
21 #include "ceed-cuda-ref.h"
22 #include "../cuda/ceed-cuda-compile.h"
23 
24 //------------------------------------------------------------------------------
25 // Tensor Basis Kernels
26 //------------------------------------------------------------------------------
27 // *INDENT-OFF*
28 static const char *basiskernels = QUOTE(
29 
30 //------------------------------------------------------------------------------
31 // Interp
32 //------------------------------------------------------------------------------
33 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
34                                   const CeedScalar *__restrict__ interp1d,
35                                   const CeedScalar *__restrict__ u,
36                                   CeedScalar *__restrict__ v) {
37   const CeedInt i = threadIdx.x;
38 
39   __shared__ CeedScalar s_mem[BASIS_Q1D * BASIS_P1D + 2 * BASIS_BUF_LEN];
40   CeedScalar *s_interp1d = s_mem;
41   CeedScalar *s_buf1 = s_mem + BASIS_Q1D * BASIS_P1D;
42   CeedScalar *s_buf2 = s_buf1 + BASIS_BUF_LEN;
43   for (CeedInt k = i; k < BASIS_Q1D * BASIS_P1D; k += blockDim.x) {
44     s_interp1d[k] = interp1d[k];
45   }
46 
47   const CeedInt P = transpose ? BASIS_Q1D : BASIS_P1D;
48   const CeedInt Q = transpose ? BASIS_P1D : BASIS_Q1D;
49   const CeedInt stride0 = transpose ? 1 : BASIS_P1D;
50   const CeedInt stride1 = transpose ? BASIS_P1D : 1;
51   const CeedInt u_stride = transpose ? BASIS_NQPT : BASIS_ELEMSIZE;
52   const CeedInt v_stride = transpose ? BASIS_ELEMSIZE : BASIS_NQPT;
53   const CeedInt u_comp_stride = nelem * (transpose ? BASIS_NQPT : BASIS_ELEMSIZE);
54   const CeedInt v_comp_stride = nelem * (transpose ? BASIS_ELEMSIZE : BASIS_NQPT);
55   const CeedInt u_size = transpose ? BASIS_NQPT : BASIS_ELEMSIZE;
56 
57   // Apply basis element by element
58   for (CeedInt elem = blockIdx.x; elem < nelem; elem += gridDim.x) {
59     for (CeedInt comp = 0; comp < BASIS_NCOMP; ++comp) {
60       const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
61       CeedScalar *cur_v = v + elem * v_stride + comp * v_comp_stride;
62       for (CeedInt k = i; k < u_size; k += blockDim.x) {
63         s_buf1[k] = cur_u[k];
64       }
65       CeedInt pre = u_size;
66       CeedInt post = 1;
67       for (CeedInt d = 0; d < BASIS_DIM; d++) {
68         __syncthreads();
69         // Update buffers used
70         pre /= P;
71         const CeedScalar *in = d % 2 ? s_buf2 : s_buf1;
72         CeedScalar *out = d == BASIS_DIM - 1 ? cur_v : (d % 2 ? s_buf1 : s_buf2);
73 
74         // Contract along middle index
75         const CeedInt writeLen = pre * post * Q;
76         for (CeedInt k = i; k < writeLen; k += blockDim.x) {
77           const CeedInt c = k % post;
78           const CeedInt j = (k / post) % Q;
79           const CeedInt a = k / (post * Q);
80 
81           CeedScalar vk = 0;
82           for (CeedInt b = 0; b < P; b++)
83             vk += s_interp1d[j*stride0 + b*stride1] * in[(a*P + b)*post + c];
84 
85           out[k] = vk;
86         }
87 
88         post *= Q;
89       }
90     }
91   }
92 }
93 
94 //------------------------------------------------------------------------------
95 // Grad
96 //------------------------------------------------------------------------------
97 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
98                                 const CeedScalar *__restrict__ interp1d,
99                                 const CeedScalar *__restrict__ grad1d,
100                                 const CeedScalar *__restrict__ u,
101                                 CeedScalar *__restrict__ v) {
102   const CeedInt i = threadIdx.x;
103 
104   __shared__ CeedScalar s_mem[2 * (BASIS_Q1D * BASIS_P1D + BASIS_BUF_LEN)];
105   CeedScalar *s_interp1d = s_mem;
106   CeedScalar *s_grad1d = s_interp1d + BASIS_Q1D * BASIS_P1D;
107   CeedScalar *s_buf1 = s_grad1d + BASIS_Q1D * BASIS_P1D;
108   CeedScalar *s_buf2 = s_buf1 + BASIS_BUF_LEN;
109   for (CeedInt k = i; k < BASIS_Q1D * BASIS_P1D; k += blockDim.x) {
110     s_interp1d[k] = interp1d[k];
111     s_grad1d[k] = grad1d[k];
112   }
113 
114   const CeedInt P = transpose ? BASIS_Q1D : BASIS_P1D;
115   const CeedInt Q = transpose ? BASIS_P1D : BASIS_Q1D;
116   const CeedInt stride0 = transpose ? 1 : BASIS_P1D;
117   const CeedInt stride1 = transpose ? BASIS_P1D : 1;
118   const CeedInt u_stride = transpose ? BASIS_NQPT : BASIS_ELEMSIZE;
119   const CeedInt v_stride = transpose ? BASIS_ELEMSIZE : BASIS_NQPT;
120   const CeedInt u_comp_stride = nelem * (transpose ? BASIS_NQPT : BASIS_ELEMSIZE);
121   const CeedInt v_comp_stride = nelem * (transpose ? BASIS_ELEMSIZE : BASIS_NQPT);
122   const CeedInt u_dim_stride = transpose ? nelem * BASIS_NQPT * BASIS_NCOMP : 0;
123   const CeedInt v_dim_stride = transpose ? 0 : nelem * BASIS_NQPT * BASIS_NCOMP;
124 
125   // Apply basis element by element
126   for (CeedInt elem = blockIdx.x; elem < nelem; elem += gridDim.x) {
127     for (CeedInt comp = 0; comp < BASIS_NCOMP; ++comp) {
128 
129       // dim*dim contractions for grad
130       for (CeedInt dim1 = 0; dim1 < BASIS_DIM; dim1++) {
131         CeedInt pre = transpose ? BASIS_NQPT : BASIS_ELEMSIZE;
132         CeedInt post = 1;
133         const CeedScalar *cur_u = u + elem * u_stride + dim1 * u_dim_stride +
134                                   comp * u_comp_stride;
135         CeedScalar *cur_v = v + elem * v_stride + dim1 * v_dim_stride + comp *
136                             v_comp_stride;
137         for (CeedInt dim2 = 0; dim2 < BASIS_DIM; dim2++) {
138           __syncthreads();
139           // Update buffers used
140           pre /= P;
141           const CeedScalar *op = dim1 == dim2 ? s_grad1d : s_interp1d;
142           const CeedScalar *in = dim2 == 0 ? cur_u : (dim2 % 2 ? s_buf2 : s_buf1);
143           CeedScalar *out = dim2 == BASIS_DIM - 1 ? cur_v : (dim2 % 2 ? s_buf1 : s_buf2);
144 
145           // Contract along middle index
146           const CeedInt writeLen = pre * post * Q;
147           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
148             const CeedInt c = k % post;
149             const CeedInt j = (k / post) % Q;
150             const CeedInt a = k / (post * Q);
151             CeedScalar vk = 0;
152             for (CeedInt b = 0; b < P; b++)
153               vk += op[j * stride0 + b * stride1] * in[(a * P + b) * post + c];
154 
155             if (transpose && dim2 == BASIS_DIM - 1)
156               out[k] += vk;
157             else
158               out[k] = vk;
159           }
160 
161           post *= Q;
162         }
163       }
164     }
165   }
166 }
167 
168 //------------------------------------------------------------------------------
169 // 1D quadrature weights
170 //------------------------------------------------------------------------------
171 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
172                          CeedScalar *w) {
173   const int i = threadIdx.x;
174   if (i < BASIS_Q1D) {
175     const size_t elem = blockIdx.x;
176     if (elem < nelem)
177       w[elem*BASIS_Q1D + i] = qweight1d[i];
178   }
179 }
180 
181 //------------------------------------------------------------------------------
182 // 2D quadrature weights
183 //------------------------------------------------------------------------------
184 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
185                          CeedScalar *w) {
186 
187   const int i = threadIdx.x;
188   const int j = threadIdx.y;
189   if (i < BASIS_Q1D && j < BASIS_Q1D) {
190     const size_t elem = blockIdx.x;
191     if (elem < nelem) {
192       const size_t ind = (elem * BASIS_Q1D + j) * BASIS_Q1D + i;
193       w[ind] = qweight1d[i] * qweight1d[j];
194     }
195   }
196 }
197 
198 //------------------------------------------------------------------------------
199 // 3D quadrature weights
200 //------------------------------------------------------------------------------
201 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
202                          CeedScalar *w) {
203   const int i = threadIdx.x;
204   const int j = threadIdx.y;
205   if (i < BASIS_Q1D && j < BASIS_Q1D) {
206     const size_t elem = blockIdx.x;
207     if (elem < nelem) {
208       for (int k=0; k<BASIS_Q1D; k++) {
209         const size_t ind = ((elem * BASIS_Q1D + k) * BASIS_Q1D + j) * BASIS_Q1D + i;
210         w[ind] = qweight1d[i] * qweight1d[j] * qweight1d[k];
211       }
212     }
213   }
214 }
215 
216 //------------------------------------------------------------------------------
217 // Quadrature weights
218 //------------------------------------------------------------------------------
219 extern "C" __global__ void weight(const CeedInt nelem,
220                                   const CeedScalar *__restrict__ qweight1d,
221                                   CeedScalar *__restrict__ v) {
222   if (BASIS_DIM==1)
223     weight1d(nelem, qweight1d, v);
224   else if (BASIS_DIM==2)
225     weight2d(nelem, qweight1d, v);
226   else if (BASIS_DIM==3)
227     weight3d(nelem, qweight1d, v);
228 }
229 
230 );
231 
232 //------------------------------------------------------------------------------
233 // Non-Tensor Basis Kernels
234 //------------------------------------------------------------------------------
235 static const char *kernelsNonTensorRef = QUOTE(
236 
237 //------------------------------------------------------------------------------
238 // Interp
239 //------------------------------------------------------------------------------
240 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
241                                   const CeedScalar *d_B,
242                                   const CeedScalar *__restrict__ d_U,
243                                   CeedScalar *__restrict__ d_V) {
244   const int tid = threadIdx.x;
245 
246   const CeedScalar *U;
247   CeedScalar V;
248   //TODO load B in shared memory if blockDim.z > 1?
249 
250   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
251        elem += gridDim.x*blockDim.z) {
252     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
253       if (!transpose) { // run with Q threads
254         U = d_U + elem*P + comp*nelem*P;
255         V = 0.0;
256         for (int i = 0; i < P; ++i)
257           V += d_B[i + tid*P]*U[i];
258 
259         d_V[elem*Q + comp*nelem*Q + tid] = V;
260       } else { // run with P threads
261         U = d_U + elem*Q + comp*nelem*Q;
262         V = 0.0;
263         for (int i = 0; i < Q; ++i)
264           V += d_B[tid + i*P]*U[i];
265 
266         d_V[elem*P + comp*nelem*P + tid] = V;
267       }
268     }
269   }
270 }
271 
272 //------------------------------------------------------------------------------
273 // Grad
274 //------------------------------------------------------------------------------
275 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
276                                 const CeedScalar *d_G,
277                                 const CeedScalar *__restrict__ d_U,
278                                 CeedScalar *__restrict__ d_V) {
279   const int tid = threadIdx.x;
280 
281   const CeedScalar *U;
282   //TODO load G in shared memory if blockDim.z > 1?
283 
284   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
285        elem += gridDim.x*blockDim.z) {
286     for (int comp=0; comp<BASIS_NCOMP; comp++) {
287       if (!transpose) { // run with Q threads
288         CeedScalar V[BASIS_DIM];
289         U = d_U + elem*P + comp*nelem*P;
290         for (int dim = 0; dim < BASIS_DIM; dim++)
291           V[dim] = 0.0;
292 
293         for (int i = 0; i < P; ++i) {
294           const CeedScalar val = U[i];
295           for(int dim = 0; dim < BASIS_DIM; dim++)
296             V[dim] += d_G[i + tid*P + dim*P*Q]*val;
297         }
298         for (int dim = 0; dim < BASIS_DIM; dim++) {
299           d_V[elem*Q + comp*nelem*Q + dim*BASIS_NCOMP*nelem*Q + tid] = V[dim];
300         }
301       } else { // run with P threads
302         CeedScalar V = 0.0;
303         for (int dim = 0; dim < BASIS_DIM; dim++) {
304           U = d_U + elem*Q + comp*nelem*Q +dim*BASIS_NCOMP*nelem*Q;
305           for (int i = 0; i < Q; ++i)
306             V += d_G[tid + i*P + dim*P*Q]*U[i];
307         }
308         d_V[elem*P + comp*nelem*P + tid] = V;
309       }
310     }
311   }
312 }
313 
314 //------------------------------------------------------------------------------
315 // Weight
316 //------------------------------------------------------------------------------
317 extern "C" __global__ void weight(const CeedInt nelem,
318                                   const CeedScalar *__restrict__ qweight,
319                                   CeedScalar *__restrict__ d_V) {
320   const int tid = threadIdx.x;
321   //TODO load qweight in shared memory if blockDim.z > 1?
322   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
323        elem += gridDim.x*blockDim.z) {
324     d_V[elem*Q + tid] = qweight[tid];
325   }
326 }
327 
328 );
329 // *INDENT-ON*
330 
331 //------------------------------------------------------------------------------
332 // Basis apply - tensor
333 //------------------------------------------------------------------------------
334 int CeedBasisApply_Cuda(CeedBasis basis, const CeedInt nelem,
335                         CeedTransposeMode tmode,
336                         CeedEvalMode emode, CeedVector u, CeedVector v) {
337   int ierr;
338   Ceed ceed;
339   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
340   Ceed_Cuda *ceed_Cuda;
341   ierr = CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr);
342   CeedBasis_Cuda *data;
343   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
344   const CeedInt transpose = tmode == CEED_TRANSPOSE;
345   const int maxblocksize = 32;
346 
347   // Read vectors
348   const CeedScalar *d_u;
349   CeedScalar *d_v;
350   if (emode != CEED_EVAL_WEIGHT) {
351     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
352   }
353   ierr = CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
354 
355   // Clear v for transpose operation
356   if (tmode == CEED_TRANSPOSE) {
357     CeedInt length;
358     ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr);
359     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar));
360     CeedChk_Cu(ceed,ierr);
361   }
362   CeedInt Q1d, dim;
363   ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
364   ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
365 
366   // Basis action
367   switch (emode) {
368   case CEED_EVAL_INTERP: {
369     void *interpargs[] = {(void *) &nelem, (void *) &transpose,
370                           &data->d_interp1d, &d_u, &d_v
371                          };
372     CeedInt blocksize = CeedIntPow(Q1d, dim);
373     blocksize = blocksize > maxblocksize ? maxblocksize : blocksize;
374 
375     ierr = CeedRunKernelCuda(ceed, data->interp, nelem, blocksize, interpargs);
376     CeedChkBackend(ierr);
377   } break;
378   case CEED_EVAL_GRAD: {
379     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d,
380                         &data->d_grad1d, &d_u, &d_v
381                        };
382     CeedInt blocksize = maxblocksize;
383 
384     ierr = CeedRunKernelCuda(ceed, data->grad, nelem, blocksize, gradargs);
385     CeedChkBackend(ierr);
386   } break;
387   case CEED_EVAL_WEIGHT: {
388     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
389     const int gridsize = nelem;
390     ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize,
391                                 Q1d, dim >= 2 ? Q1d : 1, 1,
392                                 weightargs); CeedChkBackend(ierr);
393   } break;
394   // LCOV_EXCL_START
395   // Evaluate the divergence to/from the quadrature points
396   case CEED_EVAL_DIV:
397     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
398   // Evaluate the curl to/from the quadrature points
399   case CEED_EVAL_CURL:
400     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
401   // Take no action, BasisApply should not have been called
402   case CEED_EVAL_NONE:
403     return CeedError(ceed, CEED_ERROR_BACKEND,
404                      "CEED_EVAL_NONE does not make sense in this context");
405     // LCOV_EXCL_STOP
406   }
407 
408   // Restore vectors
409   if (emode != CEED_EVAL_WEIGHT) {
410     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
411   }
412   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
413   return CEED_ERROR_SUCCESS;
414 }
415 
416 //------------------------------------------------------------------------------
417 // Basis apply - non-tensor
418 //------------------------------------------------------------------------------
419 int CeedBasisApplyNonTensor_Cuda(CeedBasis basis, const CeedInt nelem,
420                                  CeedTransposeMode tmode, CeedEvalMode emode,
421                                  CeedVector u, CeedVector v) {
422   int ierr;
423   Ceed ceed;
424   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
425   Ceed_Cuda *ceed_Cuda;
426   ierr = CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr);
427   CeedBasisNonTensor_Cuda *data;
428   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
429   CeedInt nnodes, nqpt;
430   ierr = CeedBasisGetNumQuadraturePoints(basis, &nqpt); CeedChkBackend(ierr);
431   ierr = CeedBasisGetNumNodes(basis, &nnodes); CeedChkBackend(ierr);
432   const CeedInt transpose = tmode == CEED_TRANSPOSE;
433   int elemsPerBlock = 1;
434   int grid = nelem/elemsPerBlock+((nelem/elemsPerBlock*elemsPerBlock<nelem)?1:0);
435 
436   // Read vectors
437   const CeedScalar *d_u;
438   CeedScalar *d_v;
439   if (emode != CEED_EVAL_WEIGHT) {
440     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
441   }
442   ierr = CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
443 
444   // Clear v for transpose operation
445   if (tmode == CEED_TRANSPOSE) {
446     CeedInt length;
447     ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr);
448     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar));
449     CeedChk_Cu(ceed, ierr);
450   }
451 
452   // Apply basis operation
453   switch (emode) {
454   case CEED_EVAL_INTERP: {
455     void *interpargs[] = {(void *) &nelem, (void *) &transpose,
456                           &data->d_interp, &d_u, &d_v
457                          };
458     if (!transpose) {
459       ierr = CeedRunKernelDimCuda(ceed, data->interp, grid, nqpt, 1,
460                                   elemsPerBlock, interpargs); CeedChkBackend(ierr);
461     } else {
462       ierr = CeedRunKernelDimCuda(ceed, data->interp, grid, nnodes, 1,
463                                   elemsPerBlock, interpargs); CeedChkBackend(ierr);
464     }
465   } break;
466   case CEED_EVAL_GRAD: {
467     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_grad,
468                         &d_u, &d_v
469                        };
470     if (!transpose) {
471       ierr = CeedRunKernelDimCuda(ceed, data->grad, grid, nqpt, 1,
472                                   elemsPerBlock, gradargs); CeedChkBackend(ierr);
473     } else {
474       ierr = CeedRunKernelDimCuda(ceed, data->grad, grid, nnodes, 1,
475                                   elemsPerBlock, gradargs); CeedChkBackend(ierr);
476     }
477   } break;
478   case CEED_EVAL_WEIGHT: {
479     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight, &d_v};
480     ierr = CeedRunKernelDimCuda(ceed, data->weight, grid, nqpt, 1,
481                                 elemsPerBlock, weightargs); CeedChkBackend(ierr);
482   } break;
483   // LCOV_EXCL_START
484   // Evaluate the divergence to/from the quadrature points
485   case CEED_EVAL_DIV:
486     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
487   // Evaluate the curl to/from the quadrature points
488   case CEED_EVAL_CURL:
489     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
490   // Take no action, BasisApply should not have been called
491   case CEED_EVAL_NONE:
492     return CeedError(ceed, CEED_ERROR_BACKEND,
493                      "CEED_EVAL_NONE does not make sense in this context");
494     // LCOV_EXCL_STOP
495   }
496 
497   // Restore vectors
498   if (emode != CEED_EVAL_WEIGHT) {
499     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
500   }
501   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
502   return CEED_ERROR_SUCCESS;
503 }
504 
505 //------------------------------------------------------------------------------
506 // Destroy tensor basis
507 //------------------------------------------------------------------------------
508 static int CeedBasisDestroy_Cuda(CeedBasis basis) {
509   int ierr;
510   Ceed ceed;
511   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
512 
513   CeedBasis_Cuda *data;
514   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
515 
516   CeedChk_Cu(ceed, cuModuleUnload(data->module));
517 
518   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed,ierr);
519   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed,ierr);
520   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed,ierr);
521 
522   ierr = CeedFree(&data); CeedChkBackend(ierr);
523   return CEED_ERROR_SUCCESS;
524 }
525 
526 //------------------------------------------------------------------------------
527 // Destroy non-tensor basis
528 //------------------------------------------------------------------------------
529 static int CeedBasisDestroyNonTensor_Cuda(CeedBasis basis) {
530   int ierr;
531   Ceed ceed;
532   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
533 
534   CeedBasisNonTensor_Cuda *data;
535   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
536 
537   CeedChk_Cu(ceed, cuModuleUnload(data->module));
538 
539   ierr = cudaFree(data->d_qweight); CeedChk_Cu(ceed, ierr);
540   ierr = cudaFree(data->d_interp); CeedChk_Cu(ceed, ierr);
541   ierr = cudaFree(data->d_grad); CeedChk_Cu(ceed, ierr);
542 
543   ierr = CeedFree(&data); CeedChkBackend(ierr);
544   return CEED_ERROR_SUCCESS;
545 }
546 
547 //------------------------------------------------------------------------------
548 // Create tensor
549 //------------------------------------------------------------------------------
550 int CeedBasisCreateTensorH1_Cuda(CeedInt dim, CeedInt P1d, CeedInt Q1d,
551                                  const CeedScalar *interp1d,
552                                  const CeedScalar *grad1d,
553                                  const CeedScalar *qref1d,
554                                  const CeedScalar *qweight1d,
555                                  CeedBasis basis) {
556   int ierr;
557   Ceed ceed;
558   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
559   CeedBasis_Cuda *data;
560   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
561 
562   // Copy data to GPU
563   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
564   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed,ierr);
565   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
566                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed,ierr);
567 
568   const CeedInt iBytes = qBytes * P1d;
569   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed,ierr);
570   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
571                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed,ierr);
572 
573   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed,ierr);
574   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
575                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed,ierr);
576 
577   // Complie basis kernels
578   CeedInt ncomp;
579   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
580   ierr = CeedCompileCuda(ceed, basiskernels, &data->module, 7,
581                          "BASIS_Q1D", Q1d,
582                          "BASIS_P1D", P1d,
583                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
584                              Q1d : P1d, dim),
585                          "BASIS_DIM", dim,
586                          "BASIS_NCOMP", ncomp,
587                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
588                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
589                         ); CeedChkBackend(ierr);
590   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
591   CeedChkBackend(ierr);
592   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
593   CeedChkBackend(ierr);
594   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
595   CeedChkBackend(ierr);
596   ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr);
597 
598   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
599                                 CeedBasisApply_Cuda); CeedChkBackend(ierr);
600   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
601                                 CeedBasisDestroy_Cuda); CeedChkBackend(ierr);
602   return CEED_ERROR_SUCCESS;
603 }
604 
605 //------------------------------------------------------------------------------
606 // Create non-tensor
607 //------------------------------------------------------------------------------
608 int CeedBasisCreateH1_Cuda(CeedElemTopology topo, CeedInt dim, CeedInt nnodes,
609                            CeedInt nqpts, const CeedScalar *interp,
610                            const CeedScalar *grad, const CeedScalar *qref,
611                            const CeedScalar *qweight, CeedBasis basis) {
612   int ierr;
613   Ceed ceed;
614   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
615   CeedBasisNonTensor_Cuda *data;
616   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
617 
618   // Copy basis data to GPU
619   const CeedInt qBytes = nqpts * sizeof(CeedScalar);
620   ierr = cudaMalloc((void **)&data->d_qweight, qBytes); CeedChk_Cu(ceed, ierr);
621   ierr = cudaMemcpy(data->d_qweight, qweight, qBytes,
622                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
623 
624   const CeedInt iBytes = qBytes * nnodes;
625   ierr = cudaMalloc((void **)&data->d_interp, iBytes); CeedChk_Cu(ceed, ierr);
626   ierr = cudaMemcpy(data->d_interp, interp, iBytes,
627                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
628 
629   const CeedInt gBytes = qBytes * nnodes * dim;
630   ierr = cudaMalloc((void **)&data->d_grad, gBytes); CeedChk_Cu(ceed, ierr);
631   ierr = cudaMemcpy(data->d_grad, grad, gBytes,
632                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
633 
634   // Compile basis kernels
635   CeedInt ncomp;
636   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
637   ierr = CeedCompileCuda(ceed, kernelsNonTensorRef, &data->module, 4,
638                          "Q", nqpts,
639                          "P", nnodes,
640                          "BASIS_DIM", dim,
641                          "BASIS_NCOMP", ncomp
642                         ); CeedChk_Cu(ceed, ierr);
643   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
644   CeedChk_Cu(ceed, ierr);
645   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
646   CeedChk_Cu(ceed, ierr);
647   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
648   CeedChk_Cu(ceed, ierr);
649 
650   ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr);
651 
652   // Register backend functions
653   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
654                                 CeedBasisApplyNonTensor_Cuda); CeedChkBackend(ierr);
655   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
656                                 CeedBasisDestroyNonTensor_Cuda); CeedChkBackend(ierr);
657   return CEED_ERROR_SUCCESS;
658 }
659 //------------------------------------------------------------------------------
660