xref: /libCEED/rust/libceed-sys/c-src/backends/hip-ref/ceed-hip-ref-restriction.c (revision 3d8e882215d238700cdceb37404f76ca7fa24eaa)
1*3d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*3d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
30d0321e0SJeremy L Thompson //
4*3d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
50d0321e0SJeremy L Thompson //
6*3d8e8822SJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
70d0321e0SJeremy L Thompson 
80d0321e0SJeremy L Thompson #include <ceed/ceed.h>
90d0321e0SJeremy L Thompson #include <ceed/backend.h>
10437930d1SJeremy L Thompson #include <ceed/jit-tools.h>
110d0321e0SJeremy L Thompson #include <hip/hip_runtime.h>
120d0321e0SJeremy L Thompson #include <stdbool.h>
130d0321e0SJeremy L Thompson #include <stddef.h>
140d0321e0SJeremy L Thompson #include "ceed-hip-ref.h"
150d0321e0SJeremy L Thompson #include "../hip/ceed-hip-compile.h"
160d0321e0SJeremy L Thompson 
170d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
180d0321e0SJeremy L Thompson // Apply restriction
190d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
200d0321e0SJeremy L Thompson static int CeedElemRestrictionApply_Hip(CeedElemRestriction r,
2146dc0734SJeremy L Thompson                                         CeedTransposeMode t_mode, CeedVector u,
2246dc0734SJeremy L Thompson                                         CeedVector v, CeedRequest *request) {
230d0321e0SJeremy L Thompson   int ierr;
240d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
250d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
260d0321e0SJeremy L Thompson   Ceed ceed;
270d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
280d0321e0SJeremy L Thompson   Ceed_Hip *data;
290d0321e0SJeremy L Thompson   ierr = CeedGetData(ceed, &data); CeedChkBackend(ierr);
30437930d1SJeremy L Thompson   const CeedInt block_size = 64;
31437930d1SJeremy L Thompson   const CeedInt num_nodes = impl->num_nodes;
32437930d1SJeremy L Thompson   CeedInt num_elem, elem_size;
33437930d1SJeremy L Thompson   CeedElemRestrictionGetNumElements(r, &num_elem);
34437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
350d0321e0SJeremy L Thompson   hipFunction_t kernel;
360d0321e0SJeremy L Thompson 
370d0321e0SJeremy L Thompson   // Get vectors
380d0321e0SJeremy L Thompson   const CeedScalar *d_u;
390d0321e0SJeremy L Thompson   CeedScalar *d_v;
400d0321e0SJeremy L Thompson   ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
41437930d1SJeremy L Thompson   if (t_mode == CEED_TRANSPOSE) {
420d0321e0SJeremy L Thompson     // Sum into for transpose mode, e-vec to l-vec
430d0321e0SJeremy L Thompson     ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
440d0321e0SJeremy L Thompson   } else {
450d0321e0SJeremy L Thompson     // Overwrite for notranspose mode, l-vec to e-vec
460d0321e0SJeremy L Thompson     ierr = CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
470d0321e0SJeremy L Thompson   }
480d0321e0SJeremy L Thompson 
490d0321e0SJeremy L Thompson   // Restrict
50437930d1SJeremy L Thompson   if (t_mode == CEED_NOTRANSPOSE) {
510d0321e0SJeremy L Thompson     // L-vector -> E-vector
520d0321e0SJeremy L Thompson     if (impl->d_ind) {
530d0321e0SJeremy L Thompson       // -- Offsets provided
54437930d1SJeremy L Thompson       kernel = impl->OffsetNoTranspose;
55437930d1SJeremy L Thompson       void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
56437930d1SJeremy L Thompson       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
57437930d1SJeremy L Thompson       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
58437930d1SJeremy L Thompson                               block_size, args); CeedChkBackend(ierr);
590d0321e0SJeremy L Thompson     } else {
600d0321e0SJeremy L Thompson       // -- Strided restriction
61437930d1SJeremy L Thompson       kernel = impl->StridedNoTranspose;
62437930d1SJeremy L Thompson       void *args[] = {&num_elem, &d_u, &d_v};
63437930d1SJeremy L Thompson       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
64437930d1SJeremy L Thompson       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
65437930d1SJeremy L Thompson                               block_size, args); CeedChkBackend(ierr);
660d0321e0SJeremy L Thompson     }
670d0321e0SJeremy L Thompson   } else {
680d0321e0SJeremy L Thompson     // E-vector -> L-vector
690d0321e0SJeremy L Thompson     if (impl->d_ind) {
700d0321e0SJeremy L Thompson       // -- Offsets provided
71437930d1SJeremy L Thompson       kernel = impl->OffsetTranspose;
72437930d1SJeremy L Thompson       void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices,
73437930d1SJeremy L Thompson                       &impl->d_t_offsets, &d_u, &d_v
740d0321e0SJeremy L Thompson                      };
75437930d1SJeremy L Thompson       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
76437930d1SJeremy L Thompson                               block_size, args); CeedChkBackend(ierr);
770d0321e0SJeremy L Thompson     } else {
780d0321e0SJeremy L Thompson       // -- Strided restriction
79437930d1SJeremy L Thompson       kernel = impl->StridedTranspose;
80437930d1SJeremy L Thompson       void *args[] = {&num_elem, &d_u, &d_v};
81437930d1SJeremy L Thompson       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
82437930d1SJeremy L Thompson                               block_size, args); CeedChkBackend(ierr);
830d0321e0SJeremy L Thompson     }
840d0321e0SJeremy L Thompson   }
850d0321e0SJeremy L Thompson 
860d0321e0SJeremy L Thompson   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED)
870d0321e0SJeremy L Thompson     *request = NULL;
880d0321e0SJeremy L Thompson 
890d0321e0SJeremy L Thompson   // Restore arrays
900d0321e0SJeremy L Thompson   ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
910d0321e0SJeremy L Thompson   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
920d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
930d0321e0SJeremy L Thompson }
940d0321e0SJeremy L Thompson 
950d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
960d0321e0SJeremy L Thompson // Blocked not supported
970d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
980d0321e0SJeremy L Thompson int CeedElemRestrictionApplyBlock_Hip(CeedElemRestriction r, CeedInt block,
99437930d1SJeremy L Thompson                                       CeedTransposeMode t_mode, CeedVector u,
1000d0321e0SJeremy L Thompson                                       CeedVector v, CeedRequest *request) {
1010d0321e0SJeremy L Thompson   // LCOV_EXCL_START
1020d0321e0SJeremy L Thompson   int ierr;
1030d0321e0SJeremy L Thompson   Ceed ceed;
1040d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
1050d0321e0SJeremy L Thompson   return CeedError(ceed, CEED_ERROR_BACKEND,
1060d0321e0SJeremy L Thompson                    "Backend does not implement blocked restrictions");
1070d0321e0SJeremy L Thompson   // LCOV_EXCL_STOP
1080d0321e0SJeremy L Thompson }
1090d0321e0SJeremy L Thompson 
1100d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1110d0321e0SJeremy L Thompson // Get offsets
1120d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1130d0321e0SJeremy L Thompson static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr,
1140d0321e0SJeremy L Thompson     CeedMemType mtype, const CeedInt **offsets) {
1150d0321e0SJeremy L Thompson   int ierr;
1160d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
1170d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetData(rstr, &impl); CeedChkBackend(ierr);
1180d0321e0SJeremy L Thompson 
1190d0321e0SJeremy L Thompson   switch (mtype) {
1200d0321e0SJeremy L Thompson   case CEED_MEM_HOST:
1210d0321e0SJeremy L Thompson     *offsets = impl->h_ind;
1220d0321e0SJeremy L Thompson     break;
1230d0321e0SJeremy L Thompson   case CEED_MEM_DEVICE:
1240d0321e0SJeremy L Thompson     *offsets = impl->d_ind;
1250d0321e0SJeremy L Thompson     break;
1260d0321e0SJeremy L Thompson   }
1270d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1280d0321e0SJeremy L Thompson }
1290d0321e0SJeremy L Thompson 
1300d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1310d0321e0SJeremy L Thompson // Destroy restriction
1320d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1330d0321e0SJeremy L Thompson static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) {
1340d0321e0SJeremy L Thompson   int ierr;
1350d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
1360d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
1370d0321e0SJeremy L Thompson 
1380d0321e0SJeremy L Thompson   Ceed ceed;
1390d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
1400d0321e0SJeremy L Thompson   ierr = hipModuleUnload(impl->module); CeedChk_Hip(ceed, ierr);
1410d0321e0SJeremy L Thompson   ierr = CeedFree(&impl->h_ind_allocated); CeedChkBackend(ierr);
1420d0321e0SJeremy L Thompson   ierr = hipFree(impl->d_ind_allocated); CeedChk_Hip(ceed, ierr);
143437930d1SJeremy L Thompson   ierr = hipFree(impl->d_t_offsets); CeedChk_Hip(ceed, ierr);
144437930d1SJeremy L Thompson   ierr = hipFree(impl->d_t_indices); CeedChk_Hip(ceed, ierr);
145437930d1SJeremy L Thompson   ierr = hipFree(impl->d_l_vec_indices); CeedChk_Hip(ceed, ierr);
1460d0321e0SJeremy L Thompson   ierr = CeedFree(&impl); CeedChkBackend(ierr);
147437930d1SJeremy L Thompson 
1480d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1490d0321e0SJeremy L Thompson }
1500d0321e0SJeremy L Thompson 
1510d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1520d0321e0SJeremy L Thompson // Create transpose offsets and indices
1530d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
1540d0321e0SJeremy L Thompson static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r,
1550d0321e0SJeremy L Thompson     const CeedInt *indices) {
1560d0321e0SJeremy L Thompson   int ierr;
1570d0321e0SJeremy L Thompson   Ceed ceed;
1580d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
1590d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
1600d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
161e79b91d9SJeremy L Thompson   CeedSize l_size;
162e79b91d9SJeremy L Thompson   CeedInt num_elem, elem_size, num_comp;
163437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr);
164437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
165437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetLVectorSize(r, &l_size); CeedChkBackend(ierr);
166437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr);
1670d0321e0SJeremy L Thompson 
168437930d1SJeremy L Thompson   // Count num_nodes
169437930d1SJeremy L Thompson   bool *is_node;
170437930d1SJeremy L Thompson   ierr = CeedCalloc(l_size, &is_node); CeedChkBackend(ierr);
171437930d1SJeremy L Thompson   const CeedInt size_indices = num_elem * elem_size;
172437930d1SJeremy L Thompson   for (CeedInt i = 0; i < size_indices; i++)
173437930d1SJeremy L Thompson     is_node[indices[i]] = 1;
174437930d1SJeremy L Thompson   CeedInt num_nodes = 0;
175437930d1SJeremy L Thompson   for (CeedInt i = 0; i < l_size; i++)
176437930d1SJeremy L Thompson     num_nodes += is_node[i];
177437930d1SJeremy L Thompson   impl->num_nodes = num_nodes;
1780d0321e0SJeremy L Thompson 
1790d0321e0SJeremy L Thompson   // L-vector offsets array
180437930d1SJeremy L Thompson   CeedInt *ind_to_offset, *l_vec_indices;
181437930d1SJeremy L Thompson   ierr = CeedCalloc(l_size, &ind_to_offset); CeedChkBackend(ierr);
182437930d1SJeremy L Thompson   ierr = CeedCalloc(num_nodes, &l_vec_indices); CeedChkBackend(ierr);
1830d0321e0SJeremy L Thompson   CeedInt j = 0;
184437930d1SJeremy L Thompson   for (CeedInt i = 0; i < l_size; i++)
185437930d1SJeremy L Thompson     if (is_node[i]) {
186437930d1SJeremy L Thompson       l_vec_indices[j] = i;
1870d0321e0SJeremy L Thompson       ind_to_offset[i] = j++;
1880d0321e0SJeremy L Thompson     }
189437930d1SJeremy L Thompson   ierr = CeedFree(&is_node); CeedChkBackend(ierr);
1900d0321e0SJeremy L Thompson 
1910d0321e0SJeremy L Thompson   // Compute transpose offsets and indices
192437930d1SJeremy L Thompson   const CeedInt size_offsets = num_nodes + 1;
193437930d1SJeremy L Thompson   CeedInt *t_offsets;
194437930d1SJeremy L Thompson   ierr = CeedCalloc(size_offsets, &t_offsets); CeedChkBackend(ierr);
195437930d1SJeremy L Thompson   CeedInt *t_indices;
196437930d1SJeremy L Thompson   ierr = CeedMalloc(size_indices, &t_indices); CeedChkBackend(ierr);
1970d0321e0SJeremy L Thompson   // Count node multiplicity
198437930d1SJeremy L Thompson   for (CeedInt e = 0; e < num_elem; ++e)
199437930d1SJeremy L Thompson     for (CeedInt i = 0; i < elem_size; ++i)
200437930d1SJeremy L Thompson       ++t_offsets[ind_to_offset[indices[elem_size*e + i]] + 1];
2010d0321e0SJeremy L Thompson   // Convert to running sum
202437930d1SJeremy L Thompson   for (CeedInt i = 1; i < size_offsets; ++i)
203437930d1SJeremy L Thompson     t_offsets[i] += t_offsets[i-1];
2040d0321e0SJeremy L Thompson   // List all E-vec indices associated with L-vec node
205437930d1SJeremy L Thompson   for (CeedInt e = 0; e < num_elem; ++e) {
206437930d1SJeremy L Thompson     for (CeedInt i = 0; i < elem_size; ++i) {
207437930d1SJeremy L Thompson       const CeedInt lid = elem_size*e + i;
2080d0321e0SJeremy L Thompson       const CeedInt gid = indices[lid];
209437930d1SJeremy L Thompson       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
2100d0321e0SJeremy L Thompson     }
2110d0321e0SJeremy L Thompson   }
2120d0321e0SJeremy L Thompson   // Reset running sum
213437930d1SJeremy L Thompson   for (int i = size_offsets - 1; i > 0; --i)
214437930d1SJeremy L Thompson     t_offsets[i] = t_offsets[i - 1];
215437930d1SJeremy L Thompson   t_offsets[0] = 0;
2160d0321e0SJeremy L Thompson 
2170d0321e0SJeremy L Thompson   // Copy data to device
2180d0321e0SJeremy L Thompson   // -- L-vector indices
219437930d1SJeremy L Thompson   ierr = hipMalloc((void **)&impl->d_l_vec_indices, num_nodes*sizeof(CeedInt));
2200d0321e0SJeremy L Thompson   CeedChk_Hip(ceed, ierr);
221437930d1SJeremy L Thompson   ierr = hipMemcpy(impl->d_l_vec_indices, l_vec_indices,
222437930d1SJeremy L Thompson                    num_nodes*sizeof(CeedInt), hipMemcpyHostToDevice);
2230d0321e0SJeremy L Thompson   CeedChk_Hip(ceed, ierr);
2240d0321e0SJeremy L Thompson   // -- Transpose offsets
225437930d1SJeremy L Thompson   ierr = hipMalloc((void **)&impl->d_t_offsets, size_offsets*sizeof(CeedInt));
2260d0321e0SJeremy L Thompson   CeedChk_Hip(ceed, ierr);
227437930d1SJeremy L Thompson   ierr = hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets*sizeof(CeedInt),
2280d0321e0SJeremy L Thompson                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
2290d0321e0SJeremy L Thompson   // -- Transpose indices
230437930d1SJeremy L Thompson   ierr = hipMalloc((void **)&impl->d_t_indices, size_indices*sizeof(CeedInt));
2310d0321e0SJeremy L Thompson   CeedChk_Hip(ceed, ierr);
232437930d1SJeremy L Thompson   ierr = hipMemcpy(impl->d_t_indices, t_indices, size_indices*sizeof(CeedInt),
2330d0321e0SJeremy L Thompson                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
2340d0321e0SJeremy L Thompson 
2350d0321e0SJeremy L Thompson   // Cleanup
2360d0321e0SJeremy L Thompson   ierr = CeedFree(&ind_to_offset); CeedChkBackend(ierr);
237437930d1SJeremy L Thompson   ierr = CeedFree(&l_vec_indices); CeedChkBackend(ierr);
238437930d1SJeremy L Thompson   ierr = CeedFree(&t_offsets); CeedChkBackend(ierr);
239437930d1SJeremy L Thompson   ierr = CeedFree(&t_indices); CeedChkBackend(ierr);
240437930d1SJeremy L Thompson 
2410d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
2420d0321e0SJeremy L Thompson }
2430d0321e0SJeremy L Thompson 
2440d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2450d0321e0SJeremy L Thompson // Create restriction
2460d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
2470d0321e0SJeremy L Thompson int CeedElemRestrictionCreate_Hip(CeedMemType mtype, CeedCopyMode cmode,
2480d0321e0SJeremy L Thompson                                   const CeedInt *indices,
2490d0321e0SJeremy L Thompson                                   CeedElemRestriction r) {
2500d0321e0SJeremy L Thompson   int ierr;
2510d0321e0SJeremy L Thompson   Ceed ceed;
2520d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
2530d0321e0SJeremy L Thompson   CeedElemRestriction_Hip *impl;
2540d0321e0SJeremy L Thompson   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
255437930d1SJeremy L Thompson   CeedInt num_elem, num_comp, elem_size;
256437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr);
257437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr);
258437930d1SJeremy L Thompson   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
259437930d1SJeremy L Thompson   CeedInt size = num_elem * elem_size;
260437930d1SJeremy L Thompson   CeedInt strides[3] = {1, size, elem_size};
261437930d1SJeremy L Thompson   CeedInt comp_stride = 1;
2620d0321e0SJeremy L Thompson 
2630d0321e0SJeremy L Thompson   // Stride data
264437930d1SJeremy L Thompson   bool is_strided;
265437930d1SJeremy L Thompson   ierr = CeedElemRestrictionIsStrided(r, &is_strided); CeedChkBackend(ierr);
266437930d1SJeremy L Thompson   if (is_strided) {
267437930d1SJeremy L Thompson     bool has_backend_strides;
268437930d1SJeremy L Thompson     ierr = CeedElemRestrictionHasBackendStrides(r, &has_backend_strides);
2690d0321e0SJeremy L Thompson     CeedChkBackend(ierr);
270437930d1SJeremy L Thompson     if (!has_backend_strides) {
2710d0321e0SJeremy L Thompson       ierr = CeedElemRestrictionGetStrides(r, &strides); CeedChkBackend(ierr);
2720d0321e0SJeremy L Thompson     }
2730d0321e0SJeremy L Thompson   } else {
274437930d1SJeremy L Thompson     ierr = CeedElemRestrictionGetCompStride(r, &comp_stride); CeedChkBackend(ierr);
2750d0321e0SJeremy L Thompson   }
2760d0321e0SJeremy L Thompson 
2770d0321e0SJeremy L Thompson   impl->h_ind           = NULL;
2780d0321e0SJeremy L Thompson   impl->h_ind_allocated = NULL;
2790d0321e0SJeremy L Thompson   impl->d_ind           = NULL;
2800d0321e0SJeremy L Thompson   impl->d_ind_allocated = NULL;
281437930d1SJeremy L Thompson   impl->d_t_indices     = NULL;
282437930d1SJeremy L Thompson   impl->d_t_offsets     = NULL;
283437930d1SJeremy L Thompson   impl->num_nodes = size;
2840d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionSetData(r, impl); CeedChkBackend(ierr);
285437930d1SJeremy L Thompson   CeedInt layout[3] = {1, elem_size*num_elem, elem_size};
2860d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionSetELayout(r, layout); CeedChkBackend(ierr);
2870d0321e0SJeremy L Thompson 
2880d0321e0SJeremy L Thompson   // Set up device indices/offset arrays
2890d0321e0SJeremy L Thompson   if (mtype == CEED_MEM_HOST) {
2900d0321e0SJeremy L Thompson     switch (cmode) {
2910d0321e0SJeremy L Thompson     case CEED_OWN_POINTER:
2920d0321e0SJeremy L Thompson       impl->h_ind_allocated = (CeedInt *)indices;
2930d0321e0SJeremy L Thompson       impl->h_ind = (CeedInt *)indices;
2940d0321e0SJeremy L Thompson       break;
2950d0321e0SJeremy L Thompson     case CEED_USE_POINTER:
2960d0321e0SJeremy L Thompson       impl->h_ind = (CeedInt *)indices;
2970d0321e0SJeremy L Thompson       break;
2980d0321e0SJeremy L Thompson     case CEED_COPY_VALUES:
2990d0321e0SJeremy L Thompson       break;
3000d0321e0SJeremy L Thompson     }
3010d0321e0SJeremy L Thompson     if (indices != NULL) {
3020d0321e0SJeremy L Thompson       ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt));
3030d0321e0SJeremy L Thompson       CeedChk_Hip(ceed, ierr);
3040d0321e0SJeremy L Thompson       impl->d_ind_allocated = impl->d_ind; // We own the device memory
3050d0321e0SJeremy L Thompson       ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt),
3060d0321e0SJeremy L Thompson                        hipMemcpyHostToDevice);
3070d0321e0SJeremy L Thompson       CeedChk_Hip(ceed, ierr);
3080d0321e0SJeremy L Thompson       ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr);
3090d0321e0SJeremy L Thompson     }
3100d0321e0SJeremy L Thompson   } else if (mtype == CEED_MEM_DEVICE) {
3110d0321e0SJeremy L Thompson     switch (cmode) {
3120d0321e0SJeremy L Thompson     case CEED_COPY_VALUES:
3130d0321e0SJeremy L Thompson       if (indices != NULL) {
3140d0321e0SJeremy L Thompson         ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt));
3150d0321e0SJeremy L Thompson         CeedChk_Hip(ceed, ierr);
3160d0321e0SJeremy L Thompson         impl->d_ind_allocated = impl->d_ind; // We own the device memory
3170d0321e0SJeremy L Thompson         ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt),
3180d0321e0SJeremy L Thompson                          hipMemcpyDeviceToDevice);
3190d0321e0SJeremy L Thompson         CeedChk_Hip(ceed, ierr);
3200d0321e0SJeremy L Thompson       }
3210d0321e0SJeremy L Thompson       break;
3220d0321e0SJeremy L Thompson     case CEED_OWN_POINTER:
3230d0321e0SJeremy L Thompson       impl->d_ind = (CeedInt *)indices;
3240d0321e0SJeremy L Thompson       impl->d_ind_allocated = impl->d_ind;
3250d0321e0SJeremy L Thompson       break;
3260d0321e0SJeremy L Thompson     case CEED_USE_POINTER:
3270d0321e0SJeremy L Thompson       impl->d_ind = (CeedInt *)indices;
3280d0321e0SJeremy L Thompson     }
3290d0321e0SJeremy L Thompson     if (indices != NULL) {
3300d0321e0SJeremy L Thompson       ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr);
3310d0321e0SJeremy L Thompson     }
3320d0321e0SJeremy L Thompson   } else {
3330d0321e0SJeremy L Thompson     // LCOV_EXCL_START
3340d0321e0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND,
3350d0321e0SJeremy L Thompson                      "Only MemType = HOST or DEVICE supported");
3360d0321e0SJeremy L Thompson     // LCOV_EXCL_STOP
3370d0321e0SJeremy L Thompson   }
3380d0321e0SJeremy L Thompson 
3390d0321e0SJeremy L Thompson   // Compile HIP kernels
340437930d1SJeremy L Thompson   CeedInt num_nodes = impl->num_nodes;
341437930d1SJeremy L Thompson   char *restriction_kernel_path, *restriction_kernel_source;
342437930d1SJeremy L Thompson   ierr = CeedPathConcatenate(ceed, __FILE__, "kernels/hip-ref-restriction.h",
343437930d1SJeremy L Thompson                              &restriction_kernel_path); CeedChkBackend(ierr);
34446dc0734SJeremy L Thompson   CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source -----\n");
345437930d1SJeremy L Thompson   ierr = CeedLoadSourceToBuffer(ceed, restriction_kernel_path,
346437930d1SJeremy L Thompson                                 &restriction_kernel_source);
347437930d1SJeremy L Thompson   CeedChkBackend(ierr);
34846dc0734SJeremy L Thompson   CeedDebug256(ceed, 2,
34946dc0734SJeremy L Thompson                "----- Loading Restriction Kernel Source Complete! -----\n");
350437930d1SJeremy L Thompson   ierr = CeedCompileHip(ceed, restriction_kernel_source, &impl->module, 8,
351d7d111ecSJeremy L Thompson                         "RESTR_ELEM_SIZE", elem_size,
352d7d111ecSJeremy L Thompson                         "RESTR_NUM_ELEM", num_elem,
353d7d111ecSJeremy L Thompson                         "RESTR_NUM_COMP", num_comp,
354d7d111ecSJeremy L Thompson                         "RESTR_NUM_NODES", num_nodes,
355d7d111ecSJeremy L Thompson                         "RESTR_COMP_STRIDE", comp_stride,
356d7d111ecSJeremy L Thompson                         "RESTR_STRIDE_NODES", strides[0],
357d7d111ecSJeremy L Thompson                         "RESTR_STRIDE_COMP", strides[1],
358d7d111ecSJeremy L Thompson                         "RESTR_STRIDE_ELEM", strides[2]); CeedChkBackend(ierr);
359437930d1SJeremy L Thompson   ierr = CeedGetKernelHip(ceed, impl->module, "StridedNoTranspose",
360437930d1SJeremy L Thompson                           &impl->StridedNoTranspose); CeedChkBackend(ierr);
361437930d1SJeremy L Thompson   ierr = CeedGetKernelHip(ceed, impl->module, "OffsetNoTranspose",
362437930d1SJeremy L Thompson                           &impl->OffsetNoTranspose); CeedChkBackend(ierr);
363437930d1SJeremy L Thompson   ierr = CeedGetKernelHip(ceed, impl->module, "StridedTranspose",
364437930d1SJeremy L Thompson                           &impl->StridedTranspose); CeedChkBackend(ierr);
365437930d1SJeremy L Thompson   ierr = CeedGetKernelHip(ceed, impl->module, "OffsetTranspose",
366437930d1SJeremy L Thompson                           &impl->OffsetTranspose); CeedChkBackend(ierr);
367437930d1SJeremy L Thompson   ierr = CeedFree(&restriction_kernel_path); CeedChkBackend(ierr);
368437930d1SJeremy L Thompson   ierr = CeedFree(&restriction_kernel_source); CeedChkBackend(ierr);
3690d0321e0SJeremy L Thompson 
3700d0321e0SJeremy L Thompson   // Register backend functions
3710d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply",
3720d0321e0SJeremy L Thompson                                 CeedElemRestrictionApply_Hip);
3730d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
3740d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyBlock",
3750d0321e0SJeremy L Thompson                                 CeedElemRestrictionApplyBlock_Hip);
3760d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
3770d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets",
3780d0321e0SJeremy L Thompson                                 CeedElemRestrictionGetOffsets_Hip);
3790d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
3800d0321e0SJeremy L Thompson   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy",
3810d0321e0SJeremy L Thompson                                 CeedElemRestrictionDestroy_Hip);
3820d0321e0SJeremy L Thompson   CeedChkBackend(ierr);
3830d0321e0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
3840d0321e0SJeremy L Thompson }
3850d0321e0SJeremy L Thompson 
3860d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3870d0321e0SJeremy L Thompson // Blocked not supported
3880d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
3890d0321e0SJeremy L Thompson int CeedElemRestrictionCreateBlocked_Hip(const CeedMemType mtype,
3900d0321e0SJeremy L Thompson     const CeedCopyMode cmode, const CeedInt *indices, CeedElemRestriction r) {
3910d0321e0SJeremy L Thompson   int ierr;
3920d0321e0SJeremy L Thompson   Ceed ceed;
3930d0321e0SJeremy L Thompson   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
3940d0321e0SJeremy L Thompson   return CeedError(ceed, CEED_ERROR_BACKEND,
3950d0321e0SJeremy L Thompson                    "Backend does not implement blocked restrictions");
3960d0321e0SJeremy L Thompson }
3970d0321e0SJeremy L Thompson //------------------------------------------------------------------------------
398