1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <stdbool.h> 12 #include <stddef.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Core apply restriction code 22 //------------------------------------------------------------------------------ 23 static inline int CeedElemRestrictionApply_Hip_Core(CeedElemRestriction rstr, CeedTransposeMode t_mode, bool use_signs, bool use_orients, 24 CeedVector u, CeedVector v, CeedRequest *request) { 25 Ceed ceed; 26 CeedInt num_elem, elem_size; 27 CeedRestrictionType rstr_type; 28 const CeedScalar *d_u; 29 CeedScalar *d_v; 30 CeedElemRestriction_Hip *impl; 31 hipFunction_t kernel; 32 33 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 34 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 35 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 36 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 37 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 38 const CeedInt num_nodes = impl->num_nodes; 39 40 // Get vectors 41 CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 42 if (t_mode == CEED_TRANSPOSE) { 43 // Sum into for transpose mode, e-vec to l-vec 44 CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 45 } else { 46 // Overwrite for notranspose mode, l-vec to e-vec 47 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 48 } 49 50 // Restrict 51 if (t_mode == CEED_NOTRANSPOSE) { 52 // L-vector -> E-vector 53 const CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 54 const CeedInt grid = CeedDivUpInt(num_nodes, block_size); 55 56 switch (rstr_type) { 57 case CEED_RESTRICTION_STRIDED: { 58 kernel = impl->StridedNoTranspose; 59 void *args[] = {&num_elem, &d_u, &d_v}; 60 61 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 62 } break; 63 case CEED_RESTRICTION_STANDARD: { 64 kernel = impl->OffsetNoTranspose; 65 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 66 67 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 68 } break; 69 case CEED_RESTRICTION_ORIENTED: { 70 if (use_signs) { 71 kernel = impl->OrientedNoTranspose; 72 void *args[] = {&num_elem, &impl->d_ind, &impl->d_orients, &d_u, &d_v}; 73 74 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 75 } else { 76 kernel = impl->OffsetNoTranspose; 77 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 78 79 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 80 } 81 } break; 82 case CEED_RESTRICTION_CURL_ORIENTED: { 83 if (use_signs && use_orients) { 84 kernel = impl->CurlOrientedNoTranspose; 85 void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 86 87 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 88 } else if (use_orients) { 89 kernel = impl->CurlOrientedUnsignedNoTranspose; 90 void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 91 92 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 93 } else { 94 kernel = impl->OffsetNoTranspose; 95 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 96 97 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 98 } 99 } break; 100 case CEED_RESTRICTION_POINTS: { 101 // LCOV_EXCL_START 102 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 103 // LCOV_EXCL_STOP 104 } break; 105 } 106 } else { 107 // E-vector -> L-vector 108 const CeedInt block_size = 64; 109 const CeedInt grid = CeedDivUpInt(num_nodes, block_size); 110 111 switch (rstr_type) { 112 case CEED_RESTRICTION_STRIDED: { 113 kernel = impl->StridedTranspose; 114 void *args[] = {&num_elem, &d_u, &d_v}; 115 116 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 117 } break; 118 case CEED_RESTRICTION_STANDARD: { 119 if (impl->OffsetTranspose) { 120 kernel = impl->OffsetTranspose; 121 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 122 123 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 124 } else { 125 kernel = impl->OffsetTransposeDet; 126 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 127 128 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 129 } 130 } break; 131 case CEED_RESTRICTION_ORIENTED: { 132 if (use_signs) { 133 kernel = impl->OrientedTranspose; 134 void *args[] = {&num_elem, &impl->d_ind, &impl->d_orients, &d_u, &d_v}; 135 136 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 137 } else { 138 if (impl->OffsetTranspose) { 139 kernel = impl->OffsetTranspose; 140 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 141 142 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 143 } else { 144 kernel = impl->OffsetTransposeDet; 145 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 146 147 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 148 } 149 } 150 } break; 151 case CEED_RESTRICTION_CURL_ORIENTED: { 152 if (use_signs && use_orients) { 153 kernel = impl->CurlOrientedTranspose; 154 void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 155 156 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 157 } else if (use_orients) { 158 kernel = impl->CurlOrientedUnsignedTranspose; 159 void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 160 161 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 162 } else { 163 if (impl->OffsetTranspose) { 164 kernel = impl->OffsetTranspose; 165 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 166 167 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 168 } else { 169 kernel = impl->OffsetTransposeDet; 170 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 171 172 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 173 } 174 } 175 } break; 176 case CEED_RESTRICTION_POINTS: { 177 // LCOV_EXCL_START 178 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 179 // LCOV_EXCL_STOP 180 } break; 181 } 182 } 183 184 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 185 186 // Restore arrays 187 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 188 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 189 return CEED_ERROR_SUCCESS; 190 } 191 192 //------------------------------------------------------------------------------ 193 // Apply restriction 194 //------------------------------------------------------------------------------ 195 static int CeedElemRestrictionApply_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 196 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, true, true, u, v, request); 197 } 198 199 //------------------------------------------------------------------------------ 200 // Apply unsigned restriction 201 //------------------------------------------------------------------------------ 202 static int CeedElemRestrictionApplyUnsigned_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 203 CeedRequest *request) { 204 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, true, u, v, request); 205 } 206 207 //------------------------------------------------------------------------------ 208 // Apply unoriented restriction 209 //------------------------------------------------------------------------------ 210 static int CeedElemRestrictionApplyUnoriented_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 211 CeedRequest *request) { 212 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, false, u, v, request); 213 } 214 215 //------------------------------------------------------------------------------ 216 // Get offsets 217 //------------------------------------------------------------------------------ 218 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 219 CeedElemRestriction_Hip *impl; 220 221 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 222 switch (mem_type) { 223 case CEED_MEM_HOST: 224 *offsets = impl->h_ind; 225 break; 226 case CEED_MEM_DEVICE: 227 *offsets = impl->d_ind; 228 break; 229 } 230 return CEED_ERROR_SUCCESS; 231 } 232 233 //------------------------------------------------------------------------------ 234 // Get orientations 235 //------------------------------------------------------------------------------ 236 static int CeedElemRestrictionGetOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const bool **orients) { 237 CeedElemRestriction_Hip *impl; 238 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 239 240 switch (mem_type) { 241 case CEED_MEM_HOST: 242 *orients = impl->h_orients; 243 break; 244 case CEED_MEM_DEVICE: 245 *orients = impl->d_orients; 246 break; 247 } 248 return CEED_ERROR_SUCCESS; 249 } 250 251 //------------------------------------------------------------------------------ 252 // Get curl-conforming orientations 253 //------------------------------------------------------------------------------ 254 static int CeedElemRestrictionGetCurlOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt8 **curl_orients) { 255 CeedElemRestriction_Hip *impl; 256 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 257 258 switch (mem_type) { 259 case CEED_MEM_HOST: 260 *curl_orients = impl->h_curl_orients; 261 break; 262 case CEED_MEM_DEVICE: 263 *curl_orients = impl->d_curl_orients; 264 break; 265 } 266 return CEED_ERROR_SUCCESS; 267 } 268 269 //------------------------------------------------------------------------------ 270 // Destroy restriction 271 //------------------------------------------------------------------------------ 272 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction rstr) { 273 Ceed ceed; 274 CeedElemRestriction_Hip *impl; 275 276 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 277 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 278 CeedCallHip(ceed, hipModuleUnload(impl->module)); 279 CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 280 CeedCallHip(ceed, hipFree(impl->d_ind_allocated)); 281 CeedCallHip(ceed, hipFree(impl->d_t_offsets)); 282 CeedCallHip(ceed, hipFree(impl->d_t_indices)); 283 CeedCallHip(ceed, hipFree(impl->d_l_vec_indices)); 284 CeedCallBackend(CeedFree(&impl->h_orients_allocated)); 285 CeedCallHip(ceed, hipFree(impl->d_orients_allocated)); 286 CeedCallBackend(CeedFree(&impl->h_curl_orients_allocated)); 287 CeedCallHip(ceed, hipFree(impl->d_curl_orients_allocated)); 288 CeedCallBackend(CeedFree(&impl)); 289 return CEED_ERROR_SUCCESS; 290 } 291 292 //------------------------------------------------------------------------------ 293 // Create transpose offsets and indices 294 //------------------------------------------------------------------------------ 295 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction rstr, const CeedInt *indices) { 296 Ceed ceed; 297 bool *is_node; 298 CeedSize l_size; 299 CeedInt num_elem, elem_size, num_comp, num_nodes = 0; 300 CeedInt *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 301 CeedElemRestriction_Hip *impl; 302 303 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 304 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 305 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 306 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 307 CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 308 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 309 const CeedInt size_indices = num_elem * elem_size; 310 311 // Count num_nodes 312 CeedCallBackend(CeedCalloc(l_size, &is_node)); 313 314 for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 315 for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 316 impl->num_nodes = num_nodes; 317 318 // L-vector offsets array 319 CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 320 CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 321 for (CeedInt i = 0, j = 0; i < l_size; i++) { 322 if (is_node[i]) { 323 l_vec_indices[j] = i; 324 ind_to_offset[i] = j++; 325 } 326 } 327 CeedCallBackend(CeedFree(&is_node)); 328 329 // Compute transpose offsets and indices 330 const CeedInt size_offsets = num_nodes + 1; 331 332 CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 333 CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 334 // Count node multiplicity 335 for (CeedInt e = 0; e < num_elem; ++e) { 336 for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 337 } 338 // Convert to running sum 339 for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 340 // List all E-vec indices associated with L-vec node 341 for (CeedInt e = 0; e < num_elem; ++e) { 342 for (CeedInt i = 0; i < elem_size; ++i) { 343 const CeedInt lid = elem_size * e + i; 344 const CeedInt gid = indices[lid]; 345 346 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 347 } 348 } 349 // Reset running sum 350 for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 351 t_offsets[0] = 0; 352 353 // Copy data to device 354 // -- L-vector indices 355 CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 356 CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice)); 357 // -- Transpose offsets 358 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 359 CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice)); 360 // -- Transpose indices 361 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 362 CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice)); 363 364 // Cleanup 365 CeedCallBackend(CeedFree(&ind_to_offset)); 366 CeedCallBackend(CeedFree(&l_vec_indices)); 367 CeedCallBackend(CeedFree(&t_offsets)); 368 CeedCallBackend(CeedFree(&t_indices)); 369 return CEED_ERROR_SUCCESS; 370 } 371 372 //------------------------------------------------------------------------------ 373 // Create restriction 374 //------------------------------------------------------------------------------ 375 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients, 376 const CeedInt8 *curl_orients, CeedElemRestriction rstr) { 377 Ceed ceed, ceed_parent; 378 bool is_deterministic; 379 CeedInt num_elem, num_comp, elem_size, comp_stride = 1; 380 CeedRestrictionType rstr_type; 381 char *restriction_kernel_path, *restriction_kernel_source; 382 CeedElemRestriction_Hip *impl; 383 384 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 385 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 386 CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic)); 387 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 388 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 389 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 390 const CeedInt size = num_elem * elem_size; 391 CeedInt strides[3] = {1, size, elem_size}; 392 CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 393 394 // Stride data 395 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 396 if (rstr_type == CEED_RESTRICTION_STRIDED) { 397 bool has_backend_strides; 398 399 CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 400 if (!has_backend_strides) { 401 CeedCallBackend(CeedElemRestrictionGetStrides(rstr, &strides)); 402 } 403 } else { 404 CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride)); 405 } 406 407 CeedCallBackend(CeedCalloc(1, &impl)); 408 impl->num_nodes = size; 409 impl->h_ind = NULL; 410 impl->h_ind_allocated = NULL; 411 impl->d_ind = NULL; 412 impl->d_ind_allocated = NULL; 413 impl->d_t_indices = NULL; 414 impl->d_t_offsets = NULL; 415 impl->h_orients = NULL; 416 impl->h_orients_allocated = NULL; 417 impl->d_orients = NULL; 418 impl->d_orients_allocated = NULL; 419 impl->h_curl_orients = NULL; 420 impl->h_curl_orients_allocated = NULL; 421 impl->d_curl_orients = NULL; 422 impl->d_curl_orients_allocated = NULL; 423 CeedCallBackend(CeedElemRestrictionSetData(rstr, impl)); 424 CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout)); 425 426 // Set up device offset/orientation arrays 427 if (rstr_type != CEED_RESTRICTION_STRIDED) { 428 switch (mem_type) { 429 case CEED_MEM_HOST: { 430 switch (copy_mode) { 431 case CEED_OWN_POINTER: 432 impl->h_ind_allocated = (CeedInt *)indices; 433 impl->h_ind = (CeedInt *)indices; 434 break; 435 case CEED_USE_POINTER: 436 impl->h_ind = (CeedInt *)indices; 437 break; 438 case CEED_COPY_VALUES: 439 CeedCallBackend(CeedMalloc(size, &impl->h_ind_allocated)); 440 memcpy(impl->h_ind_allocated, indices, size * sizeof(CeedInt)); 441 impl->h_ind = impl->h_ind_allocated; 442 break; 443 } 444 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 445 impl->d_ind_allocated = impl->d_ind; // We own the device memory 446 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice)); 447 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, indices)); 448 } break; 449 case CEED_MEM_DEVICE: { 450 switch (copy_mode) { 451 case CEED_COPY_VALUES: 452 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 453 impl->d_ind_allocated = impl->d_ind; // We own the device memory 454 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice)); 455 break; 456 case CEED_OWN_POINTER: 457 impl->d_ind = (CeedInt *)indices; 458 impl->d_ind_allocated = impl->d_ind; 459 break; 460 case CEED_USE_POINTER: 461 impl->d_ind = (CeedInt *)indices; 462 break; 463 } 464 CeedCallBackend(CeedMalloc(size, &impl->h_ind_allocated)); 465 CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, size * sizeof(CeedInt), hipMemcpyDeviceToHost)); 466 impl->h_ind = impl->h_ind_allocated; 467 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, indices)); 468 } break; 469 } 470 471 // Orientation data 472 if (rstr_type == CEED_RESTRICTION_ORIENTED) { 473 switch (mem_type) { 474 case CEED_MEM_HOST: { 475 switch (copy_mode) { 476 case CEED_OWN_POINTER: 477 impl->h_orients_allocated = (bool *)orients; 478 impl->h_orients = (bool *)orients; 479 break; 480 case CEED_USE_POINTER: 481 impl->h_orients = (bool *)orients; 482 break; 483 case CEED_COPY_VALUES: 484 CeedCallBackend(CeedMalloc(size, &impl->h_orients_allocated)); 485 memcpy(impl->h_orients_allocated, orients, size * sizeof(bool)); 486 impl->h_orients = impl->h_orients_allocated; 487 break; 488 } 489 CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients, size * sizeof(bool))); 490 impl->d_orients_allocated = impl->d_orients; // We own the device memory 491 CeedCallHip(ceed, hipMemcpy(impl->d_orients, orients, size * sizeof(bool), hipMemcpyHostToDevice)); 492 } break; 493 case CEED_MEM_DEVICE: { 494 switch (copy_mode) { 495 case CEED_COPY_VALUES: 496 CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients, size * sizeof(bool))); 497 impl->d_orients_allocated = impl->d_orients; // We own the device memory 498 CeedCallHip(ceed, hipMemcpy(impl->d_orients, orients, size * sizeof(bool), hipMemcpyDeviceToDevice)); 499 break; 500 case CEED_OWN_POINTER: 501 impl->d_orients = (bool *)orients; 502 impl->d_orients_allocated = impl->d_orients; 503 break; 504 case CEED_USE_POINTER: 505 impl->d_orients = (bool *)orients; 506 break; 507 } 508 CeedCallBackend(CeedMalloc(size, &impl->h_orients_allocated)); 509 CeedCallHip(ceed, hipMemcpy(impl->h_orients_allocated, impl->d_orients, size * sizeof(bool), hipMemcpyDeviceToHost)); 510 impl->h_orients = impl->h_orients_allocated; 511 } break; 512 } 513 } else if (rstr_type == CEED_RESTRICTION_CURL_ORIENTED) { 514 switch (mem_type) { 515 case CEED_MEM_HOST: { 516 switch (copy_mode) { 517 case CEED_OWN_POINTER: 518 impl->h_curl_orients_allocated = (CeedInt8 *)curl_orients; 519 impl->h_curl_orients = (CeedInt8 *)curl_orients; 520 break; 521 case CEED_USE_POINTER: 522 impl->h_curl_orients = (CeedInt8 *)curl_orients; 523 break; 524 case CEED_COPY_VALUES: 525 CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_allocated)); 526 memcpy(impl->h_curl_orients_allocated, curl_orients, 3 * size * sizeof(CeedInt8)); 527 impl->h_curl_orients = impl->h_curl_orients_allocated; 528 break; 529 } 530 CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients, 3 * size * sizeof(CeedInt8))); 531 impl->d_curl_orients_allocated = impl->d_curl_orients; // We own the device memory 532 CeedCallHip(ceed, hipMemcpy(impl->d_curl_orients, curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyHostToDevice)); 533 } break; 534 case CEED_MEM_DEVICE: { 535 switch (copy_mode) { 536 case CEED_COPY_VALUES: 537 CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients, 3 * size * sizeof(CeedInt8))); 538 impl->d_curl_orients_allocated = impl->d_curl_orients; // We own the device memory 539 CeedCallHip(ceed, hipMemcpy(impl->d_curl_orients, curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToDevice)); 540 break; 541 case CEED_OWN_POINTER: 542 impl->d_curl_orients = (CeedInt8 *)curl_orients; 543 impl->d_curl_orients_allocated = impl->d_curl_orients; 544 break; 545 case CEED_USE_POINTER: 546 impl->d_curl_orients = (CeedInt8 *)curl_orients; 547 break; 548 } 549 CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_allocated)); 550 CeedCallHip(ceed, hipMemcpy(impl->h_curl_orients_allocated, impl->d_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToHost)); 551 impl->h_curl_orients = impl->h_curl_orients_allocated; 552 } break; 553 } 554 } 555 } 556 557 // Compile HIP kernels 558 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path)); 559 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 560 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 561 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 562 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 563 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, "RSTR_STRIDE_NODES", 564 strides[0], "RSTR_STRIDE_COMP", strides[1], "RSTR_STRIDE_ELEM", strides[2])); 565 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose)); 566 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose)); 567 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose)); 568 if (!is_deterministic) { 569 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose)); 570 } else { 571 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTransposeDet", &impl->OffsetTransposeDet)); 572 } 573 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedNoTranspose", &impl->OrientedNoTranspose)); 574 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedTranspose", &impl->OrientedTranspose)); 575 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedNoTranspose", &impl->CurlOrientedNoTranspose)); 576 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedNoTranspose", &impl->CurlOrientedUnsignedNoTranspose)); 577 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedTranspose", &impl->CurlOrientedTranspose)); 578 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedTranspose", &impl->CurlOrientedUnsignedTranspose)); 579 CeedCallBackend(CeedFree(&restriction_kernel_path)); 580 CeedCallBackend(CeedFree(&restriction_kernel_source)); 581 582 // Register backend functions 583 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Hip)); 584 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApplyUnsigned_Hip)); 585 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApplyUnoriented_Hip)); 586 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Hip)); 587 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOrientations", CeedElemRestrictionGetOrientations_Hip)); 588 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetCurlOrientations", CeedElemRestrictionGetCurlOrientations_Hip)); 589 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Hip)); 590 return CEED_ERROR_SUCCESS; 591 } 592 593 //------------------------------------------------------------------------------ 594