1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <stdbool.h> 12 #include <stddef.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Compile restriction kernels 22 //------------------------------------------------------------------------------ 23 static inline int CeedElemRestrictionSetupCompile_Hip(CeedElemRestriction rstr) { 24 Ceed ceed; 25 bool is_deterministic; 26 CeedInt num_elem, num_comp, elem_size, comp_stride; 27 CeedRestrictionType rstr_type; 28 char *restriction_kernel_path, *restriction_kernel_source; 29 CeedElemRestriction_Hip *impl; 30 31 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 32 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 33 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 34 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 35 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 36 CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride)); 37 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 38 is_deterministic = impl->d_l_vec_indices != NULL; 39 40 // Compile HIP kernels 41 switch (rstr_type) { 42 case CEED_RESTRICTION_STRIDED: { 43 CeedInt strides[3] = {1, num_elem * elem_size, elem_size}; 44 bool has_backend_strides; 45 46 CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 47 if (!has_backend_strides) { 48 CeedCallBackend(CeedElemRestrictionGetStrides(rstr, &strides)); 49 } 50 51 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-strided.h", &restriction_kernel_path)); 52 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 53 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 54 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 55 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 56 "RSTR_NUM_COMP", num_comp, "RSTR_STRIDE_NODES", strides[0], "RSTR_STRIDE_COMP", strides[1], "RSTR_STRIDE_ELEM", 57 strides[2])); 58 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->ApplyNoTranspose)); 59 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->ApplyTranspose)); 60 } break; 61 case CEED_RESTRICTION_STANDARD: { 62 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &restriction_kernel_path)); 63 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 64 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 65 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 66 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 67 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 68 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 69 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyNoTranspose)); 70 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyTranspose)); 71 } break; 72 case CEED_RESTRICTION_ORIENTED: { 73 char *offset_kernel_path; 74 75 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-oriented.h", &restriction_kernel_path)); 76 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 77 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 78 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &offset_kernel_path)); 79 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, offset_kernel_path, &restriction_kernel_source)); 80 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 81 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 82 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 83 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 84 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedNoTranspose", &impl->ApplyNoTranspose)); 85 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyUnsignedNoTranspose)); 86 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedTranspose", &impl->ApplyTranspose)); 87 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyUnsignedTranspose)); 88 CeedCallBackend(CeedFree(&offset_kernel_path)); 89 } break; 90 case CEED_RESTRICTION_CURL_ORIENTED: { 91 char *offset_kernel_path; 92 93 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-curl-oriented.h", &restriction_kernel_path)); 94 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 95 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 96 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &offset_kernel_path)); 97 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, offset_kernel_path, &restriction_kernel_source)); 98 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 99 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 100 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 101 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 102 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedNoTranspose", &impl->ApplyNoTranspose)); 103 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedNoTranspose", &impl->ApplyUnsignedNoTranspose)); 104 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyUnorientedNoTranspose)); 105 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedTranspose", &impl->ApplyTranspose)); 106 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedTranspose", &impl->ApplyUnsignedTranspose)); 107 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyUnorientedTranspose)); 108 CeedCallBackend(CeedFree(&offset_kernel_path)); 109 } break; 110 case CEED_RESTRICTION_POINTS: { 111 // LCOV_EXCL_START 112 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 113 // LCOV_EXCL_STOP 114 } break; 115 } 116 CeedCallBackend(CeedFree(&restriction_kernel_path)); 117 CeedCallBackend(CeedFree(&restriction_kernel_source)); 118 return CEED_ERROR_SUCCESS; 119 } 120 121 //------------------------------------------------------------------------------ 122 // Core apply restriction code 123 //------------------------------------------------------------------------------ 124 static inline int CeedElemRestrictionApply_Hip_Core(CeedElemRestriction rstr, CeedTransposeMode t_mode, bool use_signs, bool use_orients, 125 CeedVector u, CeedVector v, CeedRequest *request) { 126 Ceed ceed; 127 CeedRestrictionType rstr_type; 128 const CeedScalar *d_u; 129 CeedScalar *d_v; 130 CeedElemRestriction_Hip *impl; 131 132 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 133 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 134 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 135 136 // Assemble kernel if needed 137 if (!impl->module) { 138 CeedCallBackend(CeedElemRestrictionSetupCompile_Hip(rstr)); 139 } 140 141 // Get vectors 142 CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 143 if (t_mode == CEED_TRANSPOSE) { 144 // Sum into for transpose mode, e-vec to l-vec 145 CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 146 } else { 147 // Overwrite for notranspose mode, l-vec to e-vec 148 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 149 } 150 151 // Restrict 152 if (t_mode == CEED_NOTRANSPOSE) { 153 // L-vector -> E-vector 154 CeedInt elem_size; 155 156 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 157 const CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 158 const CeedInt grid = CeedDivUpInt(impl->num_nodes, block_size); 159 160 switch (rstr_type) { 161 case CEED_RESTRICTION_STRIDED: { 162 void *args[] = {&d_u, &d_v}; 163 164 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 165 } break; 166 case CEED_RESTRICTION_STANDARD: { 167 void *args[] = {&impl->d_ind, &d_u, &d_v}; 168 169 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 170 } break; 171 case CEED_RESTRICTION_ORIENTED: { 172 if (use_signs) { 173 void *args[] = {&impl->d_ind, &impl->d_orients, &d_u, &d_v}; 174 175 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 176 } else { 177 void *args[] = {&impl->d_ind, &d_u, &d_v}; 178 179 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedNoTranspose, grid, block_size, args)); 180 } 181 } break; 182 case CEED_RESTRICTION_CURL_ORIENTED: { 183 if (use_signs && use_orients) { 184 void *args[] = {&impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 185 186 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 187 } else if (use_orients) { 188 void *args[] = {&impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 189 190 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedNoTranspose, grid, block_size, args)); 191 } else { 192 void *args[] = {&impl->d_ind, &d_u, &d_v}; 193 194 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedNoTranspose, grid, block_size, args)); 195 } 196 } break; 197 case CEED_RESTRICTION_POINTS: { 198 // LCOV_EXCL_START 199 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 200 // LCOV_EXCL_STOP 201 } break; 202 } 203 } else { 204 // E-vector -> L-vector 205 const bool is_deterministic = impl->d_l_vec_indices != NULL; 206 const CeedInt block_size = 64; 207 const CeedInt grid = CeedDivUpInt(impl->num_nodes, block_size); 208 209 switch (rstr_type) { 210 case CEED_RESTRICTION_STRIDED: { 211 void *args[] = {&d_u, &d_v}; 212 213 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 214 } break; 215 case CEED_RESTRICTION_STANDARD: { 216 if (!is_deterministic) { 217 void *args[] = {&impl->d_ind, &d_u, &d_v}; 218 219 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 220 } else { 221 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 222 223 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 224 } 225 } break; 226 case CEED_RESTRICTION_ORIENTED: { 227 if (use_signs) { 228 if (!is_deterministic) { 229 void *args[] = {&impl->d_ind, &impl->d_orients, &d_u, &d_v}; 230 231 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 232 } else { 233 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_orients, &d_u, &d_v}; 234 235 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 236 } 237 } else { 238 if (!is_deterministic) { 239 void *args[] = {&impl->d_ind, &d_u, &d_v}; 240 241 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 242 } else { 243 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 244 245 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 246 } 247 } 248 } break; 249 case CEED_RESTRICTION_CURL_ORIENTED: { 250 if (use_signs && use_orients) { 251 if (!is_deterministic) { 252 void *args[] = {&impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 253 254 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 255 } else { 256 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_curl_orients, &d_u, &d_v}; 257 258 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 259 } 260 } else if (use_orients) { 261 if (!is_deterministic) { 262 void *args[] = {&impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 263 264 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 265 } else { 266 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_curl_orients, &d_u, &d_v}; 267 268 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 269 } 270 } else { 271 if (!is_deterministic) { 272 void *args[] = {&impl->d_ind, &d_u, &d_v}; 273 274 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedTranspose, grid, block_size, args)); 275 } else { 276 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 277 278 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedTranspose, grid, block_size, args)); 279 } 280 } 281 } break; 282 case CEED_RESTRICTION_POINTS: { 283 // LCOV_EXCL_START 284 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 285 // LCOV_EXCL_STOP 286 } break; 287 } 288 } 289 290 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 291 292 // Restore arrays 293 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 294 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 295 return CEED_ERROR_SUCCESS; 296 } 297 298 //------------------------------------------------------------------------------ 299 // Apply restriction 300 //------------------------------------------------------------------------------ 301 static int CeedElemRestrictionApply_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 302 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, true, true, u, v, request); 303 } 304 305 //------------------------------------------------------------------------------ 306 // Apply unsigned restriction 307 //------------------------------------------------------------------------------ 308 static int CeedElemRestrictionApplyUnsigned_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 309 CeedRequest *request) { 310 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, true, u, v, request); 311 } 312 313 //------------------------------------------------------------------------------ 314 // Apply unoriented restriction 315 //------------------------------------------------------------------------------ 316 static int CeedElemRestrictionApplyUnoriented_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 317 CeedRequest *request) { 318 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, false, u, v, request); 319 } 320 321 //------------------------------------------------------------------------------ 322 // Get offsets 323 //------------------------------------------------------------------------------ 324 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 325 CeedElemRestriction_Hip *impl; 326 327 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 328 switch (mem_type) { 329 case CEED_MEM_HOST: 330 *offsets = impl->h_ind; 331 break; 332 case CEED_MEM_DEVICE: 333 *offsets = impl->d_ind; 334 break; 335 } 336 return CEED_ERROR_SUCCESS; 337 } 338 339 //------------------------------------------------------------------------------ 340 // Get orientations 341 //------------------------------------------------------------------------------ 342 static int CeedElemRestrictionGetOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const bool **orients) { 343 CeedElemRestriction_Hip *impl; 344 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 345 346 switch (mem_type) { 347 case CEED_MEM_HOST: 348 *orients = impl->h_orients; 349 break; 350 case CEED_MEM_DEVICE: 351 *orients = impl->d_orients; 352 break; 353 } 354 return CEED_ERROR_SUCCESS; 355 } 356 357 //------------------------------------------------------------------------------ 358 // Get curl-conforming orientations 359 //------------------------------------------------------------------------------ 360 static int CeedElemRestrictionGetCurlOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt8 **curl_orients) { 361 CeedElemRestriction_Hip *impl; 362 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 363 364 switch (mem_type) { 365 case CEED_MEM_HOST: 366 *curl_orients = impl->h_curl_orients; 367 break; 368 case CEED_MEM_DEVICE: 369 *curl_orients = impl->d_curl_orients; 370 break; 371 } 372 return CEED_ERROR_SUCCESS; 373 } 374 375 //------------------------------------------------------------------------------ 376 // Destroy restriction 377 //------------------------------------------------------------------------------ 378 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction rstr) { 379 Ceed ceed; 380 CeedElemRestriction_Hip *impl; 381 382 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 383 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 384 if (impl->module) { 385 CeedCallHip(ceed, hipModuleUnload(impl->module)); 386 } 387 CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 388 CeedCallHip(ceed, hipFree(impl->d_ind_allocated)); 389 CeedCallHip(ceed, hipFree(impl->d_t_offsets)); 390 CeedCallHip(ceed, hipFree(impl->d_t_indices)); 391 CeedCallHip(ceed, hipFree(impl->d_l_vec_indices)); 392 CeedCallBackend(CeedFree(&impl->h_orients_allocated)); 393 CeedCallHip(ceed, hipFree(impl->d_orients_allocated)); 394 CeedCallBackend(CeedFree(&impl->h_curl_orients_allocated)); 395 CeedCallHip(ceed, hipFree(impl->d_curl_orients_allocated)); 396 CeedCallBackend(CeedFree(&impl)); 397 return CEED_ERROR_SUCCESS; 398 } 399 400 //------------------------------------------------------------------------------ 401 // Create transpose offsets and indices 402 //------------------------------------------------------------------------------ 403 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction rstr, const CeedInt *indices) { 404 Ceed ceed; 405 bool *is_node; 406 CeedSize l_size; 407 CeedInt num_elem, elem_size, num_comp, num_nodes = 0; 408 CeedInt *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 409 CeedElemRestriction_Hip *impl; 410 411 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 412 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 413 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 414 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 415 CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 416 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 417 const CeedInt size_indices = num_elem * elem_size; 418 419 // Count num_nodes 420 CeedCallBackend(CeedCalloc(l_size, &is_node)); 421 422 for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 423 for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 424 impl->num_nodes = num_nodes; 425 426 // L-vector offsets array 427 CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 428 CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 429 for (CeedInt i = 0, j = 0; i < l_size; i++) { 430 if (is_node[i]) { 431 l_vec_indices[j] = i; 432 ind_to_offset[i] = j++; 433 } 434 } 435 CeedCallBackend(CeedFree(&is_node)); 436 437 // Compute transpose offsets and indices 438 const CeedInt size_offsets = num_nodes + 1; 439 440 CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 441 CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 442 // Count node multiplicity 443 for (CeedInt e = 0; e < num_elem; ++e) { 444 for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 445 } 446 // Convert to running sum 447 for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 448 // List all E-vec indices associated with L-vec node 449 for (CeedInt e = 0; e < num_elem; ++e) { 450 for (CeedInt i = 0; i < elem_size; ++i) { 451 const CeedInt lid = elem_size * e + i; 452 const CeedInt gid = indices[lid]; 453 454 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 455 } 456 } 457 // Reset running sum 458 for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 459 t_offsets[0] = 0; 460 461 // Copy data to device 462 // -- L-vector indices 463 CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 464 CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice)); 465 // -- Transpose offsets 466 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 467 CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice)); 468 // -- Transpose indices 469 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 470 CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice)); 471 472 // Cleanup 473 CeedCallBackend(CeedFree(&ind_to_offset)); 474 CeedCallBackend(CeedFree(&l_vec_indices)); 475 CeedCallBackend(CeedFree(&t_offsets)); 476 CeedCallBackend(CeedFree(&t_indices)); 477 return CEED_ERROR_SUCCESS; 478 } 479 480 //------------------------------------------------------------------------------ 481 // Create restriction 482 //------------------------------------------------------------------------------ 483 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients, 484 const CeedInt8 *curl_orients, CeedElemRestriction rstr) { 485 Ceed ceed, ceed_parent; 486 bool is_deterministic; 487 CeedInt num_elem, elem_size; 488 CeedRestrictionType rstr_type; 489 CeedElemRestriction_Hip *impl; 490 491 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 492 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 493 CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic)); 494 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 495 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 496 const CeedInt size = num_elem * elem_size; 497 CeedInt layout[3] = {1, size, elem_size}; 498 499 CeedCallBackend(CeedCalloc(1, &impl)); 500 impl->num_nodes = size; 501 impl->h_ind = NULL; 502 impl->h_ind_allocated = NULL; 503 impl->d_ind = NULL; 504 impl->d_ind_allocated = NULL; 505 impl->d_t_indices = NULL; 506 impl->d_t_offsets = NULL; 507 impl->h_orients = NULL; 508 impl->h_orients_allocated = NULL; 509 impl->d_orients = NULL; 510 impl->d_orients_allocated = NULL; 511 impl->h_curl_orients = NULL; 512 impl->h_curl_orients_allocated = NULL; 513 impl->d_curl_orients = NULL; 514 impl->d_curl_orients_allocated = NULL; 515 CeedCallBackend(CeedElemRestrictionSetData(rstr, impl)); 516 CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout)); 517 518 // Set up device offset/orientation arrays 519 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 520 if (rstr_type != CEED_RESTRICTION_STRIDED) { 521 switch (mem_type) { 522 case CEED_MEM_HOST: { 523 switch (copy_mode) { 524 case CEED_OWN_POINTER: 525 impl->h_ind_allocated = (CeedInt *)indices; 526 impl->h_ind = (CeedInt *)indices; 527 break; 528 case CEED_USE_POINTER: 529 impl->h_ind = (CeedInt *)indices; 530 break; 531 case CEED_COPY_VALUES: 532 CeedCallBackend(CeedMalloc(size, &impl->h_ind_allocated)); 533 memcpy(impl->h_ind_allocated, indices, size * sizeof(CeedInt)); 534 impl->h_ind = impl->h_ind_allocated; 535 break; 536 } 537 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 538 impl->d_ind_allocated = impl->d_ind; // We own the device memory 539 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice)); 540 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, indices)); 541 } break; 542 case CEED_MEM_DEVICE: { 543 switch (copy_mode) { 544 case CEED_COPY_VALUES: 545 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 546 impl->d_ind_allocated = impl->d_ind; // We own the device memory 547 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice)); 548 break; 549 case CEED_OWN_POINTER: 550 impl->d_ind = (CeedInt *)indices; 551 impl->d_ind_allocated = impl->d_ind; 552 break; 553 case CEED_USE_POINTER: 554 impl->d_ind = (CeedInt *)indices; 555 break; 556 } 557 CeedCallBackend(CeedMalloc(size, &impl->h_ind_allocated)); 558 CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, size * sizeof(CeedInt), hipMemcpyDeviceToHost)); 559 impl->h_ind = impl->h_ind_allocated; 560 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, indices)); 561 } break; 562 } 563 564 // Orientation data 565 if (rstr_type == CEED_RESTRICTION_ORIENTED) { 566 switch (mem_type) { 567 case CEED_MEM_HOST: { 568 switch (copy_mode) { 569 case CEED_OWN_POINTER: 570 impl->h_orients_allocated = (bool *)orients; 571 impl->h_orients = (bool *)orients; 572 break; 573 case CEED_USE_POINTER: 574 impl->h_orients = (bool *)orients; 575 break; 576 case CEED_COPY_VALUES: 577 CeedCallBackend(CeedMalloc(size, &impl->h_orients_allocated)); 578 memcpy(impl->h_orients_allocated, orients, size * sizeof(bool)); 579 impl->h_orients = impl->h_orients_allocated; 580 break; 581 } 582 CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients, size * sizeof(bool))); 583 impl->d_orients_allocated = impl->d_orients; // We own the device memory 584 CeedCallHip(ceed, hipMemcpy(impl->d_orients, orients, size * sizeof(bool), hipMemcpyHostToDevice)); 585 } break; 586 case CEED_MEM_DEVICE: { 587 switch (copy_mode) { 588 case CEED_COPY_VALUES: 589 CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients, size * sizeof(bool))); 590 impl->d_orients_allocated = impl->d_orients; // We own the device memory 591 CeedCallHip(ceed, hipMemcpy(impl->d_orients, orients, size * sizeof(bool), hipMemcpyDeviceToDevice)); 592 break; 593 case CEED_OWN_POINTER: 594 impl->d_orients = (bool *)orients; 595 impl->d_orients_allocated = impl->d_orients; 596 break; 597 case CEED_USE_POINTER: 598 impl->d_orients = (bool *)orients; 599 break; 600 } 601 CeedCallBackend(CeedMalloc(size, &impl->h_orients_allocated)); 602 CeedCallHip(ceed, hipMemcpy(impl->h_orients_allocated, impl->d_orients, size * sizeof(bool), hipMemcpyDeviceToHost)); 603 impl->h_orients = impl->h_orients_allocated; 604 } break; 605 } 606 } else if (rstr_type == CEED_RESTRICTION_CURL_ORIENTED) { 607 switch (mem_type) { 608 case CEED_MEM_HOST: { 609 switch (copy_mode) { 610 case CEED_OWN_POINTER: 611 impl->h_curl_orients_allocated = (CeedInt8 *)curl_orients; 612 impl->h_curl_orients = (CeedInt8 *)curl_orients; 613 break; 614 case CEED_USE_POINTER: 615 impl->h_curl_orients = (CeedInt8 *)curl_orients; 616 break; 617 case CEED_COPY_VALUES: 618 CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_allocated)); 619 memcpy(impl->h_curl_orients_allocated, curl_orients, 3 * size * sizeof(CeedInt8)); 620 impl->h_curl_orients = impl->h_curl_orients_allocated; 621 break; 622 } 623 CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients, 3 * size * sizeof(CeedInt8))); 624 impl->d_curl_orients_allocated = impl->d_curl_orients; // We own the device memory 625 CeedCallHip(ceed, hipMemcpy(impl->d_curl_orients, curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyHostToDevice)); 626 } break; 627 case CEED_MEM_DEVICE: { 628 switch (copy_mode) { 629 case CEED_COPY_VALUES: 630 CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients, 3 * size * sizeof(CeedInt8))); 631 impl->d_curl_orients_allocated = impl->d_curl_orients; // We own the device memory 632 CeedCallHip(ceed, hipMemcpy(impl->d_curl_orients, curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToDevice)); 633 break; 634 case CEED_OWN_POINTER: 635 impl->d_curl_orients = (CeedInt8 *)curl_orients; 636 impl->d_curl_orients_allocated = impl->d_curl_orients; 637 break; 638 case CEED_USE_POINTER: 639 impl->d_curl_orients = (CeedInt8 *)curl_orients; 640 break; 641 } 642 CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_allocated)); 643 CeedCallHip(ceed, hipMemcpy(impl->h_curl_orients_allocated, impl->d_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToHost)); 644 impl->h_curl_orients = impl->h_curl_orients_allocated; 645 } break; 646 } 647 } 648 } 649 650 // Register backend functions 651 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Hip)); 652 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApplyUnsigned_Hip)); 653 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApplyUnoriented_Hip)); 654 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Hip)); 655 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOrientations", CeedElemRestrictionGetOrientations_Hip)); 656 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetCurlOrientations", CeedElemRestrictionGetCurlOrientations_Hip)); 657 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Hip)); 658 return CEED_ERROR_SUCCESS; 659 } 660 661 //------------------------------------------------------------------------------ 662