xref: /libCEED/rust/libceed-sys/c-src/backends/hip-ref/ceed-hip-ref-restriction.c (revision fe96005463bdbb79b892d21a5c89e2b475ecf62b)
15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
30d0321e0SJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
50d0321e0SJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
70d0321e0SJeremy L Thompson 
849aac155SJeremy L Thompson #include <ceed.h>
90d0321e0SJeremy L Thompson #include <ceed/backend.h>
10437930d1SJeremy L Thompson #include <ceed/jit-tools.h>
110d0321e0SJeremy L Thompson #include <stdbool.h>
120d0321e0SJeremy L Thompson #include <stddef.h>
1344d7a66cSJeremy L Thompson #include <string.h>
14c85e8640SSebastian Grimberg #include <hip/hip_runtime.h>
152b730f8bSJeremy L Thompson 
1649aac155SJeremy L Thompson #include "../hip/ceed-hip-common.h"
170d0321e0SJeremy L Thompson #include "../hip/ceed-hip-compile.h"
182b730f8bSJeremy L Thompson #include "ceed-hip-ref.h"
190d0321e0SJeremy L Thompson 
200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
21cf8cbdd6SSebastian Grimberg // Compile restriction kernels
22cf8cbdd6SSebastian Grimberg //------------------------------------------------------------------------------
23cf8cbdd6SSebastian Grimberg static inline int CeedElemRestrictionSetupCompile_Hip(CeedElemRestriction rstr) {
24cf8cbdd6SSebastian Grimberg   Ceed                     ceed;
25cf8cbdd6SSebastian Grimberg   bool                     is_deterministic;
2622070f95SJeremy L Thompson   char                    *restriction_kernel_source;
2722070f95SJeremy L Thompson   const char              *restriction_kernel_path;
28cf8cbdd6SSebastian Grimberg   CeedInt                  num_elem, num_comp, elem_size, comp_stride;
29cf8cbdd6SSebastian Grimberg   CeedRestrictionType      rstr_type;
30cf8cbdd6SSebastian Grimberg   CeedElemRestriction_Hip *impl;
31cf8cbdd6SSebastian Grimberg 
32cf8cbdd6SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
33cf8cbdd6SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
34*fe960054SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
35cf8cbdd6SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
36cf8cbdd6SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
37cf8cbdd6SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride));
38*fe960054SJeremy L Thompson   if (rstr_type == CEED_RESTRICTION_POINTS) {
39*fe960054SJeremy L Thompson     CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr, &elem_size));
40*fe960054SJeremy L Thompson   } else {
41*fe960054SJeremy L Thompson     CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
42*fe960054SJeremy L Thompson   }
43cf8cbdd6SSebastian Grimberg   is_deterministic = impl->d_l_vec_indices != NULL;
44cf8cbdd6SSebastian Grimberg 
45cf8cbdd6SSebastian Grimberg   // Compile HIP kernels
46cf8cbdd6SSebastian Grimberg   switch (rstr_type) {
47cf8cbdd6SSebastian Grimberg     case CEED_RESTRICTION_STRIDED: {
48cf8cbdd6SSebastian Grimberg       bool    has_backend_strides;
49509d4af6SJeremy L Thompson       CeedInt strides[3] = {1, num_elem * elem_size, elem_size};
50cf8cbdd6SSebastian Grimberg 
51cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides));
52cf8cbdd6SSebastian Grimberg       if (!has_backend_strides) {
5356c48462SJeremy L Thompson         CeedCallBackend(CeedElemRestrictionGetStrides(rstr, strides));
54cf8cbdd6SSebastian Grimberg       }
55cf8cbdd6SSebastian Grimberg 
56cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-strided.h", &restriction_kernel_path));
57cf8cbdd6SSebastian Grimberg       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
58cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
59cf8cbdd6SSebastian Grimberg       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
60cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem,
61cf8cbdd6SSebastian Grimberg                                       "RSTR_NUM_COMP", num_comp, "RSTR_STRIDE_NODES", strides[0], "RSTR_STRIDE_COMP", strides[1], "RSTR_STRIDE_ELEM",
62cf8cbdd6SSebastian Grimberg                                       strides[2]));
63cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->ApplyNoTranspose));
64cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->ApplyTranspose));
65cf8cbdd6SSebastian Grimberg     } break;
66*fe960054SJeremy L Thompson     case CEED_RESTRICTION_POINTS:
67cf8cbdd6SSebastian Grimberg     case CEED_RESTRICTION_STANDARD: {
68cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &restriction_kernel_path));
69cf8cbdd6SSebastian Grimberg       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
70cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
71cf8cbdd6SSebastian Grimberg       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
72cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem,
73cf8cbdd6SSebastian Grimberg                                       "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride,
74cf8cbdd6SSebastian Grimberg                                       "USE_DETERMINISTIC", is_deterministic ? 1 : 0));
75cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyNoTranspose));
76cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyTranspose));
77cf8cbdd6SSebastian Grimberg     } break;
78cf8cbdd6SSebastian Grimberg     case CEED_RESTRICTION_ORIENTED: {
7922070f95SJeremy L Thompson       const char *offset_kernel_path;
80509d4af6SJeremy L Thompson       char      **file_paths     = NULL;
81509d4af6SJeremy L Thompson       CeedInt     num_file_paths = 0;
82cf8cbdd6SSebastian Grimberg 
83cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-oriented.h", &restriction_kernel_path));
84cf8cbdd6SSebastian Grimberg       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
85509d4af6SJeremy L Thompson       CeedCallBackend(CeedLoadSourceAndInitializeBuffer(ceed, restriction_kernel_path, &num_file_paths, &file_paths, &restriction_kernel_source));
86cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &offset_kernel_path));
87509d4af6SJeremy L Thompson       CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, offset_kernel_path, &num_file_paths, &file_paths, &restriction_kernel_source));
88cf8cbdd6SSebastian Grimberg       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
89cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem,
90cf8cbdd6SSebastian Grimberg                                       "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride,
91cf8cbdd6SSebastian Grimberg                                       "USE_DETERMINISTIC", is_deterministic ? 1 : 0));
92cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedNoTranspose", &impl->ApplyNoTranspose));
93cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyUnsignedNoTranspose));
94cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedTranspose", &impl->ApplyTranspose));
95cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyUnsignedTranspose));
96509d4af6SJeremy L Thompson       // Cleanup
97cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedFree(&offset_kernel_path));
98509d4af6SJeremy L Thompson       for (CeedInt i = 0; i < num_file_paths; i++) CeedCall(CeedFree(&file_paths[i]));
99509d4af6SJeremy L Thompson       CeedCall(CeedFree(&file_paths));
100cf8cbdd6SSebastian Grimberg     } break;
101cf8cbdd6SSebastian Grimberg     case CEED_RESTRICTION_CURL_ORIENTED: {
10222070f95SJeremy L Thompson       const char *offset_kernel_path;
103509d4af6SJeremy L Thompson       char      **file_paths     = NULL;
104509d4af6SJeremy L Thompson       CeedInt     num_file_paths = 0;
105cf8cbdd6SSebastian Grimberg 
106cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-curl-oriented.h", &restriction_kernel_path));
107cf8cbdd6SSebastian Grimberg       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
108509d4af6SJeremy L Thompson       CeedCallBackend(CeedLoadSourceAndInitializeBuffer(ceed, restriction_kernel_path, &num_file_paths, &file_paths, &restriction_kernel_source));
109cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &offset_kernel_path));
110509d4af6SJeremy L Thompson       CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, offset_kernel_path, &num_file_paths, &file_paths, &restriction_kernel_source));
111cf8cbdd6SSebastian Grimberg       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
112cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem,
113cf8cbdd6SSebastian Grimberg                                       "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride,
114cf8cbdd6SSebastian Grimberg                                       "USE_DETERMINISTIC", is_deterministic ? 1 : 0));
115cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedNoTranspose", &impl->ApplyNoTranspose));
116cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedNoTranspose", &impl->ApplyUnsignedNoTranspose));
117cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyUnorientedNoTranspose));
118cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedTranspose", &impl->ApplyTranspose));
119cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedTranspose", &impl->ApplyUnsignedTranspose));
120cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyUnorientedTranspose));
121509d4af6SJeremy L Thompson       // Cleanup
122cf8cbdd6SSebastian Grimberg       CeedCallBackend(CeedFree(&offset_kernel_path));
123509d4af6SJeremy L Thompson       for (CeedInt i = 0; i < num_file_paths; i++) CeedCall(CeedFree(&file_paths[i]));
124509d4af6SJeremy L Thompson       CeedCall(CeedFree(&file_paths));
125cf8cbdd6SSebastian Grimberg     } break;
126cf8cbdd6SSebastian Grimberg   }
127cf8cbdd6SSebastian Grimberg   CeedCallBackend(CeedFree(&restriction_kernel_path));
128cf8cbdd6SSebastian Grimberg   CeedCallBackend(CeedFree(&restriction_kernel_source));
129cf8cbdd6SSebastian Grimberg   return CEED_ERROR_SUCCESS;
130cf8cbdd6SSebastian Grimberg }
131cf8cbdd6SSebastian Grimberg 
132cf8cbdd6SSebastian Grimberg //------------------------------------------------------------------------------
133dce49693SSebastian Grimberg // Core apply restriction code
1340d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
135dce49693SSebastian Grimberg static inline int CeedElemRestrictionApply_Hip_Core(CeedElemRestriction rstr, CeedTransposeMode t_mode, bool use_signs, bool use_orients,
136dce49693SSebastian Grimberg                                                     CeedVector u, CeedVector v, CeedRequest *request) {
1370d0321e0SJeremy L Thompson   Ceed                     ceed;
138dce49693SSebastian Grimberg   CeedRestrictionType      rstr_type;
1390d0321e0SJeremy L Thompson   const CeedScalar        *d_u;
1400d0321e0SJeremy L Thompson   CeedScalar              *d_v;
141b7453713SJeremy L Thompson   CeedElemRestriction_Hip *impl;
142b7453713SJeremy L Thompson 
143dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
144dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
145dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
146cf8cbdd6SSebastian Grimberg 
147cf8cbdd6SSebastian Grimberg   // Assemble kernel if needed
148cf8cbdd6SSebastian Grimberg   if (!impl->module) {
149cf8cbdd6SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionSetupCompile_Hip(rstr));
150cf8cbdd6SSebastian Grimberg   }
151b7453713SJeremy L Thompson 
152b7453713SJeremy L Thompson   // Get vectors
1532b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
154437930d1SJeremy L Thompson   if (t_mode == CEED_TRANSPOSE) {
1550d0321e0SJeremy L Thompson     // Sum into for transpose mode, e-vec to l-vec
1562b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
1570d0321e0SJeremy L Thompson   } else {
1580d0321e0SJeremy L Thompson     // Overwrite for notranspose mode, l-vec to e-vec
1592b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
1600d0321e0SJeremy L Thompson   }
1610d0321e0SJeremy L Thompson 
1620d0321e0SJeremy L Thompson   // Restrict
163437930d1SJeremy L Thompson   if (t_mode == CEED_NOTRANSPOSE) {
1640d0321e0SJeremy L Thompson     // L-vector -> E-vector
165cf8cbdd6SSebastian Grimberg     CeedInt elem_size;
166cf8cbdd6SSebastian Grimberg 
167cf8cbdd6SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
168dce49693SSebastian Grimberg     const CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
169cf8cbdd6SSebastian Grimberg     const CeedInt grid       = CeedDivUpInt(impl->num_nodes, block_size);
17058549094SSebastian Grimberg 
171dce49693SSebastian Grimberg     switch (rstr_type) {
172dce49693SSebastian Grimberg       case CEED_RESTRICTION_STRIDED: {
173cf8cbdd6SSebastian Grimberg         void *args[] = {&d_u, &d_v};
17458549094SSebastian Grimberg 
175cf8cbdd6SSebastian Grimberg         CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args));
176dce49693SSebastian Grimberg       } break;
177*fe960054SJeremy L Thompson       case CEED_RESTRICTION_POINTS:
178dce49693SSebastian Grimberg       case CEED_RESTRICTION_STANDARD: {
179a267acd1SJeremy L Thompson         void *args[] = {&impl->d_offsets, &d_u, &d_v};
180dce49693SSebastian Grimberg 
181cf8cbdd6SSebastian Grimberg         CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args));
182dce49693SSebastian Grimberg       } break;
183dce49693SSebastian Grimberg       case CEED_RESTRICTION_ORIENTED: {
184dce49693SSebastian Grimberg         if (use_signs) {
185a267acd1SJeremy L Thompson           void *args[] = {&impl->d_offsets, &impl->d_orients, &d_u, &d_v};
186dce49693SSebastian Grimberg 
187cf8cbdd6SSebastian Grimberg           CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args));
188dce49693SSebastian Grimberg         } else {
189a267acd1SJeremy L Thompson           void *args[] = {&impl->d_offsets, &d_u, &d_v};
190dce49693SSebastian Grimberg 
191cf8cbdd6SSebastian Grimberg           CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedNoTranspose, grid, block_size, args));
192dce49693SSebastian Grimberg         }
193dce49693SSebastian Grimberg       } break;
194dce49693SSebastian Grimberg       case CEED_RESTRICTION_CURL_ORIENTED: {
195dce49693SSebastian Grimberg         if (use_signs && use_orients) {
196a267acd1SJeremy L Thompson           void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v};
197dce49693SSebastian Grimberg 
198cf8cbdd6SSebastian Grimberg           CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args));
199dce49693SSebastian Grimberg         } else if (use_orients) {
200a267acd1SJeremy L Thompson           void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v};
201dce49693SSebastian Grimberg 
202cf8cbdd6SSebastian Grimberg           CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedNoTranspose, grid, block_size, args));
203dce49693SSebastian Grimberg         } else {
204a267acd1SJeremy L Thompson           void *args[] = {&impl->d_offsets, &d_u, &d_v};
205dce49693SSebastian Grimberg 
206cf8cbdd6SSebastian Grimberg           CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedNoTranspose, grid, block_size, args));
207dce49693SSebastian Grimberg         }
208dce49693SSebastian Grimberg       } break;
2090d0321e0SJeremy L Thompson     }
2100d0321e0SJeremy L Thompson   } else {
2110d0321e0SJeremy L Thompson     // E-vector -> L-vector
212cf8cbdd6SSebastian Grimberg     const bool    is_deterministic = impl->d_l_vec_indices != NULL;
213dce49693SSebastian Grimberg     const CeedInt block_size       = 64;
214cf8cbdd6SSebastian Grimberg     const CeedInt grid             = CeedDivUpInt(impl->num_nodes, block_size);
215b7453713SJeremy L Thompson 
216dce49693SSebastian Grimberg     switch (rstr_type) {
217dce49693SSebastian Grimberg       case CEED_RESTRICTION_STRIDED: {
218cf8cbdd6SSebastian Grimberg         void *args[] = {&d_u, &d_v};
219dce49693SSebastian Grimberg 
220cf8cbdd6SSebastian Grimberg         CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args));
221dce49693SSebastian Grimberg       } break;
222*fe960054SJeremy L Thompson       case CEED_RESTRICTION_POINTS:
223dce49693SSebastian Grimberg       case CEED_RESTRICTION_STANDARD: {
224cf8cbdd6SSebastian Grimberg         if (!is_deterministic) {
225a267acd1SJeremy L Thompson           void *args[] = {&impl->d_offsets, &d_u, &d_v};
22658549094SSebastian Grimberg 
227cf8cbdd6SSebastian Grimberg           CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args));
2280d0321e0SJeremy L Thompson         } else {
22958549094SSebastian Grimberg           void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
23058549094SSebastian Grimberg 
231cf8cbdd6SSebastian Grimberg           CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args));
23258549094SSebastian Grimberg         }
233dce49693SSebastian Grimberg       } break;
234dce49693SSebastian Grimberg       case CEED_RESTRICTION_ORIENTED: {
235dce49693SSebastian Grimberg         if (use_signs) {
236cf8cbdd6SSebastian Grimberg           if (!is_deterministic) {
237a267acd1SJeremy L Thompson             void *args[] = {&impl->d_offsets, &impl->d_orients, &d_u, &d_v};
23858549094SSebastian Grimberg 
239cf8cbdd6SSebastian Grimberg             CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args));
240dce49693SSebastian Grimberg           } else {
2417aa91133SSebastian Grimberg             void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_orients, &d_u, &d_v};
2427aa91133SSebastian Grimberg 
243cf8cbdd6SSebastian Grimberg             CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args));
2447aa91133SSebastian Grimberg           }
2457aa91133SSebastian Grimberg         } else {
246cf8cbdd6SSebastian Grimberg           if (!is_deterministic) {
247a267acd1SJeremy L Thompson             void *args[] = {&impl->d_offsets, &d_u, &d_v};
248dce49693SSebastian Grimberg 
249cf8cbdd6SSebastian Grimberg             CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args));
250dce49693SSebastian Grimberg           } else {
251dce49693SSebastian Grimberg             void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
252dce49693SSebastian Grimberg 
253cf8cbdd6SSebastian Grimberg             CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args));
254dce49693SSebastian Grimberg           }
255dce49693SSebastian Grimberg         }
256dce49693SSebastian Grimberg       } break;
257dce49693SSebastian Grimberg       case CEED_RESTRICTION_CURL_ORIENTED: {
258dce49693SSebastian Grimberg         if (use_signs && use_orients) {
259cf8cbdd6SSebastian Grimberg           if (!is_deterministic) {
260a267acd1SJeremy L Thompson             void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v};
261dce49693SSebastian Grimberg 
262cf8cbdd6SSebastian Grimberg             CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args));
2637aa91133SSebastian Grimberg           } else {
2647aa91133SSebastian Grimberg             void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_curl_orients, &d_u, &d_v};
2657aa91133SSebastian Grimberg 
266cf8cbdd6SSebastian Grimberg             CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args));
2677aa91133SSebastian Grimberg           }
268dce49693SSebastian Grimberg         } else if (use_orients) {
269cf8cbdd6SSebastian Grimberg           if (!is_deterministic) {
270a267acd1SJeremy L Thompson             void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v};
271dce49693SSebastian Grimberg 
272cf8cbdd6SSebastian Grimberg             CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args));
273dce49693SSebastian Grimberg           } else {
2747aa91133SSebastian Grimberg             void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_curl_orients, &d_u, &d_v};
2757aa91133SSebastian Grimberg 
276cf8cbdd6SSebastian Grimberg             CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args));
2777aa91133SSebastian Grimberg           }
2787aa91133SSebastian Grimberg         } else {
279cf8cbdd6SSebastian Grimberg           if (!is_deterministic) {
280a267acd1SJeremy L Thompson             void *args[] = {&impl->d_offsets, &d_u, &d_v};
281dce49693SSebastian Grimberg 
282cf8cbdd6SSebastian Grimberg             CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedTranspose, grid, block_size, args));
283dce49693SSebastian Grimberg           } else {
284dce49693SSebastian Grimberg             void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
285dce49693SSebastian Grimberg 
286cf8cbdd6SSebastian Grimberg             CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedTranspose, grid, block_size, args));
287dce49693SSebastian Grimberg           }
288dce49693SSebastian Grimberg         }
289dce49693SSebastian Grimberg       } break;
2900d0321e0SJeremy L Thompson     }
2910d0321e0SJeremy L Thompson   }
2920d0321e0SJeremy L Thompson 
2932b730f8bSJeremy L Thompson   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
2940d0321e0SJeremy L Thompson 
2950d0321e0SJeremy L Thompson   // Restore arrays
2962b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
2972b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
2980d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2990d0321e0SJeremy L Thompson }
3000d0321e0SJeremy L Thompson 
3010d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
302dce49693SSebastian Grimberg // Apply restriction
303dce49693SSebastian Grimberg //------------------------------------------------------------------------------
304dce49693SSebastian Grimberg static int CeedElemRestrictionApply_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
305dce49693SSebastian Grimberg   return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, true, true, u, v, request);
306dce49693SSebastian Grimberg }
307dce49693SSebastian Grimberg 
308dce49693SSebastian Grimberg //------------------------------------------------------------------------------
309dce49693SSebastian Grimberg // Apply unsigned restriction
310dce49693SSebastian Grimberg //------------------------------------------------------------------------------
311dce49693SSebastian Grimberg static int CeedElemRestrictionApplyUnsigned_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v,
312dce49693SSebastian Grimberg                                                 CeedRequest *request) {
313dce49693SSebastian Grimberg   return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, true, u, v, request);
314dce49693SSebastian Grimberg }
315dce49693SSebastian Grimberg 
316dce49693SSebastian Grimberg //------------------------------------------------------------------------------
317dce49693SSebastian Grimberg // Apply unoriented restriction
318dce49693SSebastian Grimberg //------------------------------------------------------------------------------
319dce49693SSebastian Grimberg static int CeedElemRestrictionApplyUnoriented_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v,
320dce49693SSebastian Grimberg                                                   CeedRequest *request) {
321dce49693SSebastian Grimberg   return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, false, u, v, request);
322dce49693SSebastian Grimberg }
323dce49693SSebastian Grimberg 
324dce49693SSebastian Grimberg //------------------------------------------------------------------------------
3250d0321e0SJeremy L Thompson // Get offsets
3260d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
327472941f0SJeremy L Thompson static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) {
3280d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
329*fe960054SJeremy L Thompson   CeedRestrictionType      rstr_type;
3300d0321e0SJeremy L Thompson 
331b7453713SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
332*fe960054SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
333472941f0SJeremy L Thompson   switch (mem_type) {
3340d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
335*fe960054SJeremy L Thompson       *offsets = rstr_type == CEED_RESTRICTION_POINTS ? impl->h_offsets_at_points : impl->h_offsets;
3360d0321e0SJeremy L Thompson       break;
3370d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
338*fe960054SJeremy L Thompson       *offsets = rstr_type == CEED_RESTRICTION_POINTS ? impl->d_offsets_at_points : impl->d_offsets;
3390d0321e0SJeremy L Thompson       break;
3400d0321e0SJeremy L Thompson   }
3410d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
3420d0321e0SJeremy L Thompson }
3430d0321e0SJeremy L Thompson 
3440d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
345dce49693SSebastian Grimberg // Get orientations
346dce49693SSebastian Grimberg //------------------------------------------------------------------------------
347dce49693SSebastian Grimberg static int CeedElemRestrictionGetOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const bool **orients) {
348dce49693SSebastian Grimberg   CeedElemRestriction_Hip *impl;
349dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
350dce49693SSebastian Grimberg 
351dce49693SSebastian Grimberg   switch (mem_type) {
352dce49693SSebastian Grimberg     case CEED_MEM_HOST:
353dce49693SSebastian Grimberg       *orients = impl->h_orients;
354dce49693SSebastian Grimberg       break;
355dce49693SSebastian Grimberg     case CEED_MEM_DEVICE:
356dce49693SSebastian Grimberg       *orients = impl->d_orients;
357dce49693SSebastian Grimberg       break;
358dce49693SSebastian Grimberg   }
359dce49693SSebastian Grimberg   return CEED_ERROR_SUCCESS;
360dce49693SSebastian Grimberg }
361dce49693SSebastian Grimberg 
362dce49693SSebastian Grimberg //------------------------------------------------------------------------------
363dce49693SSebastian Grimberg // Get curl-conforming orientations
364dce49693SSebastian Grimberg //------------------------------------------------------------------------------
365dce49693SSebastian Grimberg static int CeedElemRestrictionGetCurlOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt8 **curl_orients) {
366dce49693SSebastian Grimberg   CeedElemRestriction_Hip *impl;
367dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
368dce49693SSebastian Grimberg 
369dce49693SSebastian Grimberg   switch (mem_type) {
370dce49693SSebastian Grimberg     case CEED_MEM_HOST:
371dce49693SSebastian Grimberg       *curl_orients = impl->h_curl_orients;
372dce49693SSebastian Grimberg       break;
373dce49693SSebastian Grimberg     case CEED_MEM_DEVICE:
374dce49693SSebastian Grimberg       *curl_orients = impl->d_curl_orients;
375dce49693SSebastian Grimberg       break;
376dce49693SSebastian Grimberg   }
377dce49693SSebastian Grimberg   return CEED_ERROR_SUCCESS;
378dce49693SSebastian Grimberg }
379dce49693SSebastian Grimberg 
380dce49693SSebastian Grimberg //------------------------------------------------------------------------------
381*fe960054SJeremy L Thompson // Get offset for padded AtPoints E-layout
382*fe960054SJeremy L Thompson //------------------------------------------------------------------------------
383*fe960054SJeremy L Thompson static int CeedElemRestrictionGetAtPointsElementOffset_Hip(CeedElemRestriction rstr, CeedInt elem, CeedSize *elem_offset) {
384*fe960054SJeremy L Thompson   CeedInt layout[3];
385*fe960054SJeremy L Thompson 
386*fe960054SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetELayout(rstr, layout));
387*fe960054SJeremy L Thompson   *elem_offset = 0 * layout[0] + 0 * layout[1] + elem * layout[2];
388*fe960054SJeremy L Thompson   return CEED_ERROR_SUCCESS;
389*fe960054SJeremy L Thompson }
390*fe960054SJeremy L Thompson 
391*fe960054SJeremy L Thompson //------------------------------------------------------------------------------
3920d0321e0SJeremy L Thompson // Destroy restriction
3930d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
394dce49693SSebastian Grimberg static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction rstr) {
3950d0321e0SJeremy L Thompson   Ceed                     ceed;
396b7453713SJeremy L Thompson   CeedElemRestriction_Hip *impl;
397b7453713SJeremy L Thompson 
398dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
399dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
400cf8cbdd6SSebastian Grimberg   if (impl->module) {
4012b730f8bSJeremy L Thompson     CeedCallHip(ceed, hipModuleUnload(impl->module));
402cf8cbdd6SSebastian Grimberg   }
403a267acd1SJeremy L Thompson   CeedCallBackend(CeedFree(&impl->h_offsets_owned));
404f5d1e504SJeremy L Thompson   CeedCallHip(ceed, hipFree((CeedInt *)impl->d_offsets_owned));
405081aa29dSJeremy L Thompson   CeedCallHip(ceed, hipFree((CeedInt *)impl->d_t_offsets));
406081aa29dSJeremy L Thompson   CeedCallHip(ceed, hipFree((CeedInt *)impl->d_t_indices));
407081aa29dSJeremy L Thompson   CeedCallHip(ceed, hipFree((CeedInt *)impl->d_l_vec_indices));
408a267acd1SJeremy L Thompson   CeedCallBackend(CeedFree(&impl->h_orients_owned));
409f5d1e504SJeremy L Thompson   CeedCallHip(ceed, hipFree((bool *)impl->d_orients_owned));
410a267acd1SJeremy L Thompson   CeedCallBackend(CeedFree(&impl->h_curl_orients_owned));
411f5d1e504SJeremy L Thompson   CeedCallHip(ceed, hipFree((CeedInt8 *)impl->d_curl_orients_owned));
412*fe960054SJeremy L Thompson   CeedCallBackend(CeedFree(&impl->h_offsets_at_points_owned));
413*fe960054SJeremy L Thompson   CeedCallHip(ceed, hipFree((CeedInt8 *)impl->d_offsets_at_points_owned));
4142b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&impl));
4150d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
4160d0321e0SJeremy L Thompson }
4170d0321e0SJeremy L Thompson 
4180d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
4190d0321e0SJeremy L Thompson // Create transpose offsets and indices
4200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
421*fe960054SJeremy L Thompson static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction rstr, const CeedInt elem_size, const CeedInt *indices) {
4220d0321e0SJeremy L Thompson   Ceed                     ceed;
423b7453713SJeremy L Thompson   bool                    *is_node;
424e79b91d9SJeremy L Thompson   CeedSize                 l_size;
425*fe960054SJeremy L Thompson   CeedInt                  num_elem, num_comp, num_nodes = 0;
426dce49693SSebastian Grimberg   CeedInt                 *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices;
427*fe960054SJeremy L Thompson   CeedRestrictionType      rstr_type;
428b7453713SJeremy L Thompson   CeedElemRestriction_Hip *impl;
429b7453713SJeremy L Thompson 
430dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
431dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
432dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
433*fe960054SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
434dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size));
435dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp));
436b7453713SJeremy L Thompson   const CeedInt size_indices = num_elem * elem_size;
4370d0321e0SJeremy L Thompson 
438437930d1SJeremy L Thompson   // Count num_nodes
4392b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(l_size, &is_node));
440dce49693SSebastian Grimberg 
4412b730f8bSJeremy L Thompson   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
4422b730f8bSJeremy L Thompson   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
443437930d1SJeremy L Thompson   impl->num_nodes = num_nodes;
4440d0321e0SJeremy L Thompson 
4450d0321e0SJeremy L Thompson   // L-vector offsets array
4462b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
4472b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
448b7453713SJeremy L Thompson   for (CeedInt i = 0, j = 0; i < l_size; i++) {
449437930d1SJeremy L Thompson     if (is_node[i]) {
450437930d1SJeremy L Thompson       l_vec_indices[j] = i;
4510d0321e0SJeremy L Thompson       ind_to_offset[i] = j++;
4520d0321e0SJeremy L Thompson     }
4532b730f8bSJeremy L Thompson   }
4542b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&is_node));
4550d0321e0SJeremy L Thompson 
4560d0321e0SJeremy L Thompson   // Compute transpose offsets and indices
457437930d1SJeremy L Thompson   const CeedInt size_offsets = num_nodes + 1;
458b7453713SJeremy L Thompson 
4592b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
4602b730f8bSJeremy L Thompson   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
4610d0321e0SJeremy L Thompson   // Count node multiplicity
4622b730f8bSJeremy L Thompson   for (CeedInt e = 0; e < num_elem; ++e) {
4632b730f8bSJeremy L Thompson     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
4642b730f8bSJeremy L Thompson   }
4650d0321e0SJeremy L Thompson   // Convert to running sum
4662b730f8bSJeremy L Thompson   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
4670d0321e0SJeremy L Thompson   // List all E-vec indices associated with L-vec node
468437930d1SJeremy L Thompson   for (CeedInt e = 0; e < num_elem; ++e) {
469437930d1SJeremy L Thompson     for (CeedInt i = 0; i < elem_size; ++i) {
470437930d1SJeremy L Thompson       const CeedInt lid = elem_size * e + i;
4710d0321e0SJeremy L Thompson       const CeedInt gid = indices[lid];
472b7453713SJeremy L Thompson 
473437930d1SJeremy L Thompson       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
4740d0321e0SJeremy L Thompson     }
4750d0321e0SJeremy L Thompson   }
4760d0321e0SJeremy L Thompson   // Reset running sum
4772b730f8bSJeremy L Thompson   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
478437930d1SJeremy L Thompson   t_offsets[0] = 0;
4790d0321e0SJeremy L Thompson 
4800d0321e0SJeremy L Thompson   // Copy data to device
4810d0321e0SJeremy L Thompson   // -- L-vector indices
4822b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt)));
483081aa29dSJeremy L Thompson   CeedCallHip(ceed, hipMemcpy((CeedInt *)impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice));
4840d0321e0SJeremy L Thompson   // -- Transpose offsets
4852b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt)));
486081aa29dSJeremy L Thompson   CeedCallHip(ceed, hipMemcpy((CeedInt *)impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice));
4870d0321e0SJeremy L Thompson   // -- Transpose indices
4882b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt)));
489081aa29dSJeremy L Thompson   CeedCallHip(ceed, hipMemcpy((CeedInt *)impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice));
4900d0321e0SJeremy L Thompson 
4910d0321e0SJeremy L Thompson   // Cleanup
4922b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&ind_to_offset));
4932b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&l_vec_indices));
4942b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&t_offsets));
4952b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&t_indices));
4960d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
4970d0321e0SJeremy L Thompson }
4980d0321e0SJeremy L Thompson 
4990d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
5000d0321e0SJeremy L Thompson // Create restriction
5010d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
502a267acd1SJeremy L Thompson int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *offsets, const bool *orients,
503dce49693SSebastian Grimberg                                   const CeedInt8 *curl_orients, CeedElemRestriction rstr) {
504b7453713SJeremy L Thompson   Ceed                     ceed, ceed_parent;
505dce49693SSebastian Grimberg   bool                     is_deterministic;
506cf8cbdd6SSebastian Grimberg   CeedInt                  num_elem, elem_size;
507b7453713SJeremy L Thompson   CeedRestrictionType      rstr_type;
5080d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
509b7453713SJeremy L Thompson 
510dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed));
511ca735530SJeremy L Thompson   CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
512ca735530SJeremy L Thompson   CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic));
513dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem));
514dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
51522eb1385SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type));
516*fe960054SJeremy L Thompson   // Use max number of points as elem size for AtPoints restrictions
517*fe960054SJeremy L Thompson   if (rstr_type == CEED_RESTRICTION_POINTS) {
518*fe960054SJeremy L Thompson     CeedInt max_points = 0;
519*fe960054SJeremy L Thompson 
520*fe960054SJeremy L Thompson     for (CeedInt i = 0; i < num_elem; i++) {
521*fe960054SJeremy L Thompson       max_points = CeedIntMax(max_points, offsets[i + 1] - offsets[i]);
522*fe960054SJeremy L Thompson     }
523*fe960054SJeremy L Thompson     elem_size = max_points;
524*fe960054SJeremy L Thompson   }
525dce49693SSebastian Grimberg   const CeedInt size = num_elem * elem_size;
5260d0321e0SJeremy L Thompson 
527dce49693SSebastian Grimberg   CeedCallBackend(CeedCalloc(1, &impl));
528dce49693SSebastian Grimberg   impl->num_nodes = size;
529dce49693SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionSetData(rstr, impl));
53022eb1385SJeremy L Thompson 
53122eb1385SJeremy L Thompson   // Set layouts
53222eb1385SJeremy L Thompson   {
53322eb1385SJeremy L Thompson     bool    has_backend_strides;
53422eb1385SJeremy L Thompson     CeedInt layout[3] = {1, size, elem_size};
53522eb1385SJeremy L Thompson 
536dce49693SSebastian Grimberg     CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout));
53722eb1385SJeremy L Thompson     if (rstr_type == CEED_RESTRICTION_STRIDED) {
53822eb1385SJeremy L Thompson       CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides));
53922eb1385SJeremy L Thompson       if (has_backend_strides) {
54022eb1385SJeremy L Thompson         CeedCallBackend(CeedElemRestrictionSetLLayout(rstr, layout));
54122eb1385SJeremy L Thompson       }
54222eb1385SJeremy L Thompson     }
54322eb1385SJeremy L Thompson   }
5440d0321e0SJeremy L Thompson 
545*fe960054SJeremy L Thompson   // Pad AtPoints indices
546*fe960054SJeremy L Thompson   if (rstr_type == CEED_RESTRICTION_POINTS) {
547*fe960054SJeremy L Thompson     CeedSize offsets_len = elem_size * num_elem, at_points_size = num_elem + 1;
548*fe960054SJeremy L Thompson     CeedInt  max_points = elem_size, *offsets_padded;
549*fe960054SJeremy L Thompson 
550*fe960054SJeremy L Thompson     CeedCheck(mem_type == CEED_MEM_HOST, ceed, CEED_ERROR_BACKEND, "only MemType Host supported when creating AtPoints restriction");
551*fe960054SJeremy L Thompson     CeedCallBackend(CeedMalloc(offsets_len, &offsets_padded));
552*fe960054SJeremy L Thompson     for (CeedInt i = 0; i < num_elem; i++) {
553*fe960054SJeremy L Thompson       CeedInt num_points = offsets[i + 1] - offsets[i];
554*fe960054SJeremy L Thompson 
555*fe960054SJeremy L Thompson       at_points_size += num_points;
556*fe960054SJeremy L Thompson       // -- Copy all points in element
557*fe960054SJeremy L Thompson       for (CeedInt j = 0; j < num_points; j++) {
558*fe960054SJeremy L Thompson         offsets_padded[i * max_points + j] = offsets[offsets[i] + j];
559*fe960054SJeremy L Thompson       }
560*fe960054SJeremy L Thompson       // -- Replicate out last point in element
561*fe960054SJeremy L Thompson       for (CeedInt j = num_points; j < max_points; j++) {
562*fe960054SJeremy L Thompson         offsets_padded[i * max_points + j] = offsets[offsets[i] + num_points - 1];
563*fe960054SJeremy L Thompson       }
564*fe960054SJeremy L Thompson     }
565*fe960054SJeremy L Thompson     CeedCallBackend(CeedSetHostCeedIntArray(offsets, copy_mode, at_points_size, &impl->h_offsets_at_points_owned, &impl->h_offsets_at_points_borrowed,
566*fe960054SJeremy L Thompson                                             &impl->h_offsets_at_points));
567*fe960054SJeremy L Thompson     CeedCallHip(ceed, hipMalloc((void **)&impl->d_offsets_at_points_owned, at_points_size * sizeof(CeedInt)));
568*fe960054SJeremy L Thompson     CeedCallHip(ceed, hipMemcpy((CeedInt **)impl->d_offsets_at_points_owned, impl->h_offsets_at_points, at_points_size * sizeof(CeedInt),
569*fe960054SJeremy L Thompson                                 hipMemcpyHostToDevice));
570*fe960054SJeremy L Thompson     impl->d_offsets_at_points = (CeedInt *)impl->d_offsets_at_points_owned;
571*fe960054SJeremy L Thompson     // -- Use padded offsets for the rest of the setup
572*fe960054SJeremy L Thompson     offsets   = (const CeedInt *)offsets_padded;
573*fe960054SJeremy L Thompson     copy_mode = CEED_OWN_POINTER;
574*fe960054SJeremy L Thompson   }
575*fe960054SJeremy L Thompson 
576dce49693SSebastian Grimberg   // Set up device offset/orientation arrays
577dce49693SSebastian Grimberg   if (rstr_type != CEED_RESTRICTION_STRIDED) {
578472941f0SJeremy L Thompson     switch (mem_type) {
5796574a04fSJeremy L Thompson       case CEED_MEM_HOST: {
580f5d1e504SJeremy L Thompson         CeedCallBackend(CeedSetHostCeedIntArray(offsets, copy_mode, size, &impl->h_offsets_owned, &impl->h_offsets_borrowed, &impl->h_offsets));
581a267acd1SJeremy L Thompson         CeedCallHip(ceed, hipMalloc((void **)&impl->d_offsets_owned, size * sizeof(CeedInt)));
582f5d1e504SJeremy L Thompson         CeedCallHip(ceed, hipMemcpy((CeedInt **)impl->d_offsets_owned, impl->h_offsets, size * sizeof(CeedInt), hipMemcpyHostToDevice));
583f5d1e504SJeremy L Thompson         impl->d_offsets = (CeedInt *)impl->d_offsets_owned;
584*fe960054SJeremy L Thompson         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, elem_size, offsets));
585dce49693SSebastian Grimberg       } break;
5866574a04fSJeremy L Thompson       case CEED_MEM_DEVICE: {
587f5d1e504SJeremy L Thompson         CeedCallBackend(CeedSetDeviceCeedIntArray_Hip(ceed, offsets, copy_mode, size, &impl->d_offsets_owned, &impl->d_offsets_borrowed,
588f5d1e504SJeremy L Thompson                                                       (const CeedInt **)&impl->d_offsets));
589a267acd1SJeremy L Thompson         CeedCallBackend(CeedMalloc(size, &impl->h_offsets_owned));
590f5d1e504SJeremy L Thompson         CeedCallHip(ceed, hipMemcpy((CeedInt **)impl->h_offsets_owned, impl->d_offsets, size * sizeof(CeedInt), hipMemcpyDeviceToHost));
591a267acd1SJeremy L Thompson         impl->h_offsets = impl->h_offsets_owned;
592*fe960054SJeremy L Thompson         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, elem_size, offsets));
593dce49693SSebastian Grimberg       } break;
594dce49693SSebastian Grimberg     }
595dce49693SSebastian Grimberg 
596dce49693SSebastian Grimberg     // Orientation data
597dce49693SSebastian Grimberg     if (rstr_type == CEED_RESTRICTION_ORIENTED) {
598dce49693SSebastian Grimberg       switch (mem_type) {
599dce49693SSebastian Grimberg         case CEED_MEM_HOST: {
600f5d1e504SJeremy L Thompson           CeedCallBackend(CeedSetHostBoolArray(orients, copy_mode, size, &impl->h_orients_owned, &impl->h_orients_borrowed, &impl->h_orients));
601a267acd1SJeremy L Thompson           CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients_owned, size * sizeof(bool)));
602f5d1e504SJeremy L Thompson           CeedCallHip(ceed, hipMemcpy((bool *)impl->d_orients_owned, impl->h_orients, size * sizeof(bool), hipMemcpyHostToDevice));
603a267acd1SJeremy L Thompson           impl->d_orients = impl->d_orients_owned;
604dce49693SSebastian Grimberg         } break;
605dce49693SSebastian Grimberg         case CEED_MEM_DEVICE: {
606f5d1e504SJeremy L Thompson           CeedCallBackend(CeedSetDeviceBoolArray_Hip(ceed, orients, copy_mode, size, &impl->d_orients_owned, &impl->d_orients_borrowed,
607f5d1e504SJeremy L Thompson                                                      (const bool **)&impl->d_orients));
608a267acd1SJeremy L Thompson           CeedCallBackend(CeedMalloc(size, &impl->h_orients_owned));
609f5d1e504SJeremy L Thompson           CeedCallHip(ceed, hipMemcpy((bool *)impl->h_orients_owned, impl->d_orients, size * sizeof(bool), hipMemcpyDeviceToHost));
610a267acd1SJeremy L Thompson           impl->h_orients = impl->h_orients_owned;
611dce49693SSebastian Grimberg         } break;
612dce49693SSebastian Grimberg       }
613dce49693SSebastian Grimberg     } else if (rstr_type == CEED_RESTRICTION_CURL_ORIENTED) {
614dce49693SSebastian Grimberg       switch (mem_type) {
615dce49693SSebastian Grimberg         case CEED_MEM_HOST: {
616f5d1e504SJeremy L Thompson           CeedCallBackend(CeedSetHostCeedInt8Array(curl_orients, copy_mode, 3 * size, &impl->h_curl_orients_owned, &impl->h_curl_orients_borrowed,
617f5d1e504SJeremy L Thompson                                                    &impl->h_curl_orients));
618a267acd1SJeremy L Thompson           CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients_owned, 3 * size * sizeof(CeedInt8)));
619f5d1e504SJeremy L Thompson           CeedCallHip(ceed,
620f5d1e504SJeremy L Thompson                       hipMemcpy((CeedInt8 *)impl->d_curl_orients_owned, impl->h_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyHostToDevice));
621a267acd1SJeremy L Thompson           impl->d_curl_orients = impl->d_curl_orients_owned;
622dce49693SSebastian Grimberg         } break;
623dce49693SSebastian Grimberg         case CEED_MEM_DEVICE: {
624f5d1e504SJeremy L Thompson           CeedCallBackend(CeedSetDeviceCeedInt8Array_Hip(ceed, curl_orients, copy_mode, 3 * size, &impl->d_curl_orients_owned,
625f5d1e504SJeremy L Thompson                                                          &impl->d_curl_orients_borrowed, (const CeedInt8 **)&impl->d_curl_orients));
626a267acd1SJeremy L Thompson           CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_owned));
627f5d1e504SJeremy L Thompson           CeedCallHip(ceed,
628f5d1e504SJeremy L Thompson                       hipMemcpy((CeedInt8 *)impl->h_curl_orients_owned, impl->d_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToHost));
629a267acd1SJeremy L Thompson           impl->h_curl_orients = impl->h_curl_orients_owned;
630dce49693SSebastian Grimberg         } break;
631dce49693SSebastian Grimberg       }
632dce49693SSebastian Grimberg     }
6330d0321e0SJeremy L Thompson   }
6340d0321e0SJeremy L Thompson 
6350d0321e0SJeremy L Thompson   // Register backend functions
636dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Hip));
637dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApplyUnsigned_Hip));
638dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApplyUnoriented_Hip));
639dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Hip));
640dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOrientations", CeedElemRestrictionGetOrientations_Hip));
641dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetCurlOrientations", CeedElemRestrictionGetCurlOrientations_Hip));
642*fe960054SJeremy L Thompson   if (rstr_type == CEED_RESTRICTION_POINTS) {
643*fe960054SJeremy L Thompson     CeedCallBackend(
644*fe960054SJeremy L Thompson         CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetAtPointsElementOffset", CeedElemRestrictionGetAtPointsElementOffset_Hip));
645*fe960054SJeremy L Thompson   }
646dce49693SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Hip));
6470d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
6480d0321e0SJeremy L Thompson }
6490d0321e0SJeremy L Thompson 
6500d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
651