13d8e8822SJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 23d8e8822SJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 30d0321e0SJeremy L Thompson // 43d8e8822SJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause 50d0321e0SJeremy L Thompson // 63d8e8822SJeremy L Thompson // This file is part of CEED: http://github.com/ceed 70d0321e0SJeremy L Thompson 849aac155SJeremy L Thompson #include <ceed.h> 90d0321e0SJeremy L Thompson #include <ceed/backend.h> 10437930d1SJeremy L Thompson #include <ceed/jit-tools.h> 110d0321e0SJeremy L Thompson #include <stdbool.h> 120d0321e0SJeremy L Thompson #include <stddef.h> 1344d7a66cSJeremy L Thompson #include <string.h> 14c85e8640SSebastian Grimberg #include <hip/hip_runtime.h> 152b730f8bSJeremy L Thompson 1649aac155SJeremy L Thompson #include "../hip/ceed-hip-common.h" 170d0321e0SJeremy L Thompson #include "../hip/ceed-hip-compile.h" 182b730f8bSJeremy L Thompson #include "ceed-hip-ref.h" 190d0321e0SJeremy L Thompson 200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 210d0321e0SJeremy L Thompson // Apply restriction 220d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 232b730f8bSJeremy L Thompson static int CeedElemRestrictionApply_Hip(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 240d0321e0SJeremy L Thompson Ceed ceed; 250d0321e0SJeremy L Thompson Ceed_Hip *data; 26437930d1SJeremy L Thompson CeedInt num_elem, elem_size; 270d0321e0SJeremy L Thompson const CeedScalar *d_u; 280d0321e0SJeremy L Thompson CeedScalar *d_v; 29*b7453713SJeremy L Thompson CeedElemRestriction_Hip *impl; 30*b7453713SJeremy L Thompson hipFunction_t kernel; 31*b7453713SJeremy L Thompson 32*b7453713SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 33*b7453713SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 34*b7453713SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 35*b7453713SJeremy L Thompson CeedElemRestrictionGetNumElements(r, &num_elem); 36*b7453713SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 37*b7453713SJeremy L Thompson const CeedInt num_nodes = impl->num_nodes; 38*b7453713SJeremy L Thompson 39*b7453713SJeremy L Thompson // Get vectors 402b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 41437930d1SJeremy L Thompson if (t_mode == CEED_TRANSPOSE) { 420d0321e0SJeremy L Thompson // Sum into for transpose mode, e-vec to l-vec 432b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 440d0321e0SJeremy L Thompson } else { 450d0321e0SJeremy L Thompson // Overwrite for notranspose mode, l-vec to e-vec 462b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 470d0321e0SJeremy L Thompson } 480d0321e0SJeremy L Thompson 490d0321e0SJeremy L Thompson // Restrict 50437930d1SJeremy L Thompson if (t_mode == CEED_NOTRANSPOSE) { 510d0321e0SJeremy L Thompson // L-vector -> E-vector 520d0321e0SJeremy L Thompson if (impl->d_ind) { 530d0321e0SJeremy L Thompson // -- Offsets provided 54437930d1SJeremy L Thompson kernel = impl->OffsetNoTranspose; 55437930d1SJeremy L Thompson void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 56437930d1SJeremy L Thompson CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 5758549094SSebastian Grimberg 58eb7e6cafSJeremy L Thompson CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 590d0321e0SJeremy L Thompson } else { 600d0321e0SJeremy L Thompson // -- Strided restriction 61437930d1SJeremy L Thompson kernel = impl->StridedNoTranspose; 62437930d1SJeremy L Thompson void *args[] = {&num_elem, &d_u, &d_v}; 63437930d1SJeremy L Thompson CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 6458549094SSebastian Grimberg 65eb7e6cafSJeremy L Thompson CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 660d0321e0SJeremy L Thompson } 670d0321e0SJeremy L Thompson } else { 680d0321e0SJeremy L Thompson // E-vector -> L-vector 690d0321e0SJeremy L Thompson if (impl->d_ind) { 700d0321e0SJeremy L Thompson // -- Offsets provided 7158549094SSebastian Grimberg CeedInt block_size = 64; 72*b7453713SJeremy L Thompson 7358549094SSebastian Grimberg if (impl->OffsetTranspose) { 74437930d1SJeremy L Thompson kernel = impl->OffsetTranspose; 7558549094SSebastian Grimberg void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 7658549094SSebastian Grimberg 77eb7e6cafSJeremy L Thompson CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 780d0321e0SJeremy L Thompson } else { 7958549094SSebastian Grimberg kernel = impl->OffsetTransposeDet; 8058549094SSebastian Grimberg void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 8158549094SSebastian Grimberg 8258549094SSebastian Grimberg CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 8358549094SSebastian Grimberg } 8458549094SSebastian Grimberg } else { 850d0321e0SJeremy L Thompson // -- Strided restriction 86437930d1SJeremy L Thompson kernel = impl->StridedTranspose; 87437930d1SJeremy L Thompson void *args[] = {&num_elem, &d_u, &d_v}; 8858549094SSebastian Grimberg CeedInt block_size = 64; 8958549094SSebastian Grimberg 90eb7e6cafSJeremy L Thompson CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 910d0321e0SJeremy L Thompson } 920d0321e0SJeremy L Thompson } 930d0321e0SJeremy L Thompson 942b730f8bSJeremy L Thompson if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 950d0321e0SJeremy L Thompson 960d0321e0SJeremy L Thompson // Restore arrays 972b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 982b730f8bSJeremy L Thompson CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 990d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1000d0321e0SJeremy L Thompson } 1010d0321e0SJeremy L Thompson 1020d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1030d0321e0SJeremy L Thompson // Get offsets 1040d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 105472941f0SJeremy L Thompson static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 1060d0321e0SJeremy L Thompson CeedElemRestriction_Hip *impl; 1070d0321e0SJeremy L Thompson 108*b7453713SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 109472941f0SJeremy L Thompson switch (mem_type) { 1100d0321e0SJeremy L Thompson case CEED_MEM_HOST: 1110d0321e0SJeremy L Thompson *offsets = impl->h_ind; 1120d0321e0SJeremy L Thompson break; 1130d0321e0SJeremy L Thompson case CEED_MEM_DEVICE: 1140d0321e0SJeremy L Thompson *offsets = impl->d_ind; 1150d0321e0SJeremy L Thompson break; 1160d0321e0SJeremy L Thompson } 1170d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1180d0321e0SJeremy L Thompson } 1190d0321e0SJeremy L Thompson 1200d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1210d0321e0SJeremy L Thompson // Destroy restriction 1220d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1230d0321e0SJeremy L Thompson static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) { 1240d0321e0SJeremy L Thompson Ceed ceed; 125*b7453713SJeremy L Thompson CeedElemRestriction_Hip *impl; 126*b7453713SJeremy L Thompson 127*b7453713SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 1282b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 1292b730f8bSJeremy L Thompson CeedCallHip(ceed, hipModuleUnload(impl->module)); 1302b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 1312b730f8bSJeremy L Thompson CeedCallHip(ceed, hipFree(impl->d_ind_allocated)); 1322b730f8bSJeremy L Thompson CeedCallHip(ceed, hipFree(impl->d_t_offsets)); 1332b730f8bSJeremy L Thompson CeedCallHip(ceed, hipFree(impl->d_t_indices)); 1342b730f8bSJeremy L Thompson CeedCallHip(ceed, hipFree(impl->d_l_vec_indices)); 1352b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&impl)); 1360d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 1370d0321e0SJeremy L Thompson } 1380d0321e0SJeremy L Thompson 1390d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1400d0321e0SJeremy L Thompson // Create transpose offsets and indices 1410d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 1422b730f8bSJeremy L Thompson static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r, const CeedInt *indices) { 1430d0321e0SJeremy L Thompson Ceed ceed; 144*b7453713SJeremy L Thompson bool *is_node; 145e79b91d9SJeremy L Thompson CeedSize l_size; 146*b7453713SJeremy L Thompson CeedInt num_elem, elem_size, num_comp, num_nodes = 0, *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 147*b7453713SJeremy L Thompson CeedElemRestriction_Hip *impl; 148*b7453713SJeremy L Thompson 149*b7453713SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 150*b7453713SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 1512b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 1522b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 1532b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 1542b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 155*b7453713SJeremy L Thompson const CeedInt size_indices = num_elem * elem_size; 1560d0321e0SJeremy L Thompson 157437930d1SJeremy L Thompson // Count num_nodes 1582b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(l_size, &is_node)); 1592b730f8bSJeremy L Thompson for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 1602b730f8bSJeremy L Thompson for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 161437930d1SJeremy L Thompson impl->num_nodes = num_nodes; 1620d0321e0SJeremy L Thompson 1630d0321e0SJeremy L Thompson // L-vector offsets array 1642b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 1652b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 166*b7453713SJeremy L Thompson for (CeedInt i = 0, j = 0; i < l_size; i++) { 167437930d1SJeremy L Thompson if (is_node[i]) { 168437930d1SJeremy L Thompson l_vec_indices[j] = i; 1690d0321e0SJeremy L Thompson ind_to_offset[i] = j++; 1700d0321e0SJeremy L Thompson } 1712b730f8bSJeremy L Thompson } 1722b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&is_node)); 1730d0321e0SJeremy L Thompson 1740d0321e0SJeremy L Thompson // Compute transpose offsets and indices 175437930d1SJeremy L Thompson const CeedInt size_offsets = num_nodes + 1; 176*b7453713SJeremy L Thompson 1772b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 1782b730f8bSJeremy L Thompson CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 1790d0321e0SJeremy L Thompson // Count node multiplicity 1802b730f8bSJeremy L Thompson for (CeedInt e = 0; e < num_elem; ++e) { 1812b730f8bSJeremy L Thompson for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 1822b730f8bSJeremy L Thompson } 1830d0321e0SJeremy L Thompson // Convert to running sum 1842b730f8bSJeremy L Thompson for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 1850d0321e0SJeremy L Thompson // List all E-vec indices associated with L-vec node 186437930d1SJeremy L Thompson for (CeedInt e = 0; e < num_elem; ++e) { 187437930d1SJeremy L Thompson for (CeedInt i = 0; i < elem_size; ++i) { 188437930d1SJeremy L Thompson const CeedInt lid = elem_size * e + i; 1890d0321e0SJeremy L Thompson const CeedInt gid = indices[lid]; 190*b7453713SJeremy L Thompson 191437930d1SJeremy L Thompson t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 1920d0321e0SJeremy L Thompson } 1930d0321e0SJeremy L Thompson } 1940d0321e0SJeremy L Thompson // Reset running sum 1952b730f8bSJeremy L Thompson for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 196437930d1SJeremy L Thompson t_offsets[0] = 0; 1970d0321e0SJeremy L Thompson 1980d0321e0SJeremy L Thompson // Copy data to device 1990d0321e0SJeremy L Thompson // -- L-vector indices 2002b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 2012b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice)); 2020d0321e0SJeremy L Thompson // -- Transpose offsets 2032b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 2042b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice)); 2050d0321e0SJeremy L Thompson // -- Transpose indices 2062b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 2072b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice)); 2080d0321e0SJeremy L Thompson 2090d0321e0SJeremy L Thompson // Cleanup 2102b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&ind_to_offset)); 2112b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&l_vec_indices)); 2122b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&t_offsets)); 2132b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&t_indices)); 2140d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 2150d0321e0SJeremy L Thompson } 2160d0321e0SJeremy L Thompson 2170d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 2180d0321e0SJeremy L Thompson // Create restriction 2190d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 220fcbe8c06SSebastian Grimberg int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients, 2210c73c039SSebastian Grimberg const CeedInt8 *curl_orients, CeedElemRestriction r) { 222*b7453713SJeremy L Thompson Ceed ceed, ceed_parent; 223*b7453713SJeremy L Thompson bool is_deterministic, is_strided; 224*b7453713SJeremy L Thompson char *restriction_kernel_path, *restriction_kernel_source; 225*b7453713SJeremy L Thompson CeedInt num_elem, num_comp, elem_size, comp_stride = 1; 226*b7453713SJeremy L Thompson CeedRestrictionType rstr_type; 2270d0321e0SJeremy L Thompson CeedElemRestriction_Hip *impl; 228*b7453713SJeremy L Thompson 229*b7453713SJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 2302b730f8bSJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl)); 231ca735530SJeremy L Thompson CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 232ca735530SJeremy L Thompson CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic)); 2332b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 2342b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 2352b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 236437930d1SJeremy L Thompson CeedInt size = num_elem * elem_size; 237437930d1SJeremy L Thompson CeedInt strides[3] = {1, size, elem_size}; 238*b7453713SJeremy L Thompson CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 2390d0321e0SJeremy L Thompson 24000125730SSebastian Grimberg CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type)); 24100125730SSebastian Grimberg CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND, 24200125730SSebastian Grimberg "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented"); 24300125730SSebastian Grimberg 2440d0321e0SJeremy L Thompson // Stride data 2452b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided)); 246437930d1SJeremy L Thompson if (is_strided) { 247437930d1SJeremy L Thompson bool has_backend_strides; 248*b7453713SJeremy L Thompson 2492b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides)); 250437930d1SJeremy L Thompson if (!has_backend_strides) { 2512b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 2520d0321e0SJeremy L Thompson } 2530d0321e0SJeremy L Thompson } else { 2542b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 2550d0321e0SJeremy L Thompson } 2560d0321e0SJeremy L Thompson 2570d0321e0SJeremy L Thompson impl->h_ind = NULL; 2580d0321e0SJeremy L Thompson impl->h_ind_allocated = NULL; 2590d0321e0SJeremy L Thompson impl->d_ind = NULL; 2600d0321e0SJeremy L Thompson impl->d_ind_allocated = NULL; 261437930d1SJeremy L Thompson impl->d_t_indices = NULL; 262437930d1SJeremy L Thompson impl->d_t_offsets = NULL; 263437930d1SJeremy L Thompson impl->num_nodes = size; 2642b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionSetData(r, impl)); 2652b730f8bSJeremy L Thompson CeedCallBackend(CeedElemRestrictionSetELayout(r, layout)); 2660d0321e0SJeremy L Thompson 2670d0321e0SJeremy L Thompson // Set up device indices/offset arrays 268472941f0SJeremy L Thompson switch (mem_type) { 2696574a04fSJeremy L Thompson case CEED_MEM_HOST: { 270472941f0SJeremy L Thompson switch (copy_mode) { 2710d0321e0SJeremy L Thompson case CEED_OWN_POINTER: 2720d0321e0SJeremy L Thompson impl->h_ind_allocated = (CeedInt *)indices; 2730d0321e0SJeremy L Thompson impl->h_ind = (CeedInt *)indices; 2740d0321e0SJeremy L Thompson break; 2750d0321e0SJeremy L Thompson case CEED_USE_POINTER: 2760d0321e0SJeremy L Thompson impl->h_ind = (CeedInt *)indices; 2770d0321e0SJeremy L Thompson break; 2780d0321e0SJeremy L Thompson case CEED_COPY_VALUES: 27944d7a66cSJeremy L Thompson if (indices != NULL) { 2802b730f8bSJeremy L Thompson CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 28144d7a66cSJeremy L Thompson memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt)); 28244d7a66cSJeremy L Thompson impl->h_ind = impl->h_ind_allocated; 28344d7a66cSJeremy L Thompson } 2840d0321e0SJeremy L Thompson break; 2850d0321e0SJeremy L Thompson } 2860d0321e0SJeremy L Thompson if (indices != NULL) { 2872b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 2880d0321e0SJeremy L Thompson impl->d_ind_allocated = impl->d_ind; // We own the device memory 2892b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice)); 29058549094SSebastian Grimberg if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices)); 2910d0321e0SJeremy L Thompson } 2926574a04fSJeremy L Thompson break; 2936574a04fSJeremy L Thompson } 2946574a04fSJeremy L Thompson case CEED_MEM_DEVICE: { 295472941f0SJeremy L Thompson switch (copy_mode) { 2960d0321e0SJeremy L Thompson case CEED_COPY_VALUES: 2970d0321e0SJeremy L Thompson if (indices != NULL) { 2982b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 2990d0321e0SJeremy L Thompson impl->d_ind_allocated = impl->d_ind; // We own the device memory 3002b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice)); 3010d0321e0SJeremy L Thompson } 3020d0321e0SJeremy L Thompson break; 3030d0321e0SJeremy L Thompson case CEED_OWN_POINTER: 3040d0321e0SJeremy L Thompson impl->d_ind = (CeedInt *)indices; 3050d0321e0SJeremy L Thompson impl->d_ind_allocated = impl->d_ind; 3060d0321e0SJeremy L Thompson break; 3070d0321e0SJeremy L Thompson case CEED_USE_POINTER: 3080d0321e0SJeremy L Thompson impl->d_ind = (CeedInt *)indices; 3090d0321e0SJeremy L Thompson } 3100d0321e0SJeremy L Thompson if (indices != NULL) { 3112b730f8bSJeremy L Thompson CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 3122b730f8bSJeremy L Thompson CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost)); 31344d7a66cSJeremy L Thompson impl->h_ind = impl->h_ind_allocated; 31458549094SSebastian Grimberg if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices)); 3150d0321e0SJeremy L Thompson } 3166574a04fSJeremy L Thompson break; 3176574a04fSJeremy L Thompson } 3180d0321e0SJeremy L Thompson // LCOV_EXCL_START 3196574a04fSJeremy L Thompson default: 3202b730f8bSJeremy L Thompson return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported"); 3210d0321e0SJeremy L Thompson // LCOV_EXCL_STOP 3220d0321e0SJeremy L Thompson } 3230d0321e0SJeremy L Thompson 3240d0321e0SJeremy L Thompson // Compile HIP kernels 325437930d1SJeremy L Thompson CeedInt num_nodes = impl->num_nodes; 326*b7453713SJeremy L Thompson 3272b730f8bSJeremy L Thompson CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path)); 32823d4529eSJeremy L Thompson CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 3292b730f8bSJeremy L Thompson CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 33023d4529eSJeremy L Thompson CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 331eb7e6cafSJeremy L Thompson CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem, 3322b730f8bSJeremy L Thompson "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES", 3332b730f8bSJeremy L Thompson strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2])); 334eb7e6cafSJeremy L Thompson CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose)); 335eb7e6cafSJeremy L Thompson CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose)); 33658549094SSebastian Grimberg CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose)); 33758549094SSebastian Grimberg if (!is_deterministic) { 338eb7e6cafSJeremy L Thompson CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose)); 33958549094SSebastian Grimberg } else { 34058549094SSebastian Grimberg CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTransposeDet", &impl->OffsetTransposeDet)); 34158549094SSebastian Grimberg } 3422b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&restriction_kernel_path)); 3432b730f8bSJeremy L Thompson CeedCallBackend(CeedFree(&restriction_kernel_source)); 3440d0321e0SJeremy L Thompson 3450d0321e0SJeremy L Thompson // Register backend functions 3462b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Hip)); 347b17517eeSSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Hip)); 3487c1dbaffSSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Hip)); 3492b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Hip)); 3502b730f8bSJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Hip)); 3510d0321e0SJeremy L Thompson return CEED_ERROR_SUCCESS; 3520d0321e0SJeremy L Thompson } 3530d0321e0SJeremy L Thompson 3540d0321e0SJeremy L Thompson //------------------------------------------------------------------------------ 355