xref: /libCEED/rust/libceed-sys/c-src/backends/hip-ref/ceed-hip-ref-restriction.c (revision dce496930433734c64b503b2c0057d67691234a0) !
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <string.h>
14 #include <hip/hip_runtime.h>
15 
16 #include "../hip/ceed-hip-common.h"
17 #include "../hip/ceed-hip-compile.h"
18 #include "ceed-hip-ref.h"
19 
20 //------------------------------------------------------------------------------
21 // Core apply restriction code
22 //------------------------------------------------------------------------------
23 static inline int CeedElemRestrictionApply_Hip_Core(CeedElemRestriction rstr, CeedTransposeMode t_mode, bool use_signs, bool use_orients,
24                                                     CeedVector u, CeedVector v, CeedRequest *request) {
25   Ceed                     ceed;
26   CeedInt                  num_elem, elem_size;
27   CeedRestrictionType      rstr_type;
28   const CeedScalar        *d_u;
29   CeedScalar              *d_v;
30   CeedElemRestriction_Hip *impl;
31   hipFunction_t            kernel;
32 
33   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
34   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
35   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
36   CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
37   CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
38   const CeedInt num_nodes = impl->num_nodes;
39 
40   // Get vectors
41   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
42   if (t_mode == CEED_TRANSPOSE) {
43     // Sum into for transpose mode, e-vec to l-vec
44     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
45   } else {
46     // Overwrite for notranspose mode, l-vec to e-vec
47     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
48   }
49 
50   // Restrict
51   if (t_mode == CEED_NOTRANSPOSE) {
52     // L-vector -> E-vector
53     const CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
54     const CeedInt grid       = CeedDivUpInt(num_nodes, block_size);
55 
56     switch (rstr_type) {
57       case CEED_RESTRICTION_STRIDED: {
58         kernel       = impl->StridedNoTranspose;
59         void *args[] = {&num_elem, &d_u, &d_v};
60 
61         CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
62       } break;
63       case CEED_RESTRICTION_STANDARD: {
64         kernel       = impl->OffsetNoTranspose;
65         void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
66 
67         CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
68       } break;
69       case CEED_RESTRICTION_ORIENTED: {
70         if (use_signs) {
71           kernel       = impl->OrientedNoTranspose;
72           void *args[] = {&num_elem, &impl->d_ind, &impl->d_orients, &d_u, &d_v};
73 
74           CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
75         } else {
76           kernel       = impl->OffsetNoTranspose;
77           void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
78 
79           CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
80         }
81       } break;
82       case CEED_RESTRICTION_CURL_ORIENTED: {
83         if (use_signs && use_orients) {
84           kernel       = impl->CurlOrientedNoTranspose;
85           void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v};
86 
87           CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
88         } else if (use_orients) {
89           kernel       = impl->CurlOrientedUnsignedNoTranspose;
90           void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v};
91 
92           CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
93         } else {
94           kernel       = impl->OffsetNoTranspose;
95           void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
96 
97           CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
98         }
99       } break;
100     }
101   } else {
102     // E-vector -> L-vector
103     const CeedInt block_size = 64;
104     const CeedInt grid       = CeedDivUpInt(num_nodes, block_size);
105 
106     switch (rstr_type) {
107       case CEED_RESTRICTION_STRIDED: {
108         kernel       = impl->StridedTranspose;
109         void *args[] = {&num_elem, &d_u, &d_v};
110 
111         CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
112       } break;
113       case CEED_RESTRICTION_STANDARD: {
114         if (impl->OffsetTranspose) {
115           kernel       = impl->OffsetTranspose;
116           void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
117 
118           CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
119         } else {
120           kernel       = impl->OffsetTransposeDet;
121           void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
122 
123           CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
124         }
125       } break;
126       case CEED_RESTRICTION_ORIENTED: {
127         if (use_signs) {
128           kernel       = impl->OrientedTranspose;
129           void *args[] = {&num_elem, &impl->d_ind, &impl->d_orients, &d_u, &d_v};
130 
131           CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
132         } else {
133           if (impl->OffsetTranspose) {
134             kernel       = impl->OffsetTranspose;
135             void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
136 
137             CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
138           } else {
139             kernel       = impl->OffsetTransposeDet;
140             void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
141 
142             CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
143           }
144         }
145       } break;
146       case CEED_RESTRICTION_CURL_ORIENTED: {
147         if (use_signs && use_orients) {
148           kernel       = impl->CurlOrientedTranspose;
149           void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v};
150 
151           CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
152         } else if (use_orients) {
153           kernel       = impl->CurlOrientedUnsignedTranspose;
154           void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v};
155 
156           CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
157         } else {
158           if (impl->OffsetTranspose) {
159             kernel       = impl->OffsetTranspose;
160             void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
161 
162             CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
163           } else {
164             kernel       = impl->OffsetTransposeDet;
165             void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
166 
167             CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args));
168           }
169         }
170       } break;
171     }
172   }
173 
174   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
175 
176   // Restore arrays
177   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
178   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
179   return CEED_ERROR_SUCCESS;
180 }
181 
182 //------------------------------------------------------------------------------
183 // Apply restriction
184 //------------------------------------------------------------------------------
185 static int CeedElemRestrictionApply_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
186   return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, true, true, u, v, request);
187 }
188 
189 //------------------------------------------------------------------------------
190 // Apply unsigned restriction
191 //------------------------------------------------------------------------------
192 static int CeedElemRestrictionApplyUnsigned_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v,
193                                                 CeedRequest *request) {
194   return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, true, u, v, request);
195 }
196 
197 //------------------------------------------------------------------------------
198 // Apply unoriented restriction
199 //------------------------------------------------------------------------------
200 static int CeedElemRestrictionApplyUnoriented_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v,
201                                                   CeedRequest *request) {
202   return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, false, u, v, request);
203 }
204 
205 //------------------------------------------------------------------------------
206 // Get offsets
207 //------------------------------------------------------------------------------
208 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) {
209   CeedElemRestriction_Hip *impl;
210 
211   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
212   switch (mem_type) {
213     case CEED_MEM_HOST:
214       *offsets = impl->h_ind;
215       break;
216     case CEED_MEM_DEVICE:
217       *offsets = impl->d_ind;
218       break;
219   }
220   return CEED_ERROR_SUCCESS;
221 }
222 
223 //------------------------------------------------------------------------------
224 // Get orientations
225 //------------------------------------------------------------------------------
226 static int CeedElemRestrictionGetOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const bool **orients) {
227   CeedElemRestriction_Hip *impl;
228   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
229 
230   switch (mem_type) {
231     case CEED_MEM_HOST:
232       *orients = impl->h_orients;
233       break;
234     case CEED_MEM_DEVICE:
235       *orients = impl->d_orients;
236       break;
237   }
238   return CEED_ERROR_SUCCESS;
239 }
240 
241 //------------------------------------------------------------------------------
242 // Get curl-conforming orientations
243 //------------------------------------------------------------------------------
244 static int CeedElemRestrictionGetCurlOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt8 **curl_orients) {
245   CeedElemRestriction_Hip *impl;
246   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
247 
248   switch (mem_type) {
249     case CEED_MEM_HOST:
250       *curl_orients = impl->h_curl_orients;
251       break;
252     case CEED_MEM_DEVICE:
253       *curl_orients = impl->d_curl_orients;
254       break;
255   }
256   return CEED_ERROR_SUCCESS;
257 }
258 
259 //------------------------------------------------------------------------------
260 // Destroy restriction
261 //------------------------------------------------------------------------------
262 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction rstr) {
263   Ceed                     ceed;
264   CeedElemRestriction_Hip *impl;
265 
266   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
267   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
268   CeedCallHip(ceed, hipModuleUnload(impl->module));
269   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
270   CeedCallHip(ceed, hipFree(impl->d_ind_allocated));
271   CeedCallHip(ceed, hipFree(impl->d_t_offsets));
272   CeedCallHip(ceed, hipFree(impl->d_t_indices));
273   CeedCallHip(ceed, hipFree(impl->d_l_vec_indices));
274   CeedCallBackend(CeedFree(&impl->h_orients_allocated));
275   CeedCallHip(ceed, hipFree(impl->d_orients_allocated));
276   CeedCallBackend(CeedFree(&impl->h_curl_orients_allocated));
277   CeedCallHip(ceed, hipFree(impl->d_curl_orients_allocated));
278   CeedCallBackend(CeedFree(&impl));
279   return CEED_ERROR_SUCCESS;
280 }
281 
282 //------------------------------------------------------------------------------
283 // Create transpose offsets and indices
284 //------------------------------------------------------------------------------
285 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction rstr, const CeedInt *indices) {
286   Ceed                     ceed;
287   bool                    *is_node;
288   CeedSize                 l_size;
289   CeedInt                  num_elem, elem_size, num_comp, num_nodes = 0;
290   CeedInt                 *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices;
291   CeedElemRestriction_Hip *impl;
292 
293   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
294   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
295   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
296   CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
297   CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size));
298   CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
299   const CeedInt size_indices = num_elem * elem_size;
300 
301   // Count num_nodes
302   CeedCallBackend(CeedCalloc(l_size, &is_node));
303 
304   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
305   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
306   impl->num_nodes = num_nodes;
307 
308   // L-vector offsets array
309   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
310   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
311   for (CeedInt i = 0, j = 0; i < l_size; i++) {
312     if (is_node[i]) {
313       l_vec_indices[j] = i;
314       ind_to_offset[i] = j++;
315     }
316   }
317   CeedCallBackend(CeedFree(&is_node));
318 
319   // Compute transpose offsets and indices
320   const CeedInt size_offsets = num_nodes + 1;
321 
322   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
323   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
324   // Count node multiplicity
325   for (CeedInt e = 0; e < num_elem; ++e) {
326     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
327   }
328   // Convert to running sum
329   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
330   // List all E-vec indices associated with L-vec node
331   for (CeedInt e = 0; e < num_elem; ++e) {
332     for (CeedInt i = 0; i < elem_size; ++i) {
333       const CeedInt lid = elem_size * e + i;
334       const CeedInt gid = indices[lid];
335 
336       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
337     }
338   }
339   // Reset running sum
340   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
341   t_offsets[0] = 0;
342 
343   // Copy data to device
344   // -- L-vector indices
345   CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt)));
346   CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice));
347   // -- Transpose offsets
348   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt)));
349   CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice));
350   // -- Transpose indices
351   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt)));
352   CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice));
353 
354   // Cleanup
355   CeedCallBackend(CeedFree(&ind_to_offset));
356   CeedCallBackend(CeedFree(&l_vec_indices));
357   CeedCallBackend(CeedFree(&t_offsets));
358   CeedCallBackend(CeedFree(&t_indices));
359   return CEED_ERROR_SUCCESS;
360 }
361 
362 //------------------------------------------------------------------------------
363 // Create restriction
364 //------------------------------------------------------------------------------
365 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients,
366                                   const CeedInt8 *curl_orients, CeedElemRestriction rstr) {
367   Ceed                     ceed, ceed_parent;
368   bool                     is_deterministic;
369   CeedInt                  num_elem, num_comp, elem_size, comp_stride = 1;
370   CeedRestrictionType      rstr_type;
371   char                    *restriction_kernel_path, *restriction_kernel_source;
372   CeedElemRestriction_Hip *impl;
373 
374   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
375   CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
376   CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic));
377   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
378   CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
379   CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
380   const CeedInt size       = num_elem * elem_size;
381   CeedInt       strides[3] = {1, size, elem_size};
382   CeedInt       layout[3]  = {1, elem_size * num_elem, elem_size};
383 
384   // Stride data
385   CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
386   if (rstr_type == CEED_RESTRICTION_STRIDED) {
387     bool has_backend_strides;
388 
389     CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides));
390     if (!has_backend_strides) {
391       CeedCallBackend(CeedElemRestrictionGetStrides(rstr, &strides));
392     }
393   } else {
394     CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride));
395   }
396 
397   CeedCallBackend(CeedCalloc(1, &impl));
398   impl->num_nodes                = size;
399   impl->h_ind                    = NULL;
400   impl->h_ind_allocated          = NULL;
401   impl->d_ind                    = NULL;
402   impl->d_ind_allocated          = NULL;
403   impl->d_t_indices              = NULL;
404   impl->d_t_offsets              = NULL;
405   impl->h_orients                = NULL;
406   impl->h_orients_allocated      = NULL;
407   impl->d_orients                = NULL;
408   impl->d_orients_allocated      = NULL;
409   impl->h_curl_orients           = NULL;
410   impl->h_curl_orients_allocated = NULL;
411   impl->d_curl_orients           = NULL;
412   impl->d_curl_orients_allocated = NULL;
413   CeedCallBackend(CeedElemRestrictionSetData(rstr, impl));
414   CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout));
415 
416   // Set up device offset/orientation arrays
417   if (rstr_type != CEED_RESTRICTION_STRIDED) {
418     switch (mem_type) {
419       case CEED_MEM_HOST: {
420         switch (copy_mode) {
421           case CEED_OWN_POINTER:
422             impl->h_ind_allocated = (CeedInt *)indices;
423             impl->h_ind           = (CeedInt *)indices;
424             break;
425           case CEED_USE_POINTER:
426             impl->h_ind = (CeedInt *)indices;
427             break;
428           case CEED_COPY_VALUES:
429             CeedCallBackend(CeedMalloc(size, &impl->h_ind_allocated));
430             memcpy(impl->h_ind_allocated, indices, size * sizeof(CeedInt));
431             impl->h_ind = impl->h_ind_allocated;
432             break;
433         }
434         CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
435         impl->d_ind_allocated = impl->d_ind;  // We own the device memory
436         CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice));
437         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, indices));
438       } break;
439       case CEED_MEM_DEVICE: {
440         switch (copy_mode) {
441           case CEED_COPY_VALUES:
442             CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
443             impl->d_ind_allocated = impl->d_ind;  // We own the device memory
444             CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice));
445             break;
446           case CEED_OWN_POINTER:
447             impl->d_ind           = (CeedInt *)indices;
448             impl->d_ind_allocated = impl->d_ind;
449             break;
450           case CEED_USE_POINTER:
451             impl->d_ind = (CeedInt *)indices;
452             break;
453         }
454         CeedCallBackend(CeedMalloc(size, &impl->h_ind_allocated));
455         CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, size * sizeof(CeedInt), hipMemcpyDeviceToHost));
456         impl->h_ind = impl->h_ind_allocated;
457         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, indices));
458       } break;
459     }
460 
461     // Orientation data
462     if (rstr_type == CEED_RESTRICTION_ORIENTED) {
463       switch (mem_type) {
464         case CEED_MEM_HOST: {
465           switch (copy_mode) {
466             case CEED_OWN_POINTER:
467               impl->h_orients_allocated = (bool *)orients;
468               impl->h_orients           = (bool *)orients;
469               break;
470             case CEED_USE_POINTER:
471               impl->h_orients = (bool *)orients;
472               break;
473             case CEED_COPY_VALUES:
474               CeedCallBackend(CeedMalloc(size, &impl->h_orients_allocated));
475               memcpy(impl->h_orients_allocated, orients, size * sizeof(bool));
476               impl->h_orients = impl->h_orients_allocated;
477               break;
478           }
479           CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients, size * sizeof(bool)));
480           impl->d_orients_allocated = impl->d_orients;  // We own the device memory
481           CeedCallHip(ceed, hipMemcpy(impl->d_orients, orients, size * sizeof(bool), hipMemcpyHostToDevice));
482         } break;
483         case CEED_MEM_DEVICE: {
484           switch (copy_mode) {
485             case CEED_COPY_VALUES:
486               CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients, size * sizeof(bool)));
487               impl->d_orients_allocated = impl->d_orients;  // We own the device memory
488               CeedCallHip(ceed, hipMemcpy(impl->d_orients, orients, size * sizeof(bool), hipMemcpyDeviceToDevice));
489               break;
490             case CEED_OWN_POINTER:
491               impl->d_orients           = (bool *)orients;
492               impl->d_orients_allocated = impl->d_orients;
493               break;
494             case CEED_USE_POINTER:
495               impl->d_orients = (bool *)orients;
496               break;
497           }
498           CeedCallBackend(CeedMalloc(size, &impl->h_orients_allocated));
499           CeedCallHip(ceed, hipMemcpy(impl->h_orients_allocated, impl->d_orients, size * sizeof(bool), hipMemcpyDeviceToHost));
500           impl->h_orients = impl->h_orients_allocated;
501         } break;
502       }
503     } else if (rstr_type == CEED_RESTRICTION_CURL_ORIENTED) {
504       switch (mem_type) {
505         case CEED_MEM_HOST: {
506           switch (copy_mode) {
507             case CEED_OWN_POINTER:
508               impl->h_curl_orients_allocated = (CeedInt8 *)curl_orients;
509               impl->h_curl_orients           = (CeedInt8 *)curl_orients;
510               break;
511             case CEED_USE_POINTER:
512               impl->h_curl_orients = (CeedInt8 *)curl_orients;
513               break;
514             case CEED_COPY_VALUES:
515               CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_allocated));
516               memcpy(impl->h_curl_orients_allocated, curl_orients, 3 * size * sizeof(CeedInt8));
517               impl->h_curl_orients = impl->h_curl_orients_allocated;
518               break;
519           }
520           CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients, 3 * size * sizeof(CeedInt8)));
521           impl->d_curl_orients_allocated = impl->d_curl_orients;  // We own the device memory
522           CeedCallHip(ceed, hipMemcpy(impl->d_curl_orients, curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyHostToDevice));
523         } break;
524         case CEED_MEM_DEVICE: {
525           switch (copy_mode) {
526             case CEED_COPY_VALUES:
527               CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients, 3 * size * sizeof(CeedInt8)));
528               impl->d_curl_orients_allocated = impl->d_curl_orients;  // We own the device memory
529               CeedCallHip(ceed, hipMemcpy(impl->d_curl_orients, curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToDevice));
530               break;
531             case CEED_OWN_POINTER:
532               impl->d_curl_orients           = (CeedInt8 *)curl_orients;
533               impl->d_curl_orients_allocated = impl->d_curl_orients;
534               break;
535             case CEED_USE_POINTER:
536               impl->d_curl_orients = (CeedInt8 *)curl_orients;
537               break;
538           }
539           CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_allocated));
540           CeedCallHip(ceed, hipMemcpy(impl->h_curl_orients_allocated, impl->d_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToHost));
541           impl->h_curl_orients = impl->h_curl_orients_allocated;
542         } break;
543       }
544     }
545   }
546 
547   // Compile HIP kernels
548   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path));
549   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
550   CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
551   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
552   CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem,
553                                   "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, "RSTR_STRIDE_NODES",
554                                   strides[0], "RSTR_STRIDE_COMP", strides[1], "RSTR_STRIDE_ELEM", strides[2]));
555   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose));
556   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose));
557   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose));
558   if (!is_deterministic) {
559     CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose));
560   } else {
561     CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTransposeDet", &impl->OffsetTransposeDet));
562   }
563   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedNoTranspose", &impl->OrientedNoTranspose));
564   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedTranspose", &impl->OrientedTranspose));
565   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedNoTranspose", &impl->CurlOrientedNoTranspose));
566   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedNoTranspose", &impl->CurlOrientedUnsignedNoTranspose));
567   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedTranspose", &impl->CurlOrientedTranspose));
568   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedTranspose", &impl->CurlOrientedUnsignedTranspose));
569   CeedCallBackend(CeedFree(&restriction_kernel_path));
570   CeedCallBackend(CeedFree(&restriction_kernel_source));
571 
572   // Register backend functions
573   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Hip));
574   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApplyUnsigned_Hip));
575   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApplyUnoriented_Hip));
576   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Hip));
577   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOrientations", CeedElemRestrictionGetOrientations_Hip));
578   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetCurlOrientations", CeedElemRestrictionGetCurlOrientations_Hip));
579   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Hip));
580   return CEED_ERROR_SUCCESS;
581 }
582 
583 //------------------------------------------------------------------------------
584