xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-ref/ceed-cuda-ref-restriction.c (revision ff1e7120ad38c28723000cabebb8ede4ff31c408)
1*ff1e7120SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*ff1e7120SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*ff1e7120SSebastian Grimberg //
4*ff1e7120SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5*ff1e7120SSebastian Grimberg //
6*ff1e7120SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7*ff1e7120SSebastian Grimberg 
8*ff1e7120SSebastian Grimberg #include <ceed.h>
9*ff1e7120SSebastian Grimberg #include <ceed/backend.h>
10*ff1e7120SSebastian Grimberg #include <ceed/jit-tools.h>
11*ff1e7120SSebastian Grimberg #include <cuda.h>
12*ff1e7120SSebastian Grimberg #include <cuda_runtime.h>
13*ff1e7120SSebastian Grimberg #include <stdbool.h>
14*ff1e7120SSebastian Grimberg #include <stddef.h>
15*ff1e7120SSebastian Grimberg #include <string.h>
16*ff1e7120SSebastian Grimberg 
17*ff1e7120SSebastian Grimberg #include "../cuda/ceed-cuda-common.h"
18*ff1e7120SSebastian Grimberg #include "../cuda/ceed-cuda-compile.h"
19*ff1e7120SSebastian Grimberg #include "ceed-cuda-ref.h"
20*ff1e7120SSebastian Grimberg 
21*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
22*ff1e7120SSebastian Grimberg // Apply restriction
23*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
24*ff1e7120SSebastian Grimberg static int CeedElemRestrictionApply_Cuda(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
25*ff1e7120SSebastian Grimberg   CeedElemRestriction_Cuda *impl;
26*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
27*ff1e7120SSebastian Grimberg   Ceed ceed;
28*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
29*ff1e7120SSebastian Grimberg   Ceed_Cuda *data;
30*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetData(ceed, &data));
31*ff1e7120SSebastian Grimberg   const CeedInt warp_size  = 32;
32*ff1e7120SSebastian Grimberg   const CeedInt block_size = warp_size;
33*ff1e7120SSebastian Grimberg   const CeedInt num_nodes  = impl->num_nodes;
34*ff1e7120SSebastian Grimberg   CeedInt       num_elem, elem_size;
35*ff1e7120SSebastian Grimberg   CeedElemRestrictionGetNumElements(r, &num_elem);
36*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
37*ff1e7120SSebastian Grimberg   CUfunction kernel;
38*ff1e7120SSebastian Grimberg 
39*ff1e7120SSebastian Grimberg   // Get vectors
40*ff1e7120SSebastian Grimberg   const CeedScalar *d_u;
41*ff1e7120SSebastian Grimberg   CeedScalar       *d_v;
42*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
43*ff1e7120SSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
44*ff1e7120SSebastian Grimberg     // Sum into for transpose mode, e-vec to l-vec
45*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
46*ff1e7120SSebastian Grimberg   } else {
47*ff1e7120SSebastian Grimberg     // Overwrite for notranspose mode, l-vec to e-vec
48*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
49*ff1e7120SSebastian Grimberg   }
50*ff1e7120SSebastian Grimberg 
51*ff1e7120SSebastian Grimberg   // Restrict
52*ff1e7120SSebastian Grimberg   if (t_mode == CEED_NOTRANSPOSE) {
53*ff1e7120SSebastian Grimberg     // L-vector -> E-vector
54*ff1e7120SSebastian Grimberg     if (impl->d_ind) {
55*ff1e7120SSebastian Grimberg       // -- Offsets provided
56*ff1e7120SSebastian Grimberg       kernel             = impl->OffsetNoTranspose;
57*ff1e7120SSebastian Grimberg       void   *args[]     = {&num_elem, &impl->d_ind, &d_u, &d_v};
58*ff1e7120SSebastian Grimberg       CeedInt block_size = elem_size < 1024 ? (elem_size > 32 ? elem_size : 32) : 1024;
59*ff1e7120SSebastian Grimberg       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
60*ff1e7120SSebastian Grimberg     } else {
61*ff1e7120SSebastian Grimberg       // -- Strided restriction
62*ff1e7120SSebastian Grimberg       kernel             = impl->StridedNoTranspose;
63*ff1e7120SSebastian Grimberg       void   *args[]     = {&num_elem, &d_u, &d_v};
64*ff1e7120SSebastian Grimberg       CeedInt block_size = elem_size < 1024 ? (elem_size > 32 ? elem_size : 32) : 1024;
65*ff1e7120SSebastian Grimberg       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
66*ff1e7120SSebastian Grimberg     }
67*ff1e7120SSebastian Grimberg   } else {
68*ff1e7120SSebastian Grimberg     // E-vector -> L-vector
69*ff1e7120SSebastian Grimberg     if (impl->d_ind) {
70*ff1e7120SSebastian Grimberg       // -- Offsets provided
71*ff1e7120SSebastian Grimberg       kernel       = impl->OffsetTranspose;
72*ff1e7120SSebastian Grimberg       void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
73*ff1e7120SSebastian Grimberg       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
74*ff1e7120SSebastian Grimberg     } else {
75*ff1e7120SSebastian Grimberg       // -- Strided restriction
76*ff1e7120SSebastian Grimberg       kernel       = impl->StridedTranspose;
77*ff1e7120SSebastian Grimberg       void *args[] = {&num_elem, &d_u, &d_v};
78*ff1e7120SSebastian Grimberg       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
79*ff1e7120SSebastian Grimberg     }
80*ff1e7120SSebastian Grimberg   }
81*ff1e7120SSebastian Grimberg 
82*ff1e7120SSebastian Grimberg   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
83*ff1e7120SSebastian Grimberg 
84*ff1e7120SSebastian Grimberg   // Restore arrays
85*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
86*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
87*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
88*ff1e7120SSebastian Grimberg }
89*ff1e7120SSebastian Grimberg 
90*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
91*ff1e7120SSebastian Grimberg // Get offsets
92*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
93*ff1e7120SSebastian Grimberg static int CeedElemRestrictionGetOffsets_Cuda(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) {
94*ff1e7120SSebastian Grimberg   CeedElemRestriction_Cuda *impl;
95*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
96*ff1e7120SSebastian Grimberg 
97*ff1e7120SSebastian Grimberg   switch (mem_type) {
98*ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
99*ff1e7120SSebastian Grimberg       *offsets = impl->h_ind;
100*ff1e7120SSebastian Grimberg       break;
101*ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
102*ff1e7120SSebastian Grimberg       *offsets = impl->d_ind;
103*ff1e7120SSebastian Grimberg       break;
104*ff1e7120SSebastian Grimberg   }
105*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
106*ff1e7120SSebastian Grimberg }
107*ff1e7120SSebastian Grimberg 
108*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
109*ff1e7120SSebastian Grimberg // Destroy restriction
110*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
111*ff1e7120SSebastian Grimberg static int CeedElemRestrictionDestroy_Cuda(CeedElemRestriction r) {
112*ff1e7120SSebastian Grimberg   CeedElemRestriction_Cuda *impl;
113*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
114*ff1e7120SSebastian Grimberg 
115*ff1e7120SSebastian Grimberg   Ceed ceed;
116*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
117*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cuModuleUnload(impl->module));
118*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
119*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaFree(impl->d_ind_allocated));
120*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaFree(impl->d_t_offsets));
121*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaFree(impl->d_t_indices));
122*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaFree(impl->d_l_vec_indices));
123*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&impl));
124*ff1e7120SSebastian Grimberg 
125*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
126*ff1e7120SSebastian Grimberg }
127*ff1e7120SSebastian Grimberg 
128*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
129*ff1e7120SSebastian Grimberg // Create transpose offsets and indices
130*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
131*ff1e7120SSebastian Grimberg static int CeedElemRestrictionOffset_Cuda(const CeedElemRestriction r, const CeedInt *indices) {
132*ff1e7120SSebastian Grimberg   Ceed ceed;
133*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
134*ff1e7120SSebastian Grimberg   CeedElemRestriction_Cuda *impl;
135*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
136*ff1e7120SSebastian Grimberg   CeedSize l_size;
137*ff1e7120SSebastian Grimberg   CeedInt  num_elem, elem_size, num_comp;
138*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
139*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
140*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
141*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
142*ff1e7120SSebastian Grimberg 
143*ff1e7120SSebastian Grimberg   // Count num_nodes
144*ff1e7120SSebastian Grimberg   bool *is_node;
145*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(l_size, &is_node));
146*ff1e7120SSebastian Grimberg   const CeedInt size_indices = num_elem * elem_size;
147*ff1e7120SSebastian Grimberg   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
148*ff1e7120SSebastian Grimberg   CeedInt num_nodes = 0;
149*ff1e7120SSebastian Grimberg   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
150*ff1e7120SSebastian Grimberg   impl->num_nodes = num_nodes;
151*ff1e7120SSebastian Grimberg 
152*ff1e7120SSebastian Grimberg   // L-vector offsets array
153*ff1e7120SSebastian Grimberg   CeedInt *ind_to_offset, *l_vec_indices;
154*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
155*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
156*ff1e7120SSebastian Grimberg   CeedInt j = 0;
157*ff1e7120SSebastian Grimberg   for (CeedInt i = 0; i < l_size; i++) {
158*ff1e7120SSebastian Grimberg     if (is_node[i]) {
159*ff1e7120SSebastian Grimberg       l_vec_indices[j] = i;
160*ff1e7120SSebastian Grimberg       ind_to_offset[i] = j++;
161*ff1e7120SSebastian Grimberg     }
162*ff1e7120SSebastian Grimberg   }
163*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&is_node));
164*ff1e7120SSebastian Grimberg 
165*ff1e7120SSebastian Grimberg   // Compute transpose offsets and indices
166*ff1e7120SSebastian Grimberg   const CeedInt size_offsets = num_nodes + 1;
167*ff1e7120SSebastian Grimberg   CeedInt      *t_offsets;
168*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
169*ff1e7120SSebastian Grimberg   CeedInt *t_indices;
170*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
171*ff1e7120SSebastian Grimberg   // Count node multiplicity
172*ff1e7120SSebastian Grimberg   for (CeedInt e = 0; e < num_elem; ++e) {
173*ff1e7120SSebastian Grimberg     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
174*ff1e7120SSebastian Grimberg   }
175*ff1e7120SSebastian Grimberg   // Convert to running sum
176*ff1e7120SSebastian Grimberg   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
177*ff1e7120SSebastian Grimberg   // List all E-vec indices associated with L-vec node
178*ff1e7120SSebastian Grimberg   for (CeedInt e = 0; e < num_elem; ++e) {
179*ff1e7120SSebastian Grimberg     for (CeedInt i = 0; i < elem_size; ++i) {
180*ff1e7120SSebastian Grimberg       const CeedInt lid                          = elem_size * e + i;
181*ff1e7120SSebastian Grimberg       const CeedInt gid                          = indices[lid];
182*ff1e7120SSebastian Grimberg       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
183*ff1e7120SSebastian Grimberg     }
184*ff1e7120SSebastian Grimberg   }
185*ff1e7120SSebastian Grimberg   // Reset running sum
186*ff1e7120SSebastian Grimberg   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
187*ff1e7120SSebastian Grimberg   t_offsets[0] = 0;
188*ff1e7120SSebastian Grimberg 
189*ff1e7120SSebastian Grimberg   // Copy data to device
190*ff1e7120SSebastian Grimberg   // -- L-vector indices
191*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt)));
192*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), cudaMemcpyHostToDevice));
193*ff1e7120SSebastian Grimberg   // -- Transpose offsets
194*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt)));
195*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), cudaMemcpyHostToDevice));
196*ff1e7120SSebastian Grimberg   // -- Transpose indices
197*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt)));
198*ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), cudaMemcpyHostToDevice));
199*ff1e7120SSebastian Grimberg 
200*ff1e7120SSebastian Grimberg   // Cleanup
201*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&ind_to_offset));
202*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&l_vec_indices));
203*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&t_offsets));
204*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&t_indices));
205*ff1e7120SSebastian Grimberg 
206*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
207*ff1e7120SSebastian Grimberg }
208*ff1e7120SSebastian Grimberg 
209*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
210*ff1e7120SSebastian Grimberg // Create restriction
211*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
212*ff1e7120SSebastian Grimberg int CeedElemRestrictionCreate_Cuda(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, CeedElemRestriction r) {
213*ff1e7120SSebastian Grimberg   Ceed ceed;
214*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
215*ff1e7120SSebastian Grimberg   CeedElemRestriction_Cuda *impl;
216*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(1, &impl));
217*ff1e7120SSebastian Grimberg   CeedInt num_elem, num_comp, elem_size;
218*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
219*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
220*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
221*ff1e7120SSebastian Grimberg   CeedInt size        = num_elem * elem_size;
222*ff1e7120SSebastian Grimberg   CeedInt strides[3]  = {1, size, elem_size};
223*ff1e7120SSebastian Grimberg   CeedInt comp_stride = 1;
224*ff1e7120SSebastian Grimberg 
225*ff1e7120SSebastian Grimberg   // Stride data
226*ff1e7120SSebastian Grimberg   bool is_strided;
227*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided));
228*ff1e7120SSebastian Grimberg   if (is_strided) {
229*ff1e7120SSebastian Grimberg     bool has_backend_strides;
230*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides));
231*ff1e7120SSebastian Grimberg     if (!has_backend_strides) {
232*ff1e7120SSebastian Grimberg       CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
233*ff1e7120SSebastian Grimberg     }
234*ff1e7120SSebastian Grimberg   } else {
235*ff1e7120SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
236*ff1e7120SSebastian Grimberg   }
237*ff1e7120SSebastian Grimberg 
238*ff1e7120SSebastian Grimberg   impl->h_ind           = NULL;
239*ff1e7120SSebastian Grimberg   impl->h_ind_allocated = NULL;
240*ff1e7120SSebastian Grimberg   impl->d_ind           = NULL;
241*ff1e7120SSebastian Grimberg   impl->d_ind_allocated = NULL;
242*ff1e7120SSebastian Grimberg   impl->d_t_indices     = NULL;
243*ff1e7120SSebastian Grimberg   impl->d_t_offsets     = NULL;
244*ff1e7120SSebastian Grimberg   impl->num_nodes       = size;
245*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionSetData(r, impl));
246*ff1e7120SSebastian Grimberg   CeedInt layout[3] = {1, elem_size * num_elem, elem_size};
247*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionSetELayout(r, layout));
248*ff1e7120SSebastian Grimberg 
249*ff1e7120SSebastian Grimberg   // Set up device indices/offset arrays
250*ff1e7120SSebastian Grimberg   switch (mem_type) {
251*ff1e7120SSebastian Grimberg     case CEED_MEM_HOST: {
252*ff1e7120SSebastian Grimberg       switch (copy_mode) {
253*ff1e7120SSebastian Grimberg         case CEED_OWN_POINTER:
254*ff1e7120SSebastian Grimberg           impl->h_ind_allocated = (CeedInt *)indices;
255*ff1e7120SSebastian Grimberg           impl->h_ind           = (CeedInt *)indices;
256*ff1e7120SSebastian Grimberg           break;
257*ff1e7120SSebastian Grimberg         case CEED_USE_POINTER:
258*ff1e7120SSebastian Grimberg           impl->h_ind = (CeedInt *)indices;
259*ff1e7120SSebastian Grimberg           break;
260*ff1e7120SSebastian Grimberg         case CEED_COPY_VALUES:
261*ff1e7120SSebastian Grimberg           if (indices != NULL) {
262*ff1e7120SSebastian Grimberg             CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
263*ff1e7120SSebastian Grimberg             memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
264*ff1e7120SSebastian Grimberg             impl->h_ind = impl->h_ind_allocated;
265*ff1e7120SSebastian Grimberg           }
266*ff1e7120SSebastian Grimberg           break;
267*ff1e7120SSebastian Grimberg       }
268*ff1e7120SSebastian Grimberg       if (indices != NULL) {
269*ff1e7120SSebastian Grimberg         CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
270*ff1e7120SSebastian Grimberg         impl->d_ind_allocated = impl->d_ind;  // We own the device memory
271*ff1e7120SSebastian Grimberg         CeedCallCuda(ceed, cudaMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), cudaMemcpyHostToDevice));
272*ff1e7120SSebastian Grimberg         CeedCallBackend(CeedElemRestrictionOffset_Cuda(r, indices));
273*ff1e7120SSebastian Grimberg       }
274*ff1e7120SSebastian Grimberg       break;
275*ff1e7120SSebastian Grimberg     }
276*ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE: {
277*ff1e7120SSebastian Grimberg       switch (copy_mode) {
278*ff1e7120SSebastian Grimberg         case CEED_COPY_VALUES:
279*ff1e7120SSebastian Grimberg           if (indices != NULL) {
280*ff1e7120SSebastian Grimberg             CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
281*ff1e7120SSebastian Grimberg             impl->d_ind_allocated = impl->d_ind;  // We own the device memory
282*ff1e7120SSebastian Grimberg             CeedCallCuda(ceed, cudaMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), cudaMemcpyDeviceToDevice));
283*ff1e7120SSebastian Grimberg           }
284*ff1e7120SSebastian Grimberg           break;
285*ff1e7120SSebastian Grimberg         case CEED_OWN_POINTER:
286*ff1e7120SSebastian Grimberg           impl->d_ind           = (CeedInt *)indices;
287*ff1e7120SSebastian Grimberg           impl->d_ind_allocated = impl->d_ind;
288*ff1e7120SSebastian Grimberg           break;
289*ff1e7120SSebastian Grimberg         case CEED_USE_POINTER:
290*ff1e7120SSebastian Grimberg           impl->d_ind = (CeedInt *)indices;
291*ff1e7120SSebastian Grimberg       }
292*ff1e7120SSebastian Grimberg       if (indices != NULL) {
293*ff1e7120SSebastian Grimberg         CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
294*ff1e7120SSebastian Grimberg         CeedCallCuda(ceed, cudaMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), cudaMemcpyDeviceToHost));
295*ff1e7120SSebastian Grimberg         impl->h_ind = impl->h_ind_allocated;
296*ff1e7120SSebastian Grimberg         CeedCallBackend(CeedElemRestrictionOffset_Cuda(r, indices));
297*ff1e7120SSebastian Grimberg       }
298*ff1e7120SSebastian Grimberg       break;
299*ff1e7120SSebastian Grimberg     }
300*ff1e7120SSebastian Grimberg     // LCOV_EXCL_START
301*ff1e7120SSebastian Grimberg     default:
302*ff1e7120SSebastian Grimberg       return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
303*ff1e7120SSebastian Grimberg       // LCOV_EXCL_STOP
304*ff1e7120SSebastian Grimberg   }
305*ff1e7120SSebastian Grimberg 
306*ff1e7120SSebastian Grimberg   // Compile CUDA kernels
307*ff1e7120SSebastian Grimberg   CeedInt num_nodes = impl->num_nodes;
308*ff1e7120SSebastian Grimberg   char   *restriction_kernel_path, *restriction_kernel_source;
309*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-restriction.h", &restriction_kernel_path));
310*ff1e7120SSebastian Grimberg   CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source -----\n");
311*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
312*ff1e7120SSebastian Grimberg   CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source Complete! -----\n");
313*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCompile_Cuda(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem,
314*ff1e7120SSebastian Grimberg                                    "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES",
315*ff1e7120SSebastian Grimberg                                    strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2]));
316*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose));
317*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose));
318*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose));
319*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose));
320*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&restriction_kernel_path));
321*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&restriction_kernel_source));
322*ff1e7120SSebastian Grimberg 
323*ff1e7120SSebastian Grimberg   // Register backend functions
324*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Cuda));
325*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Cuda));
326*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Cuda));
327*ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Cuda));
328*ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
329*ff1e7120SSebastian Grimberg }
330*ff1e7120SSebastian Grimberg 
331*ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
332