1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <stdbool.h> 12 #include <stddef.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Core apply restriction code 22 //------------------------------------------------------------------------------ 23 static inline int CeedElemRestrictionApply_Hip_Core(CeedElemRestriction rstr, CeedTransposeMode t_mode, bool use_signs, bool use_orients, 24 CeedVector u, CeedVector v, CeedRequest *request) { 25 Ceed ceed; 26 CeedInt num_elem, elem_size; 27 CeedRestrictionType rstr_type; 28 const CeedScalar *d_u; 29 CeedScalar *d_v; 30 CeedElemRestriction_Hip *impl; 31 hipFunction_t kernel; 32 33 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 34 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 35 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 36 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 37 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 38 const CeedInt num_nodes = impl->num_nodes; 39 40 // Get vectors 41 CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 42 if (t_mode == CEED_TRANSPOSE) { 43 // Sum into for transpose mode, e-vec to l-vec 44 CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 45 } else { 46 // Overwrite for notranspose mode, l-vec to e-vec 47 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 48 } 49 50 // Restrict 51 if (t_mode == CEED_NOTRANSPOSE) { 52 // L-vector -> E-vector 53 const CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 54 const CeedInt grid = CeedDivUpInt(num_nodes, block_size); 55 56 switch (rstr_type) { 57 case CEED_RESTRICTION_STRIDED: { 58 kernel = impl->StridedNoTranspose; 59 void *args[] = {&num_elem, &d_u, &d_v}; 60 61 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 62 } break; 63 case CEED_RESTRICTION_STANDARD: { 64 kernel = impl->OffsetNoTranspose; 65 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 66 67 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 68 } break; 69 case CEED_RESTRICTION_ORIENTED: { 70 if (use_signs) { 71 kernel = impl->OrientedNoTranspose; 72 void *args[] = {&num_elem, &impl->d_ind, &impl->d_orients, &d_u, &d_v}; 73 74 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 75 } else { 76 kernel = impl->OffsetNoTranspose; 77 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 78 79 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 80 } 81 } break; 82 case CEED_RESTRICTION_CURL_ORIENTED: { 83 if (use_signs && use_orients) { 84 kernel = impl->CurlOrientedNoTranspose; 85 void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 86 87 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 88 } else if (use_orients) { 89 kernel = impl->CurlOrientedUnsignedNoTranspose; 90 void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 91 92 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 93 } else { 94 kernel = impl->OffsetNoTranspose; 95 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 96 97 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 98 } 99 } break; 100 } 101 } else { 102 // E-vector -> L-vector 103 const CeedInt block_size = 64; 104 const CeedInt grid = CeedDivUpInt(num_nodes, block_size); 105 106 switch (rstr_type) { 107 case CEED_RESTRICTION_STRIDED: { 108 kernel = impl->StridedTranspose; 109 void *args[] = {&num_elem, &d_u, &d_v}; 110 111 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 112 } break; 113 case CEED_RESTRICTION_STANDARD: { 114 if (impl->OffsetTranspose) { 115 kernel = impl->OffsetTranspose; 116 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 117 118 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 119 } else { 120 kernel = impl->OffsetTransposeDet; 121 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 122 123 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 124 } 125 } break; 126 case CEED_RESTRICTION_ORIENTED: { 127 if (use_signs) { 128 kernel = impl->OrientedTranspose; 129 void *args[] = {&num_elem, &impl->d_ind, &impl->d_orients, &d_u, &d_v}; 130 131 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 132 } else { 133 if (impl->OffsetTranspose) { 134 kernel = impl->OffsetTranspose; 135 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 136 137 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 138 } else { 139 kernel = impl->OffsetTransposeDet; 140 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 141 142 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 143 } 144 } 145 } break; 146 case CEED_RESTRICTION_CURL_ORIENTED: { 147 if (use_signs && use_orients) { 148 kernel = impl->CurlOrientedTranspose; 149 void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 150 151 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 152 } else if (use_orients) { 153 kernel = impl->CurlOrientedUnsignedTranspose; 154 void *args[] = {&num_elem, &impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 155 156 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 157 } else { 158 if (impl->OffsetTranspose) { 159 kernel = impl->OffsetTranspose; 160 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 161 162 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 163 } else { 164 kernel = impl->OffsetTransposeDet; 165 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 166 167 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, grid, block_size, args)); 168 } 169 } 170 } break; 171 } 172 } 173 174 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 175 176 // Restore arrays 177 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 178 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 179 return CEED_ERROR_SUCCESS; 180 } 181 182 //------------------------------------------------------------------------------ 183 // Apply restriction 184 //------------------------------------------------------------------------------ 185 static int CeedElemRestrictionApply_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 186 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, true, true, u, v, request); 187 } 188 189 //------------------------------------------------------------------------------ 190 // Apply unsigned restriction 191 //------------------------------------------------------------------------------ 192 static int CeedElemRestrictionApplyUnsigned_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 193 CeedRequest *request) { 194 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, true, u, v, request); 195 } 196 197 //------------------------------------------------------------------------------ 198 // Apply unoriented restriction 199 //------------------------------------------------------------------------------ 200 static int CeedElemRestrictionApplyUnoriented_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 201 CeedRequest *request) { 202 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, false, u, v, request); 203 } 204 205 //------------------------------------------------------------------------------ 206 // Get offsets 207 //------------------------------------------------------------------------------ 208 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 209 CeedElemRestriction_Hip *impl; 210 211 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 212 switch (mem_type) { 213 case CEED_MEM_HOST: 214 *offsets = impl->h_ind; 215 break; 216 case CEED_MEM_DEVICE: 217 *offsets = impl->d_ind; 218 break; 219 } 220 return CEED_ERROR_SUCCESS; 221 } 222 223 //------------------------------------------------------------------------------ 224 // Get orientations 225 //------------------------------------------------------------------------------ 226 static int CeedElemRestrictionGetOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const bool **orients) { 227 CeedElemRestriction_Hip *impl; 228 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 229 230 switch (mem_type) { 231 case CEED_MEM_HOST: 232 *orients = impl->h_orients; 233 break; 234 case CEED_MEM_DEVICE: 235 *orients = impl->d_orients; 236 break; 237 } 238 return CEED_ERROR_SUCCESS; 239 } 240 241 //------------------------------------------------------------------------------ 242 // Get curl-conforming orientations 243 //------------------------------------------------------------------------------ 244 static int CeedElemRestrictionGetCurlOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt8 **curl_orients) { 245 CeedElemRestriction_Hip *impl; 246 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 247 248 switch (mem_type) { 249 case CEED_MEM_HOST: 250 *curl_orients = impl->h_curl_orients; 251 break; 252 case CEED_MEM_DEVICE: 253 *curl_orients = impl->d_curl_orients; 254 break; 255 } 256 return CEED_ERROR_SUCCESS; 257 } 258 259 //------------------------------------------------------------------------------ 260 // Destroy restriction 261 //------------------------------------------------------------------------------ 262 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction rstr) { 263 Ceed ceed; 264 CeedElemRestriction_Hip *impl; 265 266 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 267 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 268 CeedCallHip(ceed, hipModuleUnload(impl->module)); 269 CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 270 CeedCallHip(ceed, hipFree(impl->d_ind_allocated)); 271 CeedCallHip(ceed, hipFree(impl->d_t_offsets)); 272 CeedCallHip(ceed, hipFree(impl->d_t_indices)); 273 CeedCallHip(ceed, hipFree(impl->d_l_vec_indices)); 274 CeedCallBackend(CeedFree(&impl->h_orients_allocated)); 275 CeedCallHip(ceed, hipFree(impl->d_orients_allocated)); 276 CeedCallBackend(CeedFree(&impl->h_curl_orients_allocated)); 277 CeedCallHip(ceed, hipFree(impl->d_curl_orients_allocated)); 278 CeedCallBackend(CeedFree(&impl)); 279 return CEED_ERROR_SUCCESS; 280 } 281 282 //------------------------------------------------------------------------------ 283 // Create transpose offsets and indices 284 //------------------------------------------------------------------------------ 285 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction rstr, const CeedInt *indices) { 286 Ceed ceed; 287 bool *is_node; 288 CeedSize l_size; 289 CeedInt num_elem, elem_size, num_comp, num_nodes = 0; 290 CeedInt *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 291 CeedElemRestriction_Hip *impl; 292 293 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 294 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 295 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 296 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 297 CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 298 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 299 const CeedInt size_indices = num_elem * elem_size; 300 301 // Count num_nodes 302 CeedCallBackend(CeedCalloc(l_size, &is_node)); 303 304 for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 305 for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 306 impl->num_nodes = num_nodes; 307 308 // L-vector offsets array 309 CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 310 CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 311 for (CeedInt i = 0, j = 0; i < l_size; i++) { 312 if (is_node[i]) { 313 l_vec_indices[j] = i; 314 ind_to_offset[i] = j++; 315 } 316 } 317 CeedCallBackend(CeedFree(&is_node)); 318 319 // Compute transpose offsets and indices 320 const CeedInt size_offsets = num_nodes + 1; 321 322 CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 323 CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 324 // Count node multiplicity 325 for (CeedInt e = 0; e < num_elem; ++e) { 326 for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 327 } 328 // Convert to running sum 329 for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 330 // List all E-vec indices associated with L-vec node 331 for (CeedInt e = 0; e < num_elem; ++e) { 332 for (CeedInt i = 0; i < elem_size; ++i) { 333 const CeedInt lid = elem_size * e + i; 334 const CeedInt gid = indices[lid]; 335 336 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 337 } 338 } 339 // Reset running sum 340 for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 341 t_offsets[0] = 0; 342 343 // Copy data to device 344 // -- L-vector indices 345 CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 346 CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice)); 347 // -- Transpose offsets 348 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 349 CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice)); 350 // -- Transpose indices 351 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 352 CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice)); 353 354 // Cleanup 355 CeedCallBackend(CeedFree(&ind_to_offset)); 356 CeedCallBackend(CeedFree(&l_vec_indices)); 357 CeedCallBackend(CeedFree(&t_offsets)); 358 CeedCallBackend(CeedFree(&t_indices)); 359 return CEED_ERROR_SUCCESS; 360 } 361 362 //------------------------------------------------------------------------------ 363 // Create restriction 364 //------------------------------------------------------------------------------ 365 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients, 366 const CeedInt8 *curl_orients, CeedElemRestriction rstr) { 367 Ceed ceed, ceed_parent; 368 bool is_deterministic; 369 CeedInt num_elem, num_comp, elem_size, comp_stride = 1; 370 CeedRestrictionType rstr_type; 371 char *restriction_kernel_path, *restriction_kernel_source; 372 CeedElemRestriction_Hip *impl; 373 374 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 375 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 376 CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic)); 377 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 378 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 379 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 380 const CeedInt size = num_elem * elem_size; 381 CeedInt strides[3] = {1, size, elem_size}; 382 CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 383 384 // Stride data 385 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 386 if (rstr_type == CEED_RESTRICTION_STRIDED) { 387 bool has_backend_strides; 388 389 CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 390 if (!has_backend_strides) { 391 CeedCallBackend(CeedElemRestrictionGetStrides(rstr, &strides)); 392 } 393 } else { 394 CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride)); 395 } 396 397 CeedCallBackend(CeedCalloc(1, &impl)); 398 impl->num_nodes = size; 399 impl->h_ind = NULL; 400 impl->h_ind_allocated = NULL; 401 impl->d_ind = NULL; 402 impl->d_ind_allocated = NULL; 403 impl->d_t_indices = NULL; 404 impl->d_t_offsets = NULL; 405 impl->h_orients = NULL; 406 impl->h_orients_allocated = NULL; 407 impl->d_orients = NULL; 408 impl->d_orients_allocated = NULL; 409 impl->h_curl_orients = NULL; 410 impl->h_curl_orients_allocated = NULL; 411 impl->d_curl_orients = NULL; 412 impl->d_curl_orients_allocated = NULL; 413 CeedCallBackend(CeedElemRestrictionSetData(rstr, impl)); 414 CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout)); 415 416 // Set up device offset/orientation arrays 417 if (rstr_type != CEED_RESTRICTION_STRIDED) { 418 switch (mem_type) { 419 case CEED_MEM_HOST: { 420 switch (copy_mode) { 421 case CEED_OWN_POINTER: 422 impl->h_ind_allocated = (CeedInt *)indices; 423 impl->h_ind = (CeedInt *)indices; 424 break; 425 case CEED_USE_POINTER: 426 impl->h_ind = (CeedInt *)indices; 427 break; 428 case CEED_COPY_VALUES: 429 CeedCallBackend(CeedMalloc(size, &impl->h_ind_allocated)); 430 memcpy(impl->h_ind_allocated, indices, size * sizeof(CeedInt)); 431 impl->h_ind = impl->h_ind_allocated; 432 break; 433 } 434 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 435 impl->d_ind_allocated = impl->d_ind; // We own the device memory 436 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice)); 437 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, indices)); 438 } break; 439 case CEED_MEM_DEVICE: { 440 switch (copy_mode) { 441 case CEED_COPY_VALUES: 442 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 443 impl->d_ind_allocated = impl->d_ind; // We own the device memory 444 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice)); 445 break; 446 case CEED_OWN_POINTER: 447 impl->d_ind = (CeedInt *)indices; 448 impl->d_ind_allocated = impl->d_ind; 449 break; 450 case CEED_USE_POINTER: 451 impl->d_ind = (CeedInt *)indices; 452 break; 453 } 454 CeedCallBackend(CeedMalloc(size, &impl->h_ind_allocated)); 455 CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, size * sizeof(CeedInt), hipMemcpyDeviceToHost)); 456 impl->h_ind = impl->h_ind_allocated; 457 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, indices)); 458 } break; 459 } 460 461 // Orientation data 462 if (rstr_type == CEED_RESTRICTION_ORIENTED) { 463 switch (mem_type) { 464 case CEED_MEM_HOST: { 465 switch (copy_mode) { 466 case CEED_OWN_POINTER: 467 impl->h_orients_allocated = (bool *)orients; 468 impl->h_orients = (bool *)orients; 469 break; 470 case CEED_USE_POINTER: 471 impl->h_orients = (bool *)orients; 472 break; 473 case CEED_COPY_VALUES: 474 CeedCallBackend(CeedMalloc(size, &impl->h_orients_allocated)); 475 memcpy(impl->h_orients_allocated, orients, size * sizeof(bool)); 476 impl->h_orients = impl->h_orients_allocated; 477 break; 478 } 479 CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients, size * sizeof(bool))); 480 impl->d_orients_allocated = impl->d_orients; // We own the device memory 481 CeedCallHip(ceed, hipMemcpy(impl->d_orients, orients, size * sizeof(bool), hipMemcpyHostToDevice)); 482 } break; 483 case CEED_MEM_DEVICE: { 484 switch (copy_mode) { 485 case CEED_COPY_VALUES: 486 CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients, size * sizeof(bool))); 487 impl->d_orients_allocated = impl->d_orients; // We own the device memory 488 CeedCallHip(ceed, hipMemcpy(impl->d_orients, orients, size * sizeof(bool), hipMemcpyDeviceToDevice)); 489 break; 490 case CEED_OWN_POINTER: 491 impl->d_orients = (bool *)orients; 492 impl->d_orients_allocated = impl->d_orients; 493 break; 494 case CEED_USE_POINTER: 495 impl->d_orients = (bool *)orients; 496 break; 497 } 498 CeedCallBackend(CeedMalloc(size, &impl->h_orients_allocated)); 499 CeedCallHip(ceed, hipMemcpy(impl->h_orients_allocated, impl->d_orients, size * sizeof(bool), hipMemcpyDeviceToHost)); 500 impl->h_orients = impl->h_orients_allocated; 501 } break; 502 } 503 } else if (rstr_type == CEED_RESTRICTION_CURL_ORIENTED) { 504 switch (mem_type) { 505 case CEED_MEM_HOST: { 506 switch (copy_mode) { 507 case CEED_OWN_POINTER: 508 impl->h_curl_orients_allocated = (CeedInt8 *)curl_orients; 509 impl->h_curl_orients = (CeedInt8 *)curl_orients; 510 break; 511 case CEED_USE_POINTER: 512 impl->h_curl_orients = (CeedInt8 *)curl_orients; 513 break; 514 case CEED_COPY_VALUES: 515 CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_allocated)); 516 memcpy(impl->h_curl_orients_allocated, curl_orients, 3 * size * sizeof(CeedInt8)); 517 impl->h_curl_orients = impl->h_curl_orients_allocated; 518 break; 519 } 520 CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients, 3 * size * sizeof(CeedInt8))); 521 impl->d_curl_orients_allocated = impl->d_curl_orients; // We own the device memory 522 CeedCallHip(ceed, hipMemcpy(impl->d_curl_orients, curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyHostToDevice)); 523 } break; 524 case CEED_MEM_DEVICE: { 525 switch (copy_mode) { 526 case CEED_COPY_VALUES: 527 CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients, 3 * size * sizeof(CeedInt8))); 528 impl->d_curl_orients_allocated = impl->d_curl_orients; // We own the device memory 529 CeedCallHip(ceed, hipMemcpy(impl->d_curl_orients, curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToDevice)); 530 break; 531 case CEED_OWN_POINTER: 532 impl->d_curl_orients = (CeedInt8 *)curl_orients; 533 impl->d_curl_orients_allocated = impl->d_curl_orients; 534 break; 535 case CEED_USE_POINTER: 536 impl->d_curl_orients = (CeedInt8 *)curl_orients; 537 break; 538 } 539 CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_allocated)); 540 CeedCallHip(ceed, hipMemcpy(impl->h_curl_orients_allocated, impl->d_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToHost)); 541 impl->h_curl_orients = impl->h_curl_orients_allocated; 542 } break; 543 } 544 } 545 } 546 547 // Compile HIP kernels 548 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path)); 549 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 550 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 551 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 552 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 553 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, "RSTR_STRIDE_NODES", 554 strides[0], "RSTR_STRIDE_COMP", strides[1], "RSTR_STRIDE_ELEM", strides[2])); 555 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose)); 556 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose)); 557 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose)); 558 if (!is_deterministic) { 559 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose)); 560 } else { 561 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTransposeDet", &impl->OffsetTransposeDet)); 562 } 563 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedNoTranspose", &impl->OrientedNoTranspose)); 564 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedTranspose", &impl->OrientedTranspose)); 565 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedNoTranspose", &impl->CurlOrientedNoTranspose)); 566 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedNoTranspose", &impl->CurlOrientedUnsignedNoTranspose)); 567 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedTranspose", &impl->CurlOrientedTranspose)); 568 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedTranspose", &impl->CurlOrientedUnsignedTranspose)); 569 CeedCallBackend(CeedFree(&restriction_kernel_path)); 570 CeedCallBackend(CeedFree(&restriction_kernel_source)); 571 572 // Register backend functions 573 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Hip)); 574 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApplyUnsigned_Hip)); 575 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApplyUnoriented_Hip)); 576 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Hip)); 577 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOrientations", CeedElemRestrictionGetOrientations_Hip)); 578 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetCurlOrientations", CeedElemRestrictionGetCurlOrientations_Hip)); 579 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Hip)); 580 return CEED_ERROR_SUCCESS; 581 } 582 583 //------------------------------------------------------------------------------ 584