1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <stdbool.h> 12 #include <stddef.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Apply restriction 22 //------------------------------------------------------------------------------ 23 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 24 Ceed ceed; 25 Ceed_Hip *data; 26 CeedInt num_elem, elem_size; 27 const CeedScalar *d_u; 28 CeedScalar *d_v; 29 CeedElemRestriction_Hip *impl; 30 hipFunction_t kernel; 31 32 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 33 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 34 CeedCallBackend(CeedGetData(ceed, &data)); 35 CeedElemRestrictionGetNumElements(r, &num_elem); 36 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 37 const CeedInt num_nodes = impl->num_nodes; 38 39 // Get vectors 40 CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 41 if (t_mode == CEED_TRANSPOSE) { 42 // Sum into for transpose mode, e-vec to l-vec 43 CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 44 } else { 45 // Overwrite for notranspose mode, l-vec to e-vec 46 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 47 } 48 49 // Restrict 50 if (t_mode == CEED_NOTRANSPOSE) { 51 // L-vector -> E-vector 52 if (impl->d_ind) { 53 // -- Offsets provided 54 kernel = impl->OffsetNoTranspose; 55 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 56 CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 57 58 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 59 } else { 60 // -- Strided restriction 61 kernel = impl->StridedNoTranspose; 62 void *args[] = {&num_elem, &d_u, &d_v}; 63 CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 64 65 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 66 } 67 } else { 68 // E-vector -> L-vector 69 if (impl->d_ind) { 70 // -- Offsets provided 71 CeedInt block_size = 64; 72 73 if (impl->OffsetTranspose) { 74 kernel = impl->OffsetTranspose; 75 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 76 77 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 78 } else { 79 kernel = impl->OffsetTransposeDet; 80 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 81 82 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 83 } 84 } else { 85 // -- Strided restriction 86 kernel = impl->StridedTranspose; 87 void *args[] = {&num_elem, &d_u, &d_v}; 88 CeedInt block_size = 64; 89 90 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 91 } 92 } 93 94 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 95 96 // Restore arrays 97 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 98 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 99 return CEED_ERROR_SUCCESS; 100 } 101 102 //------------------------------------------------------------------------------ 103 // Get offsets 104 //------------------------------------------------------------------------------ 105 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 106 CeedElemRestriction_Hip *impl; 107 108 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 109 switch (mem_type) { 110 case CEED_MEM_HOST: 111 *offsets = impl->h_ind; 112 break; 113 case CEED_MEM_DEVICE: 114 *offsets = impl->d_ind; 115 break; 116 } 117 return CEED_ERROR_SUCCESS; 118 } 119 120 //------------------------------------------------------------------------------ 121 // Destroy restriction 122 //------------------------------------------------------------------------------ 123 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) { 124 Ceed ceed; 125 CeedElemRestriction_Hip *impl; 126 127 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 128 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 129 CeedCallHip(ceed, hipModuleUnload(impl->module)); 130 CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 131 CeedCallHip(ceed, hipFree(impl->d_ind_allocated)); 132 CeedCallHip(ceed, hipFree(impl->d_t_offsets)); 133 CeedCallHip(ceed, hipFree(impl->d_t_indices)); 134 CeedCallHip(ceed, hipFree(impl->d_l_vec_indices)); 135 CeedCallBackend(CeedFree(&impl)); 136 return CEED_ERROR_SUCCESS; 137 } 138 139 //------------------------------------------------------------------------------ 140 // Create transpose offsets and indices 141 //------------------------------------------------------------------------------ 142 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r, const CeedInt *indices) { 143 Ceed ceed; 144 bool *is_node; 145 CeedSize l_size; 146 CeedInt num_elem, elem_size, num_comp, num_nodes = 0, *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 147 CeedElemRestriction_Hip *impl; 148 149 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 150 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 151 CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 152 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 153 CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 154 CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 155 const CeedInt size_indices = num_elem * elem_size; 156 157 // Count num_nodes 158 CeedCallBackend(CeedCalloc(l_size, &is_node)); 159 for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 160 for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 161 impl->num_nodes = num_nodes; 162 163 // L-vector offsets array 164 CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 165 CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 166 for (CeedInt i = 0, j = 0; i < l_size; i++) { 167 if (is_node[i]) { 168 l_vec_indices[j] = i; 169 ind_to_offset[i] = j++; 170 } 171 } 172 CeedCallBackend(CeedFree(&is_node)); 173 174 // Compute transpose offsets and indices 175 const CeedInt size_offsets = num_nodes + 1; 176 177 CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 178 CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 179 // Count node multiplicity 180 for (CeedInt e = 0; e < num_elem; ++e) { 181 for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 182 } 183 // Convert to running sum 184 for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 185 // List all E-vec indices associated with L-vec node 186 for (CeedInt e = 0; e < num_elem; ++e) { 187 for (CeedInt i = 0; i < elem_size; ++i) { 188 const CeedInt lid = elem_size * e + i; 189 const CeedInt gid = indices[lid]; 190 191 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 192 } 193 } 194 // Reset running sum 195 for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 196 t_offsets[0] = 0; 197 198 // Copy data to device 199 // -- L-vector indices 200 CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 201 CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice)); 202 // -- Transpose offsets 203 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 204 CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice)); 205 // -- Transpose indices 206 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 207 CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice)); 208 209 // Cleanup 210 CeedCallBackend(CeedFree(&ind_to_offset)); 211 CeedCallBackend(CeedFree(&l_vec_indices)); 212 CeedCallBackend(CeedFree(&t_offsets)); 213 CeedCallBackend(CeedFree(&t_indices)); 214 return CEED_ERROR_SUCCESS; 215 } 216 217 //------------------------------------------------------------------------------ 218 // Create restriction 219 //------------------------------------------------------------------------------ 220 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients, 221 const CeedInt8 *curl_orients, CeedElemRestriction r) { 222 Ceed ceed, ceed_parent; 223 bool is_deterministic, is_strided; 224 char *restriction_kernel_path, *restriction_kernel_source; 225 CeedInt num_elem, num_comp, elem_size, comp_stride = 1; 226 CeedRestrictionType rstr_type; 227 CeedElemRestriction_Hip *impl; 228 229 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 230 CeedCallBackend(CeedCalloc(1, &impl)); 231 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 232 CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic)); 233 CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 234 CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 235 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 236 CeedInt size = num_elem * elem_size; 237 CeedInt strides[3] = {1, size, elem_size}; 238 CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 239 240 CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type)); 241 CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND, 242 "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented"); 243 244 // Stride data 245 CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided)); 246 if (is_strided) { 247 bool has_backend_strides; 248 249 CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides)); 250 if (!has_backend_strides) { 251 CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 252 } 253 } else { 254 CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 255 } 256 257 impl->h_ind = NULL; 258 impl->h_ind_allocated = NULL; 259 impl->d_ind = NULL; 260 impl->d_ind_allocated = NULL; 261 impl->d_t_indices = NULL; 262 impl->d_t_offsets = NULL; 263 impl->num_nodes = size; 264 CeedCallBackend(CeedElemRestrictionSetData(r, impl)); 265 CeedCallBackend(CeedElemRestrictionSetELayout(r, layout)); 266 267 // Set up device indices/offset arrays 268 switch (mem_type) { 269 case CEED_MEM_HOST: { 270 switch (copy_mode) { 271 case CEED_OWN_POINTER: 272 impl->h_ind_allocated = (CeedInt *)indices; 273 impl->h_ind = (CeedInt *)indices; 274 break; 275 case CEED_USE_POINTER: 276 impl->h_ind = (CeedInt *)indices; 277 break; 278 case CEED_COPY_VALUES: 279 if (indices != NULL) { 280 CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 281 memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt)); 282 impl->h_ind = impl->h_ind_allocated; 283 } 284 break; 285 } 286 if (indices != NULL) { 287 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 288 impl->d_ind_allocated = impl->d_ind; // We own the device memory 289 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice)); 290 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices)); 291 } 292 break; 293 } 294 case CEED_MEM_DEVICE: { 295 switch (copy_mode) { 296 case CEED_COPY_VALUES: 297 if (indices != NULL) { 298 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 299 impl->d_ind_allocated = impl->d_ind; // We own the device memory 300 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice)); 301 } 302 break; 303 case CEED_OWN_POINTER: 304 impl->d_ind = (CeedInt *)indices; 305 impl->d_ind_allocated = impl->d_ind; 306 break; 307 case CEED_USE_POINTER: 308 impl->d_ind = (CeedInt *)indices; 309 } 310 if (indices != NULL) { 311 CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 312 CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost)); 313 impl->h_ind = impl->h_ind_allocated; 314 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices)); 315 } 316 break; 317 } 318 // LCOV_EXCL_START 319 default: 320 return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported"); 321 // LCOV_EXCL_STOP 322 } 323 324 // Compile HIP kernels 325 CeedInt num_nodes = impl->num_nodes; 326 327 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path)); 328 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 329 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 330 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 331 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem, 332 "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES", 333 strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2])); 334 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose)); 335 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose)); 336 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose)); 337 if (!is_deterministic) { 338 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose)); 339 } else { 340 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTransposeDet", &impl->OffsetTransposeDet)); 341 } 342 CeedCallBackend(CeedFree(&restriction_kernel_path)); 343 CeedCallBackend(CeedFree(&restriction_kernel_source)); 344 345 // Register backend functions 346 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Hip)); 347 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Hip)); 348 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Hip)); 349 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Hip)); 350 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Hip)); 351 return CEED_ERROR_SUCCESS; 352 } 353 354 //------------------------------------------------------------------------------ 355