xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-ref/ceed-cuda-ref-restriction.c (revision 58549094d8a305d0f4b066b44680cf34cff212e7)
1ff1e7120SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2ff1e7120SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3ff1e7120SSebastian Grimberg //
4ff1e7120SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5ff1e7120SSebastian Grimberg //
6ff1e7120SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7ff1e7120SSebastian Grimberg 
8ff1e7120SSebastian Grimberg #include <ceed.h>
9ff1e7120SSebastian Grimberg #include <ceed/backend.h>
10ff1e7120SSebastian Grimberg #include <ceed/jit-tools.h>
11ff1e7120SSebastian Grimberg #include <cuda.h>
12ff1e7120SSebastian Grimberg #include <cuda_runtime.h>
13ff1e7120SSebastian Grimberg #include <stdbool.h>
14ff1e7120SSebastian Grimberg #include <stddef.h>
15ff1e7120SSebastian Grimberg #include <string.h>
16ff1e7120SSebastian Grimberg 
17ff1e7120SSebastian Grimberg #include "../cuda/ceed-cuda-common.h"
18ff1e7120SSebastian Grimberg #include "../cuda/ceed-cuda-compile.h"
19ff1e7120SSebastian Grimberg #include "ceed-cuda-ref.h"
20ff1e7120SSebastian Grimberg 
21ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
22ff1e7120SSebastian Grimberg // Apply restriction
23ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
24ff1e7120SSebastian Grimberg static int CeedElemRestrictionApply_Cuda(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
25ff1e7120SSebastian Grimberg   CeedElemRestriction_Cuda *impl;
26ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
27ff1e7120SSebastian Grimberg   Ceed ceed;
28ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
29ff1e7120SSebastian Grimberg   Ceed_Cuda *data;
30ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetData(ceed, &data));
31ff1e7120SSebastian Grimberg   const CeedInt num_nodes = impl->num_nodes;
32ff1e7120SSebastian Grimberg   CeedInt       num_elem, elem_size;
33ff1e7120SSebastian Grimberg   CeedElemRestrictionGetNumElements(r, &num_elem);
34ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
35ff1e7120SSebastian Grimberg   CUfunction kernel;
36ff1e7120SSebastian Grimberg 
37ff1e7120SSebastian Grimberg   // Get vectors
38ff1e7120SSebastian Grimberg   const CeedScalar *d_u;
39ff1e7120SSebastian Grimberg   CeedScalar       *d_v;
40ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
41ff1e7120SSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
42ff1e7120SSebastian Grimberg     // Sum into for transpose mode, e-vec to l-vec
43ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
44ff1e7120SSebastian Grimberg   } else {
45ff1e7120SSebastian Grimberg     // Overwrite for notranspose mode, l-vec to e-vec
46ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
47ff1e7120SSebastian Grimberg   }
48ff1e7120SSebastian Grimberg 
49ff1e7120SSebastian Grimberg   // Restrict
50ff1e7120SSebastian Grimberg   if (t_mode == CEED_NOTRANSPOSE) {
51ff1e7120SSebastian Grimberg     // L-vector -> E-vector
52ff1e7120SSebastian Grimberg     if (impl->d_ind) {
53ff1e7120SSebastian Grimberg       // -- Offsets provided
54ff1e7120SSebastian Grimberg       kernel             = impl->OffsetNoTranspose;
55ff1e7120SSebastian Grimberg       void   *args[]     = {&num_elem, &impl->d_ind, &d_u, &d_v};
56ff1e7120SSebastian Grimberg       CeedInt block_size = elem_size < 1024 ? (elem_size > 32 ? elem_size : 32) : 1024;
57*58549094SSebastian Grimberg 
58ff1e7120SSebastian Grimberg       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
59ff1e7120SSebastian Grimberg     } else {
60ff1e7120SSebastian Grimberg       // -- Strided restriction
61ff1e7120SSebastian Grimberg       kernel             = impl->StridedNoTranspose;
62ff1e7120SSebastian Grimberg       void   *args[]     = {&num_elem, &d_u, &d_v};
63ff1e7120SSebastian Grimberg       CeedInt block_size = elem_size < 1024 ? (elem_size > 32 ? elem_size : 32) : 1024;
64*58549094SSebastian Grimberg 
65ff1e7120SSebastian Grimberg       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
66ff1e7120SSebastian Grimberg     }
67ff1e7120SSebastian Grimberg   } else {
68ff1e7120SSebastian Grimberg     // E-vector -> L-vector
69ff1e7120SSebastian Grimberg     if (impl->d_ind) {
70ff1e7120SSebastian Grimberg       // -- Offsets provided
71*58549094SSebastian Grimberg       CeedInt block_size = 32;
72*58549094SSebastian Grimberg       if (impl->OffsetTranspose) {
73ff1e7120SSebastian Grimberg         kernel       = impl->OffsetTranspose;
74*58549094SSebastian Grimberg         void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
75*58549094SSebastian Grimberg 
76ff1e7120SSebastian Grimberg         CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
77ff1e7120SSebastian Grimberg       } else {
78*58549094SSebastian Grimberg         kernel       = impl->OffsetTransposeDet;
79*58549094SSebastian Grimberg         void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
80*58549094SSebastian Grimberg 
81*58549094SSebastian Grimberg         CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
82*58549094SSebastian Grimberg       }
83*58549094SSebastian Grimberg     } else {
84ff1e7120SSebastian Grimberg       // -- Strided restriction
85ff1e7120SSebastian Grimberg       kernel             = impl->StridedTranspose;
86ff1e7120SSebastian Grimberg       void   *args[]     = {&num_elem, &d_u, &d_v};
87*58549094SSebastian Grimberg       CeedInt block_size = 32;
88*58549094SSebastian Grimberg 
89ff1e7120SSebastian Grimberg       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
90ff1e7120SSebastian Grimberg     }
91ff1e7120SSebastian Grimberg   }
92ff1e7120SSebastian Grimberg 
93ff1e7120SSebastian Grimberg   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
94ff1e7120SSebastian Grimberg 
95ff1e7120SSebastian Grimberg   // Restore arrays
96ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
97ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
98ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
99ff1e7120SSebastian Grimberg }
100ff1e7120SSebastian Grimberg 
101ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
102ff1e7120SSebastian Grimberg // Get offsets
103ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
104ff1e7120SSebastian Grimberg static int CeedElemRestrictionGetOffsets_Cuda(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) {
105ff1e7120SSebastian Grimberg   CeedElemRestriction_Cuda *impl;
106ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
107ff1e7120SSebastian Grimberg 
108ff1e7120SSebastian Grimberg   switch (mem_type) {
109ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
110ff1e7120SSebastian Grimberg       *offsets = impl->h_ind;
111ff1e7120SSebastian Grimberg       break;
112ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
113ff1e7120SSebastian Grimberg       *offsets = impl->d_ind;
114ff1e7120SSebastian Grimberg       break;
115ff1e7120SSebastian Grimberg   }
116ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
117ff1e7120SSebastian Grimberg }
118ff1e7120SSebastian Grimberg 
119ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
120ff1e7120SSebastian Grimberg // Destroy restriction
121ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
122ff1e7120SSebastian Grimberg static int CeedElemRestrictionDestroy_Cuda(CeedElemRestriction r) {
123ff1e7120SSebastian Grimberg   CeedElemRestriction_Cuda *impl;
124ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
125ff1e7120SSebastian Grimberg 
126ff1e7120SSebastian Grimberg   Ceed ceed;
127ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
128ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cuModuleUnload(impl->module));
129ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
130ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaFree(impl->d_ind_allocated));
131ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaFree(impl->d_t_offsets));
132ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaFree(impl->d_t_indices));
133ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaFree(impl->d_l_vec_indices));
134ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&impl));
135ff1e7120SSebastian Grimberg 
136ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
137ff1e7120SSebastian Grimberg }
138ff1e7120SSebastian Grimberg 
139ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
140ff1e7120SSebastian Grimberg // Create transpose offsets and indices
141ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
142ff1e7120SSebastian Grimberg static int CeedElemRestrictionOffset_Cuda(const CeedElemRestriction r, const CeedInt *indices) {
143ff1e7120SSebastian Grimberg   Ceed ceed;
144ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
145ff1e7120SSebastian Grimberg   CeedElemRestriction_Cuda *impl;
146ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
147ff1e7120SSebastian Grimberg   CeedSize l_size;
148ff1e7120SSebastian Grimberg   CeedInt  num_elem, elem_size, num_comp;
149ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
150ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
151ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
152ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
153ff1e7120SSebastian Grimberg 
154ff1e7120SSebastian Grimberg   // Count num_nodes
155ff1e7120SSebastian Grimberg   bool *is_node;
156ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(l_size, &is_node));
157ff1e7120SSebastian Grimberg   const CeedInt size_indices = num_elem * elem_size;
158ff1e7120SSebastian Grimberg   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
159ff1e7120SSebastian Grimberg   CeedInt num_nodes = 0;
160ff1e7120SSebastian Grimberg   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
161ff1e7120SSebastian Grimberg   impl->num_nodes = num_nodes;
162ff1e7120SSebastian Grimberg 
163ff1e7120SSebastian Grimberg   // L-vector offsets array
164ff1e7120SSebastian Grimberg   CeedInt *ind_to_offset, *l_vec_indices;
165ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
166ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
167ff1e7120SSebastian Grimberg   CeedInt j = 0;
168ff1e7120SSebastian Grimberg   for (CeedInt i = 0; i < l_size; i++) {
169ff1e7120SSebastian Grimberg     if (is_node[i]) {
170ff1e7120SSebastian Grimberg       l_vec_indices[j] = i;
171ff1e7120SSebastian Grimberg       ind_to_offset[i] = j++;
172ff1e7120SSebastian Grimberg     }
173ff1e7120SSebastian Grimberg   }
174ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&is_node));
175ff1e7120SSebastian Grimberg 
176ff1e7120SSebastian Grimberg   // Compute transpose offsets and indices
177ff1e7120SSebastian Grimberg   const CeedInt size_offsets = num_nodes + 1;
178ff1e7120SSebastian Grimberg   CeedInt      *t_offsets;
179ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
180ff1e7120SSebastian Grimberg   CeedInt *t_indices;
181ff1e7120SSebastian Grimberg   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
182ff1e7120SSebastian Grimberg   // Count node multiplicity
183ff1e7120SSebastian Grimberg   for (CeedInt e = 0; e < num_elem; ++e) {
184ff1e7120SSebastian Grimberg     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
185ff1e7120SSebastian Grimberg   }
186ff1e7120SSebastian Grimberg   // Convert to running sum
187ff1e7120SSebastian Grimberg   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
188ff1e7120SSebastian Grimberg   // List all E-vec indices associated with L-vec node
189ff1e7120SSebastian Grimberg   for (CeedInt e = 0; e < num_elem; ++e) {
190ff1e7120SSebastian Grimberg     for (CeedInt i = 0; i < elem_size; ++i) {
191ff1e7120SSebastian Grimberg       const CeedInt lid                          = elem_size * e + i;
192ff1e7120SSebastian Grimberg       const CeedInt gid                          = indices[lid];
193ff1e7120SSebastian Grimberg       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
194ff1e7120SSebastian Grimberg     }
195ff1e7120SSebastian Grimberg   }
196ff1e7120SSebastian Grimberg   // Reset running sum
197ff1e7120SSebastian Grimberg   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
198ff1e7120SSebastian Grimberg   t_offsets[0] = 0;
199ff1e7120SSebastian Grimberg 
200ff1e7120SSebastian Grimberg   // Copy data to device
201ff1e7120SSebastian Grimberg   // -- L-vector indices
202ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt)));
203ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), cudaMemcpyHostToDevice));
204ff1e7120SSebastian Grimberg   // -- Transpose offsets
205ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt)));
206ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), cudaMemcpyHostToDevice));
207ff1e7120SSebastian Grimberg   // -- Transpose indices
208ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt)));
209ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), cudaMemcpyHostToDevice));
210ff1e7120SSebastian Grimberg 
211ff1e7120SSebastian Grimberg   // Cleanup
212ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&ind_to_offset));
213ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&l_vec_indices));
214ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&t_offsets));
215ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&t_indices));
216ff1e7120SSebastian Grimberg 
217ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
218ff1e7120SSebastian Grimberg }
219ff1e7120SSebastian Grimberg 
220ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
221ff1e7120SSebastian Grimberg // Create restriction
222ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
223fcbe8c06SSebastian Grimberg int CeedElemRestrictionCreate_Cuda(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients,
2240c73c039SSebastian Grimberg                                    const CeedInt8 *curl_orients, CeedElemRestriction r) {
225ff1e7120SSebastian Grimberg   Ceed ceed;
226ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
227ff1e7120SSebastian Grimberg   CeedElemRestriction_Cuda *impl;
228ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(1, &impl));
229*58549094SSebastian Grimberg   Ceed parent;
230*58549094SSebastian Grimberg   CeedCallBackend(CeedGetParent(ceed, &parent));
231*58549094SSebastian Grimberg   bool is_deterministic;
232*58549094SSebastian Grimberg   CeedCallBackend(CeedIsDeterministic(parent, &is_deterministic));
233ff1e7120SSebastian Grimberg   CeedInt num_elem, num_comp, elem_size;
234ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
235ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
236ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
237ff1e7120SSebastian Grimberg   CeedInt size        = num_elem * elem_size;
238ff1e7120SSebastian Grimberg   CeedInt strides[3]  = {1, size, elem_size};
239ff1e7120SSebastian Grimberg   CeedInt comp_stride = 1;
240ff1e7120SSebastian Grimberg 
24100125730SSebastian Grimberg   CeedRestrictionType rstr_type;
24200125730SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type));
24300125730SSebastian Grimberg   CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND,
24400125730SSebastian Grimberg             "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented");
24500125730SSebastian Grimberg 
246ff1e7120SSebastian Grimberg   // Stride data
247ff1e7120SSebastian Grimberg   bool is_strided;
248ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided));
249ff1e7120SSebastian Grimberg   if (is_strided) {
250ff1e7120SSebastian Grimberg     bool has_backend_strides;
251ff1e7120SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides));
252ff1e7120SSebastian Grimberg     if (!has_backend_strides) {
253ff1e7120SSebastian Grimberg       CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
254ff1e7120SSebastian Grimberg     }
255ff1e7120SSebastian Grimberg   } else {
256ff1e7120SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
257ff1e7120SSebastian Grimberg   }
258ff1e7120SSebastian Grimberg 
259ff1e7120SSebastian Grimberg   impl->h_ind           = NULL;
260ff1e7120SSebastian Grimberg   impl->h_ind_allocated = NULL;
261ff1e7120SSebastian Grimberg   impl->d_ind           = NULL;
262ff1e7120SSebastian Grimberg   impl->d_ind_allocated = NULL;
263ff1e7120SSebastian Grimberg   impl->d_t_indices     = NULL;
264ff1e7120SSebastian Grimberg   impl->d_t_offsets     = NULL;
265ff1e7120SSebastian Grimberg   impl->num_nodes       = size;
266ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionSetData(r, impl));
267ff1e7120SSebastian Grimberg   CeedInt layout[3] = {1, elem_size * num_elem, elem_size};
268ff1e7120SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionSetELayout(r, layout));
269ff1e7120SSebastian Grimberg 
270ff1e7120SSebastian Grimberg   // Set up device indices/offset arrays
271ff1e7120SSebastian Grimberg   switch (mem_type) {
272ff1e7120SSebastian Grimberg     case CEED_MEM_HOST: {
273ff1e7120SSebastian Grimberg       switch (copy_mode) {
274ff1e7120SSebastian Grimberg         case CEED_OWN_POINTER:
275ff1e7120SSebastian Grimberg           impl->h_ind_allocated = (CeedInt *)indices;
276ff1e7120SSebastian Grimberg           impl->h_ind           = (CeedInt *)indices;
277ff1e7120SSebastian Grimberg           break;
278ff1e7120SSebastian Grimberg         case CEED_USE_POINTER:
279ff1e7120SSebastian Grimberg           impl->h_ind = (CeedInt *)indices;
280ff1e7120SSebastian Grimberg           break;
281ff1e7120SSebastian Grimberg         case CEED_COPY_VALUES:
282ff1e7120SSebastian Grimberg           if (indices != NULL) {
283ff1e7120SSebastian Grimberg             CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
284ff1e7120SSebastian Grimberg             memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
285ff1e7120SSebastian Grimberg             impl->h_ind = impl->h_ind_allocated;
286ff1e7120SSebastian Grimberg           }
287ff1e7120SSebastian Grimberg           break;
288ff1e7120SSebastian Grimberg       }
289ff1e7120SSebastian Grimberg       if (indices != NULL) {
290ff1e7120SSebastian Grimberg         CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
291ff1e7120SSebastian Grimberg         impl->d_ind_allocated = impl->d_ind;  // We own the device memory
292ff1e7120SSebastian Grimberg         CeedCallCuda(ceed, cudaMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), cudaMemcpyHostToDevice));
293*58549094SSebastian Grimberg         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Cuda(r, indices));
294ff1e7120SSebastian Grimberg       }
295ff1e7120SSebastian Grimberg       break;
296ff1e7120SSebastian Grimberg     }
297ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE: {
298ff1e7120SSebastian Grimberg       switch (copy_mode) {
299ff1e7120SSebastian Grimberg         case CEED_COPY_VALUES:
300ff1e7120SSebastian Grimberg           if (indices != NULL) {
301ff1e7120SSebastian Grimberg             CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
302ff1e7120SSebastian Grimberg             impl->d_ind_allocated = impl->d_ind;  // We own the device memory
303ff1e7120SSebastian Grimberg             CeedCallCuda(ceed, cudaMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), cudaMemcpyDeviceToDevice));
304ff1e7120SSebastian Grimberg           }
305ff1e7120SSebastian Grimberg           break;
306ff1e7120SSebastian Grimberg         case CEED_OWN_POINTER:
307ff1e7120SSebastian Grimberg           impl->d_ind           = (CeedInt *)indices;
308ff1e7120SSebastian Grimberg           impl->d_ind_allocated = impl->d_ind;
309ff1e7120SSebastian Grimberg           break;
310ff1e7120SSebastian Grimberg         case CEED_USE_POINTER:
311ff1e7120SSebastian Grimberg           impl->d_ind = (CeedInt *)indices;
312ff1e7120SSebastian Grimberg       }
313ff1e7120SSebastian Grimberg       if (indices != NULL) {
314ff1e7120SSebastian Grimberg         CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
315ff1e7120SSebastian Grimberg         CeedCallCuda(ceed, cudaMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), cudaMemcpyDeviceToHost));
316ff1e7120SSebastian Grimberg         impl->h_ind = impl->h_ind_allocated;
317*58549094SSebastian Grimberg         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Cuda(r, indices));
318ff1e7120SSebastian Grimberg       }
319ff1e7120SSebastian Grimberg       break;
320ff1e7120SSebastian Grimberg     }
321ff1e7120SSebastian Grimberg     // LCOV_EXCL_START
322ff1e7120SSebastian Grimberg     default:
323ff1e7120SSebastian Grimberg       return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
324ff1e7120SSebastian Grimberg       // LCOV_EXCL_STOP
325ff1e7120SSebastian Grimberg   }
326ff1e7120SSebastian Grimberg 
327*58549094SSebastian Grimberg   // Compile CUDA kernels (add atomicAdd function for old NVidia architectures)
328ff1e7120SSebastian Grimberg   CeedInt num_nodes = impl->num_nodes;
329*58549094SSebastian Grimberg   char   *restriction_kernel_path, *restriction_kernel_source = NULL;
330ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-restriction.h", &restriction_kernel_path));
33123d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
332*58549094SSebastian Grimberg   if (!is_deterministic) {
333*58549094SSebastian Grimberg     struct cudaDeviceProp prop;
334*58549094SSebastian Grimberg     Ceed_Cuda            *ceed_data;
335*58549094SSebastian Grimberg     CeedCallBackend(CeedGetData(ceed, &ceed_data));
336*58549094SSebastian Grimberg     CeedCallBackend(cudaGetDeviceProperties(&prop, ceed_data->device_id));
337*58549094SSebastian Grimberg     if ((prop.major < 6) && (CEED_SCALAR_TYPE != CEED_SCALAR_FP32)) {
338*58549094SSebastian Grimberg       char *atomic_add_path;
339*58549094SSebastian Grimberg       CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-atomic-add-fallback.h", &atomic_add_path));
340*58549094SSebastian Grimberg       CeedCallBackend(CeedLoadSourceToBuffer(ceed, atomic_add_path, &restriction_kernel_source));
341*58549094SSebastian Grimberg       CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
342*58549094SSebastian Grimberg       CeedCallBackend(CeedFree(&atomic_add_path));
343*58549094SSebastian Grimberg     }
344*58549094SSebastian Grimberg   }
345*58549094SSebastian Grimberg   if (!restriction_kernel_source) {
346ff1e7120SSebastian Grimberg     CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
347*58549094SSebastian Grimberg   }
34823d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
349ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCompile_Cuda(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem,
350ff1e7120SSebastian Grimberg                                    "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES",
351ff1e7120SSebastian Grimberg                                    strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2]));
352ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose));
353*58549094SSebastian Grimberg   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose));
354ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose));
355*58549094SSebastian Grimberg   if (!is_deterministic) {
356*58549094SSebastian Grimberg     CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose));
357*58549094SSebastian Grimberg   } else {
358*58549094SSebastian Grimberg     CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetTransposeDet", &impl->OffsetTransposeDet));
359*58549094SSebastian Grimberg   }
360ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&restriction_kernel_path));
361ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&restriction_kernel_source));
362ff1e7120SSebastian Grimberg 
363ff1e7120SSebastian Grimberg   // Register backend functions
364ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Cuda));
365ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Cuda));
3667c1dbaffSSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Cuda));
367ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Cuda));
368ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Cuda));
369ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
370ff1e7120SSebastian Grimberg }
371ff1e7120SSebastian Grimberg 
372ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
373