xref: /libCEED/rust/libceed-sys/c-src/backends/hip-ref/ceed-hip-ref-restriction.c (revision b7453713e95c1c6eb59ce174cbcb87227e92884e)
13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
30d0321e0SJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
50d0321e0SJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
70d0321e0SJeremy L Thompson 
849aac155SJeremy L Thompson #include <ceed.h>
90d0321e0SJeremy L Thompson #include <ceed/backend.h>
10437930d1SJeremy L Thompson #include <ceed/jit-tools.h>
110d0321e0SJeremy L Thompson #include <stdbool.h>
120d0321e0SJeremy L Thompson #include <stddef.h>
1344d7a66cSJeremy L Thompson #include <string.h>
14c85e8640SSebastian Grimberg #include <hip/hip_runtime.h>
152b730f8bSJeremy L Thompson 
1649aac155SJeremy L Thompson #include "../hip/ceed-hip-common.h"
170d0321e0SJeremy L Thompson #include "../hip/ceed-hip-compile.h"
182b730f8bSJeremy L Thompson #include "ceed-hip-ref.h"
190d0321e0SJeremy L Thompson 
200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
210d0321e0SJeremy L Thompson // Apply restriction
220d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
232b730f8bSJeremy L Thompson static int CeedElemRestrictionApply_Hip(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
240d0321e0SJeremy L Thompson   Ceed                     ceed;
250d0321e0SJeremy L Thompson   Ceed_Hip                *data;
26437930d1SJeremy L Thompson   CeedInt                  num_elem, elem_size;
270d0321e0SJeremy L Thompson   const CeedScalar        *d_u;
280d0321e0SJeremy L Thompson   CeedScalar              *d_v;
29*b7453713SJeremy L Thompson   CeedElemRestriction_Hip *impl;
30*b7453713SJeremy L Thompson   hipFunction_t            kernel;
31*b7453713SJeremy L Thompson 
32*b7453713SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
33*b7453713SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
34*b7453713SJeremy L Thompson   CeedCallBackend(CeedGetData(ceed, &data));
35*b7453713SJeremy L Thompson   CeedElemRestrictionGetNumElements(r, &num_elem);
36*b7453713SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
37*b7453713SJeremy L Thompson   const CeedInt num_nodes = impl->num_nodes;
38*b7453713SJeremy L Thompson 
39*b7453713SJeremy L Thompson   // Get vectors
402b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
41437930d1SJeremy L Thompson   if (t_mode == CEED_TRANSPOSE) {
420d0321e0SJeremy L Thompson     // Sum into for transpose mode, e-vec to l-vec
432b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
440d0321e0SJeremy L Thompson   } else {
450d0321e0SJeremy L Thompson     // Overwrite for notranspose mode, l-vec to e-vec
462b730f8bSJeremy L Thompson     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
470d0321e0SJeremy L Thompson   }
480d0321e0SJeremy L Thompson 
490d0321e0SJeremy L Thompson   // Restrict
50437930d1SJeremy L Thompson   if (t_mode == CEED_NOTRANSPOSE) {
510d0321e0SJeremy L Thompson     // L-vector -> E-vector
520d0321e0SJeremy L Thompson     if (impl->d_ind) {
530d0321e0SJeremy L Thompson       // -- Offsets provided
54437930d1SJeremy L Thompson       kernel             = impl->OffsetNoTranspose;
55437930d1SJeremy L Thompson       void   *args[]     = {&num_elem, &impl->d_ind, &d_u, &d_v};
56437930d1SJeremy L Thompson       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
5758549094SSebastian Grimberg 
58eb7e6cafSJeremy L Thompson       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
590d0321e0SJeremy L Thompson     } else {
600d0321e0SJeremy L Thompson       // -- Strided restriction
61437930d1SJeremy L Thompson       kernel             = impl->StridedNoTranspose;
62437930d1SJeremy L Thompson       void   *args[]     = {&num_elem, &d_u, &d_v};
63437930d1SJeremy L Thompson       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
6458549094SSebastian Grimberg 
65eb7e6cafSJeremy L Thompson       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
660d0321e0SJeremy L Thompson     }
670d0321e0SJeremy L Thompson   } else {
680d0321e0SJeremy L Thompson     // E-vector -> L-vector
690d0321e0SJeremy L Thompson     if (impl->d_ind) {
700d0321e0SJeremy L Thompson       // -- Offsets provided
7158549094SSebastian Grimberg       CeedInt block_size = 64;
72*b7453713SJeremy L Thompson 
7358549094SSebastian Grimberg       if (impl->OffsetTranspose) {
74437930d1SJeremy L Thompson         kernel       = impl->OffsetTranspose;
7558549094SSebastian Grimberg         void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
7658549094SSebastian Grimberg 
77eb7e6cafSJeremy L Thompson         CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
780d0321e0SJeremy L Thompson       } else {
7958549094SSebastian Grimberg         kernel       = impl->OffsetTransposeDet;
8058549094SSebastian Grimberg         void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
8158549094SSebastian Grimberg 
8258549094SSebastian Grimberg         CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
8358549094SSebastian Grimberg       }
8458549094SSebastian Grimberg     } else {
850d0321e0SJeremy L Thompson       // -- Strided restriction
86437930d1SJeremy L Thompson       kernel             = impl->StridedTranspose;
87437930d1SJeremy L Thompson       void   *args[]     = {&num_elem, &d_u, &d_v};
8858549094SSebastian Grimberg       CeedInt block_size = 64;
8958549094SSebastian Grimberg 
90eb7e6cafSJeremy L Thompson       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
910d0321e0SJeremy L Thompson     }
920d0321e0SJeremy L Thompson   }
930d0321e0SJeremy L Thompson 
942b730f8bSJeremy L Thompson   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
950d0321e0SJeremy L Thompson 
960d0321e0SJeremy L Thompson   // Restore arrays
972b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
982b730f8bSJeremy L Thompson   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
990d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1000d0321e0SJeremy L Thompson }
1010d0321e0SJeremy L Thompson 
1020d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1030d0321e0SJeremy L Thompson // Get offsets
1040d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
105472941f0SJeremy L Thompson static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) {
1060d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
1070d0321e0SJeremy L Thompson 
108*b7453713SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
109472941f0SJeremy L Thompson   switch (mem_type) {
1100d0321e0SJeremy L Thompson     case CEED_MEM_HOST:
1110d0321e0SJeremy L Thompson       *offsets = impl->h_ind;
1120d0321e0SJeremy L Thompson       break;
1130d0321e0SJeremy L Thompson     case CEED_MEM_DEVICE:
1140d0321e0SJeremy L Thompson       *offsets = impl->d_ind;
1150d0321e0SJeremy L Thompson       break;
1160d0321e0SJeremy L Thompson   }
1170d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1180d0321e0SJeremy L Thompson }
1190d0321e0SJeremy L Thompson 
1200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1210d0321e0SJeremy L Thompson // Destroy restriction
1220d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1230d0321e0SJeremy L Thompson static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) {
1240d0321e0SJeremy L Thompson   Ceed                     ceed;
125*b7453713SJeremy L Thompson   CeedElemRestriction_Hip *impl;
126*b7453713SJeremy L Thompson 
127*b7453713SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
1282b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
1292b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipModuleUnload(impl->module));
1302b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
1312b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipFree(impl->d_ind_allocated));
1322b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipFree(impl->d_t_offsets));
1332b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipFree(impl->d_t_indices));
1342b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipFree(impl->d_l_vec_indices));
1352b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&impl));
1360d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1370d0321e0SJeremy L Thompson }
1380d0321e0SJeremy L Thompson 
1390d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1400d0321e0SJeremy L Thompson // Create transpose offsets and indices
1410d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1422b730f8bSJeremy L Thompson static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r, const CeedInt *indices) {
1430d0321e0SJeremy L Thompson   Ceed                     ceed;
144*b7453713SJeremy L Thompson   bool                    *is_node;
145e79b91d9SJeremy L Thompson   CeedSize                 l_size;
146*b7453713SJeremy L Thompson   CeedInt                  num_elem, elem_size, num_comp, num_nodes = 0, *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices;
147*b7453713SJeremy L Thompson   CeedElemRestriction_Hip *impl;
148*b7453713SJeremy L Thompson 
149*b7453713SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
150*b7453713SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
1512b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
1522b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
1532b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
1542b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
155*b7453713SJeremy L Thompson   const CeedInt size_indices = num_elem * elem_size;
1560d0321e0SJeremy L Thompson 
157437930d1SJeremy L Thompson   // Count num_nodes
1582b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(l_size, &is_node));
1592b730f8bSJeremy L Thompson   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
1602b730f8bSJeremy L Thompson   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
161437930d1SJeremy L Thompson   impl->num_nodes = num_nodes;
1620d0321e0SJeremy L Thompson 
1630d0321e0SJeremy L Thompson   // L-vector offsets array
1642b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
1652b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
166*b7453713SJeremy L Thompson   for (CeedInt i = 0, j = 0; i < l_size; i++) {
167437930d1SJeremy L Thompson     if (is_node[i]) {
168437930d1SJeremy L Thompson       l_vec_indices[j] = i;
1690d0321e0SJeremy L Thompson       ind_to_offset[i] = j++;
1700d0321e0SJeremy L Thompson     }
1712b730f8bSJeremy L Thompson   }
1722b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&is_node));
1730d0321e0SJeremy L Thompson 
1740d0321e0SJeremy L Thompson   // Compute transpose offsets and indices
175437930d1SJeremy L Thompson   const CeedInt size_offsets = num_nodes + 1;
176*b7453713SJeremy L Thompson 
1772b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
1782b730f8bSJeremy L Thompson   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
1790d0321e0SJeremy L Thompson   // Count node multiplicity
1802b730f8bSJeremy L Thompson   for (CeedInt e = 0; e < num_elem; ++e) {
1812b730f8bSJeremy L Thompson     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
1822b730f8bSJeremy L Thompson   }
1830d0321e0SJeremy L Thompson   // Convert to running sum
1842b730f8bSJeremy L Thompson   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
1850d0321e0SJeremy L Thompson   // List all E-vec indices associated with L-vec node
186437930d1SJeremy L Thompson   for (CeedInt e = 0; e < num_elem; ++e) {
187437930d1SJeremy L Thompson     for (CeedInt i = 0; i < elem_size; ++i) {
188437930d1SJeremy L Thompson       const CeedInt lid = elem_size * e + i;
1890d0321e0SJeremy L Thompson       const CeedInt gid = indices[lid];
190*b7453713SJeremy L Thompson 
191437930d1SJeremy L Thompson       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
1920d0321e0SJeremy L Thompson     }
1930d0321e0SJeremy L Thompson   }
1940d0321e0SJeremy L Thompson   // Reset running sum
1952b730f8bSJeremy L Thompson   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
196437930d1SJeremy L Thompson   t_offsets[0] = 0;
1970d0321e0SJeremy L Thompson 
1980d0321e0SJeremy L Thompson   // Copy data to device
1990d0321e0SJeremy L Thompson   // -- L-vector indices
2002b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt)));
2012b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice));
2020d0321e0SJeremy L Thompson   // -- Transpose offsets
2032b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt)));
2042b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice));
2050d0321e0SJeremy L Thompson   // -- Transpose indices
2062b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt)));
2072b730f8bSJeremy L Thompson   CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice));
2080d0321e0SJeremy L Thompson 
2090d0321e0SJeremy L Thompson   // Cleanup
2102b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&ind_to_offset));
2112b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&l_vec_indices));
2122b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&t_offsets));
2132b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&t_indices));
2140d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2150d0321e0SJeremy L Thompson }
2160d0321e0SJeremy L Thompson 
2170d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2180d0321e0SJeremy L Thompson // Create restriction
2190d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
220fcbe8c06SSebastian Grimberg int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients,
2210c73c039SSebastian Grimberg                                   const CeedInt8 *curl_orients, CeedElemRestriction r) {
222*b7453713SJeremy L Thompson   Ceed                     ceed, ceed_parent;
223*b7453713SJeremy L Thompson   bool                     is_deterministic, is_strided;
224*b7453713SJeremy L Thompson   char                    *restriction_kernel_path, *restriction_kernel_source;
225*b7453713SJeremy L Thompson   CeedInt                  num_elem, num_comp, elem_size, comp_stride = 1;
226*b7453713SJeremy L Thompson   CeedRestrictionType      rstr_type;
2270d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
228*b7453713SJeremy L Thompson 
229*b7453713SJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
2302b730f8bSJeremy L Thompson   CeedCallBackend(CeedCalloc(1, &impl));
231ca735530SJeremy L Thompson   CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
232ca735530SJeremy L Thompson   CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic));
2332b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
2342b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
2352b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
236437930d1SJeremy L Thompson   CeedInt size       = num_elem * elem_size;
237437930d1SJeremy L Thompson   CeedInt strides[3] = {1, size, elem_size};
238*b7453713SJeremy L Thompson   CeedInt layout[3]  = {1, elem_size * num_elem, elem_size};
2390d0321e0SJeremy L Thompson 
24000125730SSebastian Grimberg   CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type));
24100125730SSebastian Grimberg   CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND,
24200125730SSebastian Grimberg             "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented");
24300125730SSebastian Grimberg 
2440d0321e0SJeremy L Thompson   // Stride data
2452b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided));
246437930d1SJeremy L Thompson   if (is_strided) {
247437930d1SJeremy L Thompson     bool has_backend_strides;
248*b7453713SJeremy L Thompson 
2492b730f8bSJeremy L Thompson     CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides));
250437930d1SJeremy L Thompson     if (!has_backend_strides) {
2512b730f8bSJeremy L Thompson       CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
2520d0321e0SJeremy L Thompson     }
2530d0321e0SJeremy L Thompson   } else {
2542b730f8bSJeremy L Thompson     CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
2550d0321e0SJeremy L Thompson   }
2560d0321e0SJeremy L Thompson 
2570d0321e0SJeremy L Thompson   impl->h_ind           = NULL;
2580d0321e0SJeremy L Thompson   impl->h_ind_allocated = NULL;
2590d0321e0SJeremy L Thompson   impl->d_ind           = NULL;
2600d0321e0SJeremy L Thompson   impl->d_ind_allocated = NULL;
261437930d1SJeremy L Thompson   impl->d_t_indices     = NULL;
262437930d1SJeremy L Thompson   impl->d_t_offsets     = NULL;
263437930d1SJeremy L Thompson   impl->num_nodes       = size;
2642b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionSetData(r, impl));
2652b730f8bSJeremy L Thompson   CeedCallBackend(CeedElemRestrictionSetELayout(r, layout));
2660d0321e0SJeremy L Thompson 
2670d0321e0SJeremy L Thompson   // Set up device indices/offset arrays
268472941f0SJeremy L Thompson   switch (mem_type) {
2696574a04fSJeremy L Thompson     case CEED_MEM_HOST: {
270472941f0SJeremy L Thompson       switch (copy_mode) {
2710d0321e0SJeremy L Thompson         case CEED_OWN_POINTER:
2720d0321e0SJeremy L Thompson           impl->h_ind_allocated = (CeedInt *)indices;
2730d0321e0SJeremy L Thompson           impl->h_ind           = (CeedInt *)indices;
2740d0321e0SJeremy L Thompson           break;
2750d0321e0SJeremy L Thompson         case CEED_USE_POINTER:
2760d0321e0SJeremy L Thompson           impl->h_ind = (CeedInt *)indices;
2770d0321e0SJeremy L Thompson           break;
2780d0321e0SJeremy L Thompson         case CEED_COPY_VALUES:
27944d7a66cSJeremy L Thompson           if (indices != NULL) {
2802b730f8bSJeremy L Thompson             CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
28144d7a66cSJeremy L Thompson             memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
28244d7a66cSJeremy L Thompson             impl->h_ind = impl->h_ind_allocated;
28344d7a66cSJeremy L Thompson           }
2840d0321e0SJeremy L Thompson           break;
2850d0321e0SJeremy L Thompson       }
2860d0321e0SJeremy L Thompson       if (indices != NULL) {
2872b730f8bSJeremy L Thompson         CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
2880d0321e0SJeremy L Thompson         impl->d_ind_allocated = impl->d_ind;  // We own the device memory
2892b730f8bSJeremy L Thompson         CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice));
29058549094SSebastian Grimberg         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices));
2910d0321e0SJeremy L Thompson       }
2926574a04fSJeremy L Thompson       break;
2936574a04fSJeremy L Thompson     }
2946574a04fSJeremy L Thompson     case CEED_MEM_DEVICE: {
295472941f0SJeremy L Thompson       switch (copy_mode) {
2960d0321e0SJeremy L Thompson         case CEED_COPY_VALUES:
2970d0321e0SJeremy L Thompson           if (indices != NULL) {
2982b730f8bSJeremy L Thompson             CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
2990d0321e0SJeremy L Thompson             impl->d_ind_allocated = impl->d_ind;  // We own the device memory
3002b730f8bSJeremy L Thompson             CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice));
3010d0321e0SJeremy L Thompson           }
3020d0321e0SJeremy L Thompson           break;
3030d0321e0SJeremy L Thompson         case CEED_OWN_POINTER:
3040d0321e0SJeremy L Thompson           impl->d_ind           = (CeedInt *)indices;
3050d0321e0SJeremy L Thompson           impl->d_ind_allocated = impl->d_ind;
3060d0321e0SJeremy L Thompson           break;
3070d0321e0SJeremy L Thompson         case CEED_USE_POINTER:
3080d0321e0SJeremy L Thompson           impl->d_ind = (CeedInt *)indices;
3090d0321e0SJeremy L Thompson       }
3100d0321e0SJeremy L Thompson       if (indices != NULL) {
3112b730f8bSJeremy L Thompson         CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
3122b730f8bSJeremy L Thompson         CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost));
31344d7a66cSJeremy L Thompson         impl->h_ind = impl->h_ind_allocated;
31458549094SSebastian Grimberg         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices));
3150d0321e0SJeremy L Thompson       }
3166574a04fSJeremy L Thompson       break;
3176574a04fSJeremy L Thompson     }
3180d0321e0SJeremy L Thompson     // LCOV_EXCL_START
3196574a04fSJeremy L Thompson     default:
3202b730f8bSJeremy L Thompson       return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
3210d0321e0SJeremy L Thompson       // LCOV_EXCL_STOP
3220d0321e0SJeremy L Thompson   }
3230d0321e0SJeremy L Thompson 
3240d0321e0SJeremy L Thompson   // Compile HIP kernels
325437930d1SJeremy L Thompson   CeedInt num_nodes = impl->num_nodes;
326*b7453713SJeremy L Thompson 
3272b730f8bSJeremy L Thompson   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path));
32823d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
3292b730f8bSJeremy L Thompson   CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
33023d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
331eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem,
3322b730f8bSJeremy L Thompson                                   "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES",
3332b730f8bSJeremy L Thompson                                   strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2]));
334eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose));
335eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose));
33658549094SSebastian Grimberg   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose));
33758549094SSebastian Grimberg   if (!is_deterministic) {
338eb7e6cafSJeremy L Thompson     CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose));
33958549094SSebastian Grimberg   } else {
34058549094SSebastian Grimberg     CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTransposeDet", &impl->OffsetTransposeDet));
34158549094SSebastian Grimberg   }
3422b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&restriction_kernel_path));
3432b730f8bSJeremy L Thompson   CeedCallBackend(CeedFree(&restriction_kernel_source));
3440d0321e0SJeremy L Thompson 
3450d0321e0SJeremy L Thompson   // Register backend functions
3462b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Hip));
347b17517eeSSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Hip));
3487c1dbaffSSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Hip));
3492b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Hip));
3502b730f8bSJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Hip));
3510d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
3520d0321e0SJeremy L Thompson }
3530d0321e0SJeremy L Thompson 
3540d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
355