xref: /libCEED/rust/libceed-sys/c-src/backends/hip-ref/ceed-hip-ref-restriction.c (revision 44d7a66c6073d76a4a1f8dc0795479283da37ec1)
13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
30d0321e0SJeremy L Thompson //
43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
50d0321e0SJeremy L Thompson //
63d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
70d0321e0SJeremy L Thompson 
80d0321e0SJeremy L Thompson #include <ceed/ceed.h>
90d0321e0SJeremy L Thompson #include <ceed/backend.h>
10437930d1SJeremy L Thompson #include <ceed/jit-tools.h>
110d0321e0SJeremy L Thompson #include <hip/hip_runtime.h>
120d0321e0SJeremy L Thompson #include <stdbool.h>
130d0321e0SJeremy L Thompson #include <stddef.h>
14*44d7a66cSJeremy L Thompson #include <string.h>
150d0321e0SJeremy L Thompson #include "ceed-hip-ref.h"
160d0321e0SJeremy L Thompson #include "../hip/ceed-hip-compile.h"
170d0321e0SJeremy L Thompson 
180d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
190d0321e0SJeremy L Thompson // Apply restriction
200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
210d0321e0SJeremy L Thompson static int CeedElemRestrictionApply_Hip(CeedElemRestriction r,
2246dc0734SJeremy L Thompson                                         CeedTransposeMode t_mode, CeedVector u,
2346dc0734SJeremy L Thompson                                         CeedVector v, CeedRequest *request) {
240d0321e0SJeremy L Thompson   int ierr;
250d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
260d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
270d0321e0SJeremy L Thompson   Ceed ceed;
280d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
290d0321e0SJeremy L Thompson   Ceed_Hip *data;
300d0321e0SJeremy L Thompson   ierr = CeedGetData(ceed, &data); CeedChkBackend(ierr);
31437930d1SJeremy L Thompson   const CeedInt block_size = 64;
32437930d1SJeremy L Thompson   const CeedInt num_nodes = impl->num_nodes;
33437930d1SJeremy L Thompson   CeedInt num_elem, elem_size;
34437930d1SJeremy L Thompson   CeedElemRestrictionGetNumElements(r, &num_elem);
35437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
360d0321e0SJeremy L Thompson   hipFunction_t kernel;
370d0321e0SJeremy L Thompson 
380d0321e0SJeremy L Thompson   // Get vectors
390d0321e0SJeremy L Thompson   const CeedScalar *d_u;
400d0321e0SJeremy L Thompson   CeedScalar *d_v;
410d0321e0SJeremy L Thompson   ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
42437930d1SJeremy L Thompson   if (t_mode == CEED_TRANSPOSE) {
430d0321e0SJeremy L Thompson     // Sum into for transpose mode, e-vec to l-vec
440d0321e0SJeremy L Thompson     ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
450d0321e0SJeremy L Thompson   } else {
460d0321e0SJeremy L Thompson     // Overwrite for notranspose mode, l-vec to e-vec
470d0321e0SJeremy L Thompson     ierr = CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
480d0321e0SJeremy L Thompson   }
490d0321e0SJeremy L Thompson 
500d0321e0SJeremy L Thompson   // Restrict
51437930d1SJeremy L Thompson   if (t_mode == CEED_NOTRANSPOSE) {
520d0321e0SJeremy L Thompson     // L-vector -> E-vector
530d0321e0SJeremy L Thompson     if (impl->d_ind) {
540d0321e0SJeremy L Thompson       // -- Offsets provided
55437930d1SJeremy L Thompson       kernel = impl->OffsetNoTranspose;
56437930d1SJeremy L Thompson       void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
57437930d1SJeremy L Thompson       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
58437930d1SJeremy L Thompson       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
59437930d1SJeremy L Thompson                               block_size, args); CeedChkBackend(ierr);
600d0321e0SJeremy L Thompson     } else {
610d0321e0SJeremy L Thompson       // -- Strided restriction
62437930d1SJeremy L Thompson       kernel = impl->StridedNoTranspose;
63437930d1SJeremy L Thompson       void *args[] = {&num_elem, &d_u, &d_v};
64437930d1SJeremy L Thompson       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
65437930d1SJeremy L Thompson       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
66437930d1SJeremy L Thompson                               block_size, args); CeedChkBackend(ierr);
670d0321e0SJeremy L Thompson     }
680d0321e0SJeremy L Thompson   } else {
690d0321e0SJeremy L Thompson     // E-vector -> L-vector
700d0321e0SJeremy L Thompson     if (impl->d_ind) {
710d0321e0SJeremy L Thompson       // -- Offsets provided
72437930d1SJeremy L Thompson       kernel = impl->OffsetTranspose;
73437930d1SJeremy L Thompson       void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices,
74437930d1SJeremy L Thompson                       &impl->d_t_offsets, &d_u, &d_v
750d0321e0SJeremy L Thompson                      };
76437930d1SJeremy L Thompson       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
77437930d1SJeremy L Thompson                               block_size, args); CeedChkBackend(ierr);
780d0321e0SJeremy L Thompson     } else {
790d0321e0SJeremy L Thompson       // -- Strided restriction
80437930d1SJeremy L Thompson       kernel = impl->StridedTranspose;
81437930d1SJeremy L Thompson       void *args[] = {&num_elem, &d_u, &d_v};
82437930d1SJeremy L Thompson       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
83437930d1SJeremy L Thompson                               block_size, args); CeedChkBackend(ierr);
840d0321e0SJeremy L Thompson     }
850d0321e0SJeremy L Thompson   }
860d0321e0SJeremy L Thompson 
870d0321e0SJeremy L Thompson   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED)
880d0321e0SJeremy L Thompson     *request = NULL;
890d0321e0SJeremy L Thompson 
900d0321e0SJeremy L Thompson   // Restore arrays
910d0321e0SJeremy L Thompson   ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
920d0321e0SJeremy L Thompson   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
930d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
940d0321e0SJeremy L Thompson }
950d0321e0SJeremy L Thompson 
960d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
970d0321e0SJeremy L Thompson // Blocked not supported
980d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
990d0321e0SJeremy L Thompson int CeedElemRestrictionApplyBlock_Hip(CeedElemRestriction r, CeedInt block,
100437930d1SJeremy L Thompson                                       CeedTransposeMode t_mode, CeedVector u,
1010d0321e0SJeremy L Thompson                                       CeedVector v, CeedRequest *request) {
1020d0321e0SJeremy L Thompson   // LCOV_EXCL_START
1030d0321e0SJeremy L Thompson   int ierr;
1040d0321e0SJeremy L Thompson   Ceed ceed;
1050d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
1060d0321e0SJeremy L Thompson   return CeedError(ceed, CEED_ERROR_BACKEND,
1070d0321e0SJeremy L Thompson                    "Backend does not implement blocked restrictions");
1080d0321e0SJeremy L Thompson   // LCOV_EXCL_STOP
1090d0321e0SJeremy L Thompson }
1100d0321e0SJeremy L Thompson 
1110d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1120d0321e0SJeremy L Thompson // Get offsets
1130d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1140d0321e0SJeremy L Thompson static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr,
1150d0321e0SJeremy L Thompson     CeedMemType mtype, const CeedInt **offsets) {
1160d0321e0SJeremy L Thompson   int ierr;
1170d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
1180d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetData(rstr, &impl); CeedChkBackend(ierr);
1190d0321e0SJeremy L Thompson 
1200d0321e0SJeremy L Thompson   switch (mtype) {
1210d0321e0SJeremy L Thompson   case CEED_MEM_HOST:
1220d0321e0SJeremy L Thompson     *offsets = impl->h_ind;
1230d0321e0SJeremy L Thompson     break;
1240d0321e0SJeremy L Thompson   case CEED_MEM_DEVICE:
1250d0321e0SJeremy L Thompson     *offsets = impl->d_ind;
1260d0321e0SJeremy L Thompson     break;
1270d0321e0SJeremy L Thompson   }
1280d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1290d0321e0SJeremy L Thompson }
1300d0321e0SJeremy L Thompson 
1310d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1320d0321e0SJeremy L Thompson // Destroy restriction
1330d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1340d0321e0SJeremy L Thompson static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) {
1350d0321e0SJeremy L Thompson   int ierr;
1360d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
1370d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
1380d0321e0SJeremy L Thompson 
1390d0321e0SJeremy L Thompson   Ceed ceed;
1400d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
1410d0321e0SJeremy L Thompson   ierr = hipModuleUnload(impl->module); CeedChk_Hip(ceed, ierr);
1420d0321e0SJeremy L Thompson   ierr = CeedFree(&impl->h_ind_allocated); CeedChkBackend(ierr);
1430d0321e0SJeremy L Thompson   ierr = hipFree(impl->d_ind_allocated); CeedChk_Hip(ceed, ierr);
144437930d1SJeremy L Thompson   ierr = hipFree(impl->d_t_offsets); CeedChk_Hip(ceed, ierr);
145437930d1SJeremy L Thompson   ierr = hipFree(impl->d_t_indices); CeedChk_Hip(ceed, ierr);
146437930d1SJeremy L Thompson   ierr = hipFree(impl->d_l_vec_indices); CeedChk_Hip(ceed, ierr);
1470d0321e0SJeremy L Thompson   ierr = CeedFree(&impl); CeedChkBackend(ierr);
148437930d1SJeremy L Thompson 
1490d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1500d0321e0SJeremy L Thompson }
1510d0321e0SJeremy L Thompson 
1520d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1530d0321e0SJeremy L Thompson // Create transpose offsets and indices
1540d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1550d0321e0SJeremy L Thompson static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r,
1560d0321e0SJeremy L Thompson     const CeedInt *indices) {
1570d0321e0SJeremy L Thompson   int ierr;
1580d0321e0SJeremy L Thompson   Ceed ceed;
1590d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
1600d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
1610d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
162e79b91d9SJeremy L Thompson   CeedSize l_size;
163e79b91d9SJeremy L Thompson   CeedInt num_elem, elem_size, num_comp;
164437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr);
165437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
166437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetLVectorSize(r, &l_size); CeedChkBackend(ierr);
167437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr);
1680d0321e0SJeremy L Thompson 
169437930d1SJeremy L Thompson   // Count num_nodes
170437930d1SJeremy L Thompson   bool *is_node;
171437930d1SJeremy L Thompson   ierr = CeedCalloc(l_size, &is_node); CeedChkBackend(ierr);
172437930d1SJeremy L Thompson   const CeedInt size_indices = num_elem * elem_size;
173437930d1SJeremy L Thompson   for (CeedInt i = 0; i < size_indices; i++)
174437930d1SJeremy L Thompson     is_node[indices[i]] = 1;
175437930d1SJeremy L Thompson   CeedInt num_nodes = 0;
176437930d1SJeremy L Thompson   for (CeedInt i = 0; i < l_size; i++)
177437930d1SJeremy L Thompson     num_nodes += is_node[i];
178437930d1SJeremy L Thompson   impl->num_nodes = num_nodes;
1790d0321e0SJeremy L Thompson 
1800d0321e0SJeremy L Thompson   // L-vector offsets array
181437930d1SJeremy L Thompson   CeedInt *ind_to_offset, *l_vec_indices;
182437930d1SJeremy L Thompson   ierr = CeedCalloc(l_size, &ind_to_offset); CeedChkBackend(ierr);
183437930d1SJeremy L Thompson   ierr = CeedCalloc(num_nodes, &l_vec_indices); CeedChkBackend(ierr);
1840d0321e0SJeremy L Thompson   CeedInt j = 0;
185437930d1SJeremy L Thompson   for (CeedInt i = 0; i < l_size; i++)
186437930d1SJeremy L Thompson     if (is_node[i]) {
187437930d1SJeremy L Thompson       l_vec_indices[j] = i;
1880d0321e0SJeremy L Thompson       ind_to_offset[i] = j++;
1890d0321e0SJeremy L Thompson     }
190437930d1SJeremy L Thompson   ierr = CeedFree(&is_node); CeedChkBackend(ierr);
1910d0321e0SJeremy L Thompson 
1920d0321e0SJeremy L Thompson   // Compute transpose offsets and indices
193437930d1SJeremy L Thompson   const CeedInt size_offsets = num_nodes + 1;
194437930d1SJeremy L Thompson   CeedInt *t_offsets;
195437930d1SJeremy L Thompson   ierr = CeedCalloc(size_offsets, &t_offsets); CeedChkBackend(ierr);
196437930d1SJeremy L Thompson   CeedInt *t_indices;
197437930d1SJeremy L Thompson   ierr = CeedMalloc(size_indices, &t_indices); CeedChkBackend(ierr);
1980d0321e0SJeremy L Thompson   // Count node multiplicity
199437930d1SJeremy L Thompson   for (CeedInt e = 0; e < num_elem; ++e)
200437930d1SJeremy L Thompson     for (CeedInt i = 0; i < elem_size; ++i)
201437930d1SJeremy L Thompson       ++t_offsets[ind_to_offset[indices[elem_size*e + i]] + 1];
2020d0321e0SJeremy L Thompson   // Convert to running sum
203437930d1SJeremy L Thompson   for (CeedInt i = 1; i < size_offsets; ++i)
204437930d1SJeremy L Thompson     t_offsets[i] += t_offsets[i-1];
2050d0321e0SJeremy L Thompson   // List all E-vec indices associated with L-vec node
206437930d1SJeremy L Thompson   for (CeedInt e = 0; e < num_elem; ++e) {
207437930d1SJeremy L Thompson     for (CeedInt i = 0; i < elem_size; ++i) {
208437930d1SJeremy L Thompson       const CeedInt lid = elem_size*e + i;
2090d0321e0SJeremy L Thompson       const CeedInt gid = indices[lid];
210437930d1SJeremy L Thompson       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
2110d0321e0SJeremy L Thompson     }
2120d0321e0SJeremy L Thompson   }
2130d0321e0SJeremy L Thompson   // Reset running sum
214437930d1SJeremy L Thompson   for (int i = size_offsets - 1; i > 0; --i)
215437930d1SJeremy L Thompson     t_offsets[i] = t_offsets[i - 1];
216437930d1SJeremy L Thompson   t_offsets[0] = 0;
2170d0321e0SJeremy L Thompson 
2180d0321e0SJeremy L Thompson   // Copy data to device
2190d0321e0SJeremy L Thompson   // -- L-vector indices
220437930d1SJeremy L Thompson   ierr = hipMalloc((void **)&impl->d_l_vec_indices, num_nodes*sizeof(CeedInt));
2210d0321e0SJeremy L Thompson   CeedChk_Hip(ceed, ierr);
222437930d1SJeremy L Thompson   ierr = hipMemcpy(impl->d_l_vec_indices, l_vec_indices,
223437930d1SJeremy L Thompson                    num_nodes*sizeof(CeedInt), hipMemcpyHostToDevice);
2240d0321e0SJeremy L Thompson   CeedChk_Hip(ceed, ierr);
2250d0321e0SJeremy L Thompson   // -- Transpose offsets
226437930d1SJeremy L Thompson   ierr = hipMalloc((void **)&impl->d_t_offsets, size_offsets*sizeof(CeedInt));
2270d0321e0SJeremy L Thompson   CeedChk_Hip(ceed, ierr);
228437930d1SJeremy L Thompson   ierr = hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets*sizeof(CeedInt),
2290d0321e0SJeremy L Thompson                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
2300d0321e0SJeremy L Thompson   // -- Transpose indices
231437930d1SJeremy L Thompson   ierr = hipMalloc((void **)&impl->d_t_indices, size_indices*sizeof(CeedInt));
2320d0321e0SJeremy L Thompson   CeedChk_Hip(ceed, ierr);
233437930d1SJeremy L Thompson   ierr = hipMemcpy(impl->d_t_indices, t_indices, size_indices*sizeof(CeedInt),
2340d0321e0SJeremy L Thompson                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
2350d0321e0SJeremy L Thompson 
2360d0321e0SJeremy L Thompson   // Cleanup
2370d0321e0SJeremy L Thompson   ierr = CeedFree(&ind_to_offset); CeedChkBackend(ierr);
238437930d1SJeremy L Thompson   ierr = CeedFree(&l_vec_indices); CeedChkBackend(ierr);
239437930d1SJeremy L Thompson   ierr = CeedFree(&t_offsets); CeedChkBackend(ierr);
240437930d1SJeremy L Thompson   ierr = CeedFree(&t_indices); CeedChkBackend(ierr);
241437930d1SJeremy L Thompson 
2420d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2430d0321e0SJeremy L Thompson }
2440d0321e0SJeremy L Thompson 
2450d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2460d0321e0SJeremy L Thompson // Create restriction
2470d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2480d0321e0SJeremy L Thompson int CeedElemRestrictionCreate_Hip(CeedMemType mtype, CeedCopyMode cmode,
2490d0321e0SJeremy L Thompson                                   const CeedInt *indices,
2500d0321e0SJeremy L Thompson                                   CeedElemRestriction r) {
2510d0321e0SJeremy L Thompson   int ierr;
2520d0321e0SJeremy L Thompson   Ceed ceed;
2530d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
2540d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
2550d0321e0SJeremy L Thompson   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
256437930d1SJeremy L Thompson   CeedInt num_elem, num_comp, elem_size;
257437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr);
258437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr);
259437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
260437930d1SJeremy L Thompson   CeedInt size = num_elem * elem_size;
261437930d1SJeremy L Thompson   CeedInt strides[3] = {1, size, elem_size};
262437930d1SJeremy L Thompson   CeedInt comp_stride = 1;
2630d0321e0SJeremy L Thompson 
2640d0321e0SJeremy L Thompson   // Stride data
265437930d1SJeremy L Thompson   bool is_strided;
266437930d1SJeremy L Thompson   ierr = CeedElemRestrictionIsStrided(r, &is_strided); CeedChkBackend(ierr);
267437930d1SJeremy L Thompson   if (is_strided) {
268437930d1SJeremy L Thompson     bool has_backend_strides;
269437930d1SJeremy L Thompson     ierr = CeedElemRestrictionHasBackendStrides(r, &has_backend_strides);
2700d0321e0SJeremy L Thompson     CeedChkBackend(ierr);
271437930d1SJeremy L Thompson     if (!has_backend_strides) {
2720d0321e0SJeremy L Thompson       ierr = CeedElemRestrictionGetStrides(r, &strides); CeedChkBackend(ierr);
2730d0321e0SJeremy L Thompson     }
2740d0321e0SJeremy L Thompson   } else {
275437930d1SJeremy L Thompson     ierr = CeedElemRestrictionGetCompStride(r, &comp_stride); CeedChkBackend(ierr);
2760d0321e0SJeremy L Thompson   }
2770d0321e0SJeremy L Thompson 
2780d0321e0SJeremy L Thompson   impl->h_ind           = NULL;
2790d0321e0SJeremy L Thompson   impl->h_ind_allocated = NULL;
2800d0321e0SJeremy L Thompson   impl->d_ind           = NULL;
2810d0321e0SJeremy L Thompson   impl->d_ind_allocated = NULL;
282437930d1SJeremy L Thompson   impl->d_t_indices     = NULL;
283437930d1SJeremy L Thompson   impl->d_t_offsets     = NULL;
284437930d1SJeremy L Thompson   impl->num_nodes = size;
2850d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionSetData(r, impl); CeedChkBackend(ierr);
286437930d1SJeremy L Thompson   CeedInt layout[3] = {1, elem_size*num_elem, elem_size};
2870d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionSetELayout(r, layout); CeedChkBackend(ierr);
2880d0321e0SJeremy L Thompson 
2890d0321e0SJeremy L Thompson   // Set up device indices/offset arrays
2900d0321e0SJeremy L Thompson   if (mtype == CEED_MEM_HOST) {
2910d0321e0SJeremy L Thompson     switch (cmode) {
2920d0321e0SJeremy L Thompson     case CEED_OWN_POINTER:
2930d0321e0SJeremy L Thompson       impl->h_ind_allocated = (CeedInt *)indices;
2940d0321e0SJeremy L Thompson       impl->h_ind = (CeedInt *)indices;
2950d0321e0SJeremy L Thompson       break;
2960d0321e0SJeremy L Thompson     case CEED_USE_POINTER:
2970d0321e0SJeremy L Thompson       impl->h_ind = (CeedInt *)indices;
2980d0321e0SJeremy L Thompson       break;
2990d0321e0SJeremy L Thompson     case CEED_COPY_VALUES:
300*44d7a66cSJeremy L Thompson       if (indices != NULL) {
301*44d7a66cSJeremy L Thompson         ierr = CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated);
302*44d7a66cSJeremy L Thompson         CeedChkBackend(ierr);
303*44d7a66cSJeremy L Thompson         memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
304*44d7a66cSJeremy L Thompson         impl->h_ind = impl->h_ind_allocated;
305*44d7a66cSJeremy L Thompson       }
3060d0321e0SJeremy L Thompson       break;
3070d0321e0SJeremy L Thompson     }
3080d0321e0SJeremy L Thompson     if (indices != NULL) {
3090d0321e0SJeremy L Thompson       ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt));
3100d0321e0SJeremy L Thompson       CeedChk_Hip(ceed, ierr);
3110d0321e0SJeremy L Thompson       impl->d_ind_allocated = impl->d_ind; // We own the device memory
3120d0321e0SJeremy L Thompson       ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt),
3130d0321e0SJeremy L Thompson                        hipMemcpyHostToDevice);
3140d0321e0SJeremy L Thompson       CeedChk_Hip(ceed, ierr);
3150d0321e0SJeremy L Thompson       ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr);
3160d0321e0SJeremy L Thompson     }
3170d0321e0SJeremy L Thompson   } else if (mtype == CEED_MEM_DEVICE) {
3180d0321e0SJeremy L Thompson     switch (cmode) {
3190d0321e0SJeremy L Thompson     case CEED_COPY_VALUES:
3200d0321e0SJeremy L Thompson       if (indices != NULL) {
3210d0321e0SJeremy L Thompson         ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt));
3220d0321e0SJeremy L Thompson         CeedChk_Hip(ceed, ierr);
3230d0321e0SJeremy L Thompson         impl->d_ind_allocated = impl->d_ind; // We own the device memory
3240d0321e0SJeremy L Thompson         ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt),
3250d0321e0SJeremy L Thompson                          hipMemcpyDeviceToDevice);
3260d0321e0SJeremy L Thompson         CeedChk_Hip(ceed, ierr);
3270d0321e0SJeremy L Thompson       }
3280d0321e0SJeremy L Thompson       break;
3290d0321e0SJeremy L Thompson     case CEED_OWN_POINTER:
3300d0321e0SJeremy L Thompson       impl->d_ind = (CeedInt *)indices;
3310d0321e0SJeremy L Thompson       impl->d_ind_allocated = impl->d_ind;
3320d0321e0SJeremy L Thompson       break;
3330d0321e0SJeremy L Thompson     case CEED_USE_POINTER:
3340d0321e0SJeremy L Thompson       impl->d_ind = (CeedInt *)indices;
3350d0321e0SJeremy L Thompson     }
3360d0321e0SJeremy L Thompson     if (indices != NULL) {
337*44d7a66cSJeremy L Thompson       ierr = CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated);
338*44d7a66cSJeremy L Thompson       CeedChkBackend(ierr);
339*44d7a66cSJeremy L Thompson       ierr = hipMemcpy(impl->h_ind_allocated, impl->d_ind,
340*44d7a66cSJeremy L Thompson                        elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost);
341*44d7a66cSJeremy L Thompson       CeedChk_Hip(ceed, ierr);
342*44d7a66cSJeremy L Thompson       impl->h_ind = impl->h_ind_allocated;
3430d0321e0SJeremy L Thompson       ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr);
3440d0321e0SJeremy L Thompson     }
3450d0321e0SJeremy L Thompson   } else {
3460d0321e0SJeremy L Thompson     // LCOV_EXCL_START
3470d0321e0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND,
3480d0321e0SJeremy L Thompson                      "Only MemType = HOST or DEVICE supported");
3490d0321e0SJeremy L Thompson     // LCOV_EXCL_STOP
3500d0321e0SJeremy L Thompson   }
3510d0321e0SJeremy L Thompson 
3520d0321e0SJeremy L Thompson   // Compile HIP kernels
353437930d1SJeremy L Thompson   CeedInt num_nodes = impl->num_nodes;
354437930d1SJeremy L Thompson   char *restriction_kernel_path, *restriction_kernel_source;
355ee5a26f2SJeremy L Thompson   ierr = CeedGetJitAbsolutePath(ceed,
356a0154adeSJed Brown                                 "ceed/jit-source/hip/hip-ref-restriction.h",
357437930d1SJeremy L Thompson                                 &restriction_kernel_path); CeedChkBackend(ierr);
35846dc0734SJeremy L Thompson   CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source -----\n");
359437930d1SJeremy L Thompson   ierr = CeedLoadSourceToBuffer(ceed, restriction_kernel_path,
360437930d1SJeremy L Thompson                                 &restriction_kernel_source);
361437930d1SJeremy L Thompson   CeedChkBackend(ierr);
36246dc0734SJeremy L Thompson   CeedDebug256(ceed, 2,
36346dc0734SJeremy L Thompson                "----- Loading Restriction Kernel Source Complete! -----\n");
364437930d1SJeremy L Thompson   ierr = CeedCompileHip(ceed, restriction_kernel_source, &impl->module, 8,
365d7d111ecSJeremy L Thompson                         "RESTR_ELEM_SIZE", elem_size,
366d7d111ecSJeremy L Thompson                         "RESTR_NUM_ELEM", num_elem,
367d7d111ecSJeremy L Thompson                         "RESTR_NUM_COMP", num_comp,
368d7d111ecSJeremy L Thompson                         "RESTR_NUM_NODES", num_nodes,
369d7d111ecSJeremy L Thompson                         "RESTR_COMP_STRIDE", comp_stride,
370d7d111ecSJeremy L Thompson                         "RESTR_STRIDE_NODES", strides[0],
371d7d111ecSJeremy L Thompson                         "RESTR_STRIDE_COMP", strides[1],
372d7d111ecSJeremy L Thompson                         "RESTR_STRIDE_ELEM", strides[2]); CeedChkBackend(ierr);
373437930d1SJeremy L Thompson   ierr = CeedGetKernelHip(ceed, impl->module, "StridedNoTranspose",
374437930d1SJeremy L Thompson                           &impl->StridedNoTranspose); CeedChkBackend(ierr);
375437930d1SJeremy L Thompson   ierr = CeedGetKernelHip(ceed, impl->module, "OffsetNoTranspose",
376437930d1SJeremy L Thompson                           &impl->OffsetNoTranspose); CeedChkBackend(ierr);
377437930d1SJeremy L Thompson   ierr = CeedGetKernelHip(ceed, impl->module, "StridedTranspose",
378437930d1SJeremy L Thompson                           &impl->StridedTranspose); CeedChkBackend(ierr);
379437930d1SJeremy L Thompson   ierr = CeedGetKernelHip(ceed, impl->module, "OffsetTranspose",
380437930d1SJeremy L Thompson                           &impl->OffsetTranspose); CeedChkBackend(ierr);
381437930d1SJeremy L Thompson   ierr = CeedFree(&restriction_kernel_path); CeedChkBackend(ierr);
382437930d1SJeremy L Thompson   ierr = CeedFree(&restriction_kernel_source); CeedChkBackend(ierr);
3830d0321e0SJeremy L Thompson 
3840d0321e0SJeremy L Thompson   // Register backend functions
3850d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply",
3860d0321e0SJeremy L Thompson                                 CeedElemRestrictionApply_Hip);
3870d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
3880d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyBlock",
3890d0321e0SJeremy L Thompson                                 CeedElemRestrictionApplyBlock_Hip);
3900d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
3910d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets",
3920d0321e0SJeremy L Thompson                                 CeedElemRestrictionGetOffsets_Hip);
3930d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
3940d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy",
3950d0321e0SJeremy L Thompson                                 CeedElemRestrictionDestroy_Hip);
3960d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
3970d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
3980d0321e0SJeremy L Thompson }
3990d0321e0SJeremy L Thompson 
4000d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
4010d0321e0SJeremy L Thompson // Blocked not supported
4020d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
4030d0321e0SJeremy L Thompson int CeedElemRestrictionCreateBlocked_Hip(const CeedMemType mtype,
4040d0321e0SJeremy L Thompson     const CeedCopyMode cmode, const CeedInt *indices, CeedElemRestriction r) {
4050d0321e0SJeremy L Thompson   int ierr;
4060d0321e0SJeremy L Thompson   Ceed ceed;
4070d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
4080d0321e0SJeremy L Thompson   return CeedError(ceed, CEED_ERROR_BACKEND,
4090d0321e0SJeremy L Thompson                    "Backend does not implement blocked restrictions");
4100d0321e0SJeremy L Thompson }
4110d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
412