1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <stdbool.h> 12 #include <stddef.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Compile restriction kernels 22 //------------------------------------------------------------------------------ 23 static inline int CeedElemRestrictionSetupCompile_Hip(CeedElemRestriction rstr) { 24 Ceed ceed; 25 bool is_deterministic; 26 char *restriction_kernel_source; 27 const char *restriction_kernel_path; 28 CeedInt num_elem, num_comp, elem_size, comp_stride; 29 CeedRestrictionType rstr_type; 30 CeedElemRestriction_Hip *impl; 31 32 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 33 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 34 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 35 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 36 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 37 CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride)); 38 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 39 is_deterministic = impl->d_l_vec_indices != NULL; 40 41 // Compile HIP kernels 42 switch (rstr_type) { 43 case CEED_RESTRICTION_STRIDED: { 44 bool has_backend_strides; 45 CeedInt strides[3] = {1, num_elem * elem_size, elem_size}; 46 47 CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 48 if (!has_backend_strides) { 49 CeedCallBackend(CeedElemRestrictionGetStrides(rstr, strides)); 50 } 51 52 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-strided.h", &restriction_kernel_path)); 53 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 54 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 55 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 56 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 57 "RSTR_NUM_COMP", num_comp, "RSTR_STRIDE_NODES", strides[0], "RSTR_STRIDE_COMP", strides[1], "RSTR_STRIDE_ELEM", 58 strides[2])); 59 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->ApplyNoTranspose)); 60 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->ApplyTranspose)); 61 } break; 62 case CEED_RESTRICTION_STANDARD: { 63 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &restriction_kernel_path)); 64 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 65 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 66 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 67 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 68 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 69 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 70 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyNoTranspose)); 71 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyTranspose)); 72 } break; 73 case CEED_RESTRICTION_ORIENTED: { 74 const char *offset_kernel_path; 75 char **file_paths = NULL; 76 CeedInt num_file_paths = 0; 77 78 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-oriented.h", &restriction_kernel_path)); 79 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 80 CeedCallBackend(CeedLoadSourceAndInitializeBuffer(ceed, restriction_kernel_path, &num_file_paths, &file_paths, &restriction_kernel_source)); 81 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &offset_kernel_path)); 82 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, offset_kernel_path, &num_file_paths, &file_paths, &restriction_kernel_source)); 83 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 84 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 85 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 86 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 87 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedNoTranspose", &impl->ApplyNoTranspose)); 88 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyUnsignedNoTranspose)); 89 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedTranspose", &impl->ApplyTranspose)); 90 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyUnsignedTranspose)); 91 // Cleanup 92 CeedCallBackend(CeedFree(&offset_kernel_path)); 93 for (CeedInt i = 0; i < num_file_paths; i++) CeedCall(CeedFree(&file_paths[i])); 94 CeedCall(CeedFree(&file_paths)); 95 } break; 96 case CEED_RESTRICTION_CURL_ORIENTED: { 97 const char *offset_kernel_path; 98 char **file_paths = NULL; 99 CeedInt num_file_paths = 0; 100 101 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-curl-oriented.h", &restriction_kernel_path)); 102 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 103 CeedCallBackend(CeedLoadSourceAndInitializeBuffer(ceed, restriction_kernel_path, &num_file_paths, &file_paths, &restriction_kernel_source)); 104 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &offset_kernel_path)); 105 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, offset_kernel_path, &num_file_paths, &file_paths, &restriction_kernel_source)); 106 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 107 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 108 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 109 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 110 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedNoTranspose", &impl->ApplyNoTranspose)); 111 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedNoTranspose", &impl->ApplyUnsignedNoTranspose)); 112 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyUnorientedNoTranspose)); 113 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedTranspose", &impl->ApplyTranspose)); 114 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedTranspose", &impl->ApplyUnsignedTranspose)); 115 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyUnorientedTranspose)); 116 // Cleanup 117 CeedCallBackend(CeedFree(&offset_kernel_path)); 118 for (CeedInt i = 0; i < num_file_paths; i++) CeedCall(CeedFree(&file_paths[i])); 119 CeedCall(CeedFree(&file_paths)); 120 } break; 121 case CEED_RESTRICTION_POINTS: { 122 // LCOV_EXCL_START 123 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 124 // LCOV_EXCL_STOP 125 } break; 126 } 127 CeedCallBackend(CeedFree(&restriction_kernel_path)); 128 CeedCallBackend(CeedFree(&restriction_kernel_source)); 129 return CEED_ERROR_SUCCESS; 130 } 131 132 //------------------------------------------------------------------------------ 133 // Core apply restriction code 134 //------------------------------------------------------------------------------ 135 static inline int CeedElemRestrictionApply_Hip_Core(CeedElemRestriction rstr, CeedTransposeMode t_mode, bool use_signs, bool use_orients, 136 CeedVector u, CeedVector v, CeedRequest *request) { 137 Ceed ceed; 138 CeedRestrictionType rstr_type; 139 const CeedScalar *d_u; 140 CeedScalar *d_v; 141 CeedElemRestriction_Hip *impl; 142 143 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 144 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 145 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 146 147 // Assemble kernel if needed 148 if (!impl->module) { 149 CeedCallBackend(CeedElemRestrictionSetupCompile_Hip(rstr)); 150 } 151 152 // Get vectors 153 CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 154 if (t_mode == CEED_TRANSPOSE) { 155 // Sum into for transpose mode, e-vec to l-vec 156 CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 157 } else { 158 // Overwrite for notranspose mode, l-vec to e-vec 159 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 160 } 161 162 // Restrict 163 if (t_mode == CEED_NOTRANSPOSE) { 164 // L-vector -> E-vector 165 CeedInt elem_size; 166 167 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 168 const CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 169 const CeedInt grid = CeedDivUpInt(impl->num_nodes, block_size); 170 171 switch (rstr_type) { 172 case CEED_RESTRICTION_STRIDED: { 173 void *args[] = {&d_u, &d_v}; 174 175 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 176 } break; 177 case CEED_RESTRICTION_STANDARD: { 178 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 179 180 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 181 } break; 182 case CEED_RESTRICTION_ORIENTED: { 183 if (use_signs) { 184 void *args[] = {&impl->d_offsets, &impl->d_orients, &d_u, &d_v}; 185 186 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 187 } else { 188 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 189 190 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedNoTranspose, grid, block_size, args)); 191 } 192 } break; 193 case CEED_RESTRICTION_CURL_ORIENTED: { 194 if (use_signs && use_orients) { 195 void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v}; 196 197 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 198 } else if (use_orients) { 199 void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v}; 200 201 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedNoTranspose, grid, block_size, args)); 202 } else { 203 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 204 205 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedNoTranspose, grid, block_size, args)); 206 } 207 } break; 208 case CEED_RESTRICTION_POINTS: { 209 // LCOV_EXCL_START 210 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 211 // LCOV_EXCL_STOP 212 } break; 213 } 214 } else { 215 // E-vector -> L-vector 216 const bool is_deterministic = impl->d_l_vec_indices != NULL; 217 const CeedInt block_size = 64; 218 const CeedInt grid = CeedDivUpInt(impl->num_nodes, block_size); 219 220 switch (rstr_type) { 221 case CEED_RESTRICTION_STRIDED: { 222 void *args[] = {&d_u, &d_v}; 223 224 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 225 } break; 226 case CEED_RESTRICTION_STANDARD: { 227 if (!is_deterministic) { 228 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 229 230 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 231 } else { 232 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 233 234 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 235 } 236 } break; 237 case CEED_RESTRICTION_ORIENTED: { 238 if (use_signs) { 239 if (!is_deterministic) { 240 void *args[] = {&impl->d_offsets, &impl->d_orients, &d_u, &d_v}; 241 242 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 243 } else { 244 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_orients, &d_u, &d_v}; 245 246 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 247 } 248 } else { 249 if (!is_deterministic) { 250 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 251 252 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 253 } else { 254 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 255 256 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 257 } 258 } 259 } break; 260 case CEED_RESTRICTION_CURL_ORIENTED: { 261 if (use_signs && use_orients) { 262 if (!is_deterministic) { 263 void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v}; 264 265 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 266 } else { 267 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_curl_orients, &d_u, &d_v}; 268 269 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 270 } 271 } else if (use_orients) { 272 if (!is_deterministic) { 273 void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v}; 274 275 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 276 } else { 277 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_curl_orients, &d_u, &d_v}; 278 279 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 280 } 281 } else { 282 if (!is_deterministic) { 283 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 284 285 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedTranspose, grid, block_size, args)); 286 } else { 287 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 288 289 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedTranspose, grid, block_size, args)); 290 } 291 } 292 } break; 293 case CEED_RESTRICTION_POINTS: { 294 // LCOV_EXCL_START 295 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 296 // LCOV_EXCL_STOP 297 } break; 298 } 299 } 300 301 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 302 303 // Restore arrays 304 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 305 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 306 return CEED_ERROR_SUCCESS; 307 } 308 309 //------------------------------------------------------------------------------ 310 // Apply restriction 311 //------------------------------------------------------------------------------ 312 static int CeedElemRestrictionApply_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 313 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, true, true, u, v, request); 314 } 315 316 //------------------------------------------------------------------------------ 317 // Apply unsigned restriction 318 //------------------------------------------------------------------------------ 319 static int CeedElemRestrictionApplyUnsigned_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 320 CeedRequest *request) { 321 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, true, u, v, request); 322 } 323 324 //------------------------------------------------------------------------------ 325 // Apply unoriented restriction 326 //------------------------------------------------------------------------------ 327 static int CeedElemRestrictionApplyUnoriented_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 328 CeedRequest *request) { 329 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, false, u, v, request); 330 } 331 332 //------------------------------------------------------------------------------ 333 // Get offsets 334 //------------------------------------------------------------------------------ 335 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 336 CeedElemRestriction_Hip *impl; 337 338 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 339 switch (mem_type) { 340 case CEED_MEM_HOST: 341 *offsets = impl->h_offsets; 342 break; 343 case CEED_MEM_DEVICE: 344 *offsets = impl->d_offsets; 345 break; 346 } 347 return CEED_ERROR_SUCCESS; 348 } 349 350 //------------------------------------------------------------------------------ 351 // Get orientations 352 //------------------------------------------------------------------------------ 353 static int CeedElemRestrictionGetOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const bool **orients) { 354 CeedElemRestriction_Hip *impl; 355 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 356 357 switch (mem_type) { 358 case CEED_MEM_HOST: 359 *orients = impl->h_orients; 360 break; 361 case CEED_MEM_DEVICE: 362 *orients = impl->d_orients; 363 break; 364 } 365 return CEED_ERROR_SUCCESS; 366 } 367 368 //------------------------------------------------------------------------------ 369 // Get curl-conforming orientations 370 //------------------------------------------------------------------------------ 371 static int CeedElemRestrictionGetCurlOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt8 **curl_orients) { 372 CeedElemRestriction_Hip *impl; 373 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 374 375 switch (mem_type) { 376 case CEED_MEM_HOST: 377 *curl_orients = impl->h_curl_orients; 378 break; 379 case CEED_MEM_DEVICE: 380 *curl_orients = impl->d_curl_orients; 381 break; 382 } 383 return CEED_ERROR_SUCCESS; 384 } 385 386 //------------------------------------------------------------------------------ 387 // Destroy restriction 388 //------------------------------------------------------------------------------ 389 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction rstr) { 390 Ceed ceed; 391 CeedElemRestriction_Hip *impl; 392 393 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 394 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 395 if (impl->module) { 396 CeedCallHip(ceed, hipModuleUnload(impl->module)); 397 } 398 CeedCallBackend(CeedFree(&impl->h_offsets_owned)); 399 CeedCallHip(ceed, hipFree((CeedInt *)impl->d_offsets_owned)); 400 CeedCallHip(ceed, hipFree((CeedInt *)impl->d_t_offsets)); 401 CeedCallHip(ceed, hipFree((CeedInt *)impl->d_t_indices)); 402 CeedCallHip(ceed, hipFree((CeedInt *)impl->d_l_vec_indices)); 403 CeedCallBackend(CeedFree(&impl->h_orients_owned)); 404 CeedCallHip(ceed, hipFree((bool *)impl->d_orients_owned)); 405 CeedCallBackend(CeedFree(&impl->h_curl_orients_owned)); 406 CeedCallHip(ceed, hipFree((CeedInt8 *)impl->d_curl_orients_owned)); 407 CeedCallBackend(CeedFree(&impl)); 408 return CEED_ERROR_SUCCESS; 409 } 410 411 //------------------------------------------------------------------------------ 412 // Create transpose offsets and indices 413 //------------------------------------------------------------------------------ 414 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction rstr, const CeedInt *indices) { 415 Ceed ceed; 416 bool *is_node; 417 CeedSize l_size; 418 CeedInt num_elem, elem_size, num_comp, num_nodes = 0; 419 CeedInt *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 420 CeedElemRestriction_Hip *impl; 421 422 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 423 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 424 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 425 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 426 CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 427 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 428 const CeedInt size_indices = num_elem * elem_size; 429 430 // Count num_nodes 431 CeedCallBackend(CeedCalloc(l_size, &is_node)); 432 433 for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 434 for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 435 impl->num_nodes = num_nodes; 436 437 // L-vector offsets array 438 CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 439 CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 440 for (CeedInt i = 0, j = 0; i < l_size; i++) { 441 if (is_node[i]) { 442 l_vec_indices[j] = i; 443 ind_to_offset[i] = j++; 444 } 445 } 446 CeedCallBackend(CeedFree(&is_node)); 447 448 // Compute transpose offsets and indices 449 const CeedInt size_offsets = num_nodes + 1; 450 451 CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 452 CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 453 // Count node multiplicity 454 for (CeedInt e = 0; e < num_elem; ++e) { 455 for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 456 } 457 // Convert to running sum 458 for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 459 // List all E-vec indices associated with L-vec node 460 for (CeedInt e = 0; e < num_elem; ++e) { 461 for (CeedInt i = 0; i < elem_size; ++i) { 462 const CeedInt lid = elem_size * e + i; 463 const CeedInt gid = indices[lid]; 464 465 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 466 } 467 } 468 // Reset running sum 469 for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 470 t_offsets[0] = 0; 471 472 // Copy data to device 473 // -- L-vector indices 474 CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 475 CeedCallHip(ceed, hipMemcpy((CeedInt *)impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice)); 476 // -- Transpose offsets 477 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 478 CeedCallHip(ceed, hipMemcpy((CeedInt *)impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice)); 479 // -- Transpose indices 480 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 481 CeedCallHip(ceed, hipMemcpy((CeedInt *)impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice)); 482 483 // Cleanup 484 CeedCallBackend(CeedFree(&ind_to_offset)); 485 CeedCallBackend(CeedFree(&l_vec_indices)); 486 CeedCallBackend(CeedFree(&t_offsets)); 487 CeedCallBackend(CeedFree(&t_indices)); 488 return CEED_ERROR_SUCCESS; 489 } 490 491 //------------------------------------------------------------------------------ 492 // Create restriction 493 //------------------------------------------------------------------------------ 494 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *offsets, const bool *orients, 495 const CeedInt8 *curl_orients, CeedElemRestriction rstr) { 496 Ceed ceed, ceed_parent; 497 bool is_deterministic; 498 CeedInt num_elem, elem_size; 499 CeedRestrictionType rstr_type; 500 CeedElemRestriction_Hip *impl; 501 502 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 503 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 504 CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic)); 505 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 506 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 507 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 508 const CeedInt size = num_elem * elem_size; 509 510 CeedCallBackend(CeedCalloc(1, &impl)); 511 impl->num_nodes = size; 512 CeedCallBackend(CeedElemRestrictionSetData(rstr, impl)); 513 514 // Set layouts 515 { 516 bool has_backend_strides; 517 CeedInt layout[3] = {1, size, elem_size}; 518 519 CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout)); 520 if (rstr_type == CEED_RESTRICTION_STRIDED) { 521 CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 522 if (has_backend_strides) { 523 CeedCallBackend(CeedElemRestrictionSetLLayout(rstr, layout)); 524 } 525 } 526 } 527 528 // Set up device offset/orientation arrays 529 if (rstr_type != CEED_RESTRICTION_STRIDED) { 530 switch (mem_type) { 531 case CEED_MEM_HOST: { 532 CeedCallBackend(CeedSetHostCeedIntArray(offsets, copy_mode, size, &impl->h_offsets_owned, &impl->h_offsets_borrowed, &impl->h_offsets)); 533 CeedCallHip(ceed, hipMalloc((void **)&impl->d_offsets_owned, size * sizeof(CeedInt))); 534 CeedCallHip(ceed, hipMemcpy((CeedInt **)impl->d_offsets_owned, impl->h_offsets, size * sizeof(CeedInt), hipMemcpyHostToDevice)); 535 impl->d_offsets = (CeedInt *)impl->d_offsets_owned; 536 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, offsets)); 537 } break; 538 case CEED_MEM_DEVICE: { 539 CeedCallBackend(CeedSetDeviceCeedIntArray_Hip(ceed, offsets, copy_mode, size, &impl->d_offsets_owned, &impl->d_offsets_borrowed, 540 (const CeedInt **)&impl->d_offsets)); 541 CeedCallBackend(CeedMalloc(size, &impl->h_offsets_owned)); 542 CeedCallHip(ceed, hipMemcpy((CeedInt **)impl->h_offsets_owned, impl->d_offsets, size * sizeof(CeedInt), hipMemcpyDeviceToHost)); 543 impl->h_offsets = impl->h_offsets_owned; 544 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, offsets)); 545 } break; 546 } 547 548 // Orientation data 549 if (rstr_type == CEED_RESTRICTION_ORIENTED) { 550 switch (mem_type) { 551 case CEED_MEM_HOST: { 552 CeedCallBackend(CeedSetHostBoolArray(orients, copy_mode, size, &impl->h_orients_owned, &impl->h_orients_borrowed, &impl->h_orients)); 553 CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients_owned, size * sizeof(bool))); 554 CeedCallHip(ceed, hipMemcpy((bool *)impl->d_orients_owned, impl->h_orients, size * sizeof(bool), hipMemcpyHostToDevice)); 555 impl->d_orients = impl->d_orients_owned; 556 } break; 557 case CEED_MEM_DEVICE: { 558 CeedCallBackend(CeedSetDeviceBoolArray_Hip(ceed, orients, copy_mode, size, &impl->d_orients_owned, &impl->d_orients_borrowed, 559 (const bool **)&impl->d_orients)); 560 CeedCallBackend(CeedMalloc(size, &impl->h_orients_owned)); 561 CeedCallHip(ceed, hipMemcpy((bool *)impl->h_orients_owned, impl->d_orients, size * sizeof(bool), hipMemcpyDeviceToHost)); 562 impl->h_orients = impl->h_orients_owned; 563 } break; 564 } 565 } else if (rstr_type == CEED_RESTRICTION_CURL_ORIENTED) { 566 switch (mem_type) { 567 case CEED_MEM_HOST: { 568 CeedCallBackend(CeedSetHostCeedInt8Array(curl_orients, copy_mode, 3 * size, &impl->h_curl_orients_owned, &impl->h_curl_orients_borrowed, 569 &impl->h_curl_orients)); 570 CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients_owned, 3 * size * sizeof(CeedInt8))); 571 CeedCallHip(ceed, 572 hipMemcpy((CeedInt8 *)impl->d_curl_orients_owned, impl->h_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyHostToDevice)); 573 impl->d_curl_orients = impl->d_curl_orients_owned; 574 } break; 575 case CEED_MEM_DEVICE: { 576 CeedCallBackend(CeedSetDeviceCeedInt8Array_Hip(ceed, curl_orients, copy_mode, 3 * size, &impl->d_curl_orients_owned, 577 &impl->d_curl_orients_borrowed, (const CeedInt8 **)&impl->d_curl_orients)); 578 CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_owned)); 579 CeedCallHip(ceed, 580 hipMemcpy((CeedInt8 *)impl->h_curl_orients_owned, impl->d_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToHost)); 581 impl->h_curl_orients = impl->h_curl_orients_owned; 582 } break; 583 } 584 } 585 } 586 587 // Register backend functions 588 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Hip)); 589 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApplyUnsigned_Hip)); 590 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApplyUnoriented_Hip)); 591 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Hip)); 592 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOrientations", CeedElemRestrictionGetOrientations_Hip)); 593 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetCurlOrientations", CeedElemRestrictionGetCurlOrientations_Hip)); 594 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Hip)); 595 return CEED_ERROR_SUCCESS; 596 } 597 598 //------------------------------------------------------------------------------ 599