1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <stdbool.h> 12 #include <stddef.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Compile restriction kernels 22 //------------------------------------------------------------------------------ 23 static inline int CeedElemRestrictionSetupCompile_Hip(CeedElemRestriction rstr) { 24 Ceed ceed; 25 bool is_deterministic; 26 CeedInt num_elem, num_comp, elem_size, comp_stride; 27 CeedRestrictionType rstr_type; 28 char *restriction_kernel_path, *restriction_kernel_source; 29 CeedElemRestriction_Hip *impl; 30 31 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 32 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 33 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 34 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 35 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 36 CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride)); 37 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 38 is_deterministic = impl->d_l_vec_indices != NULL; 39 40 // Compile HIP kernels 41 switch (rstr_type) { 42 case CEED_RESTRICTION_STRIDED: { 43 CeedInt strides[3] = {1, num_elem * elem_size, elem_size}; 44 bool has_backend_strides; 45 46 CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 47 if (!has_backend_strides) { 48 CeedCallBackend(CeedElemRestrictionGetStrides(rstr, strides)); 49 } 50 51 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-strided.h", &restriction_kernel_path)); 52 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 53 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 54 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 55 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 56 "RSTR_NUM_COMP", num_comp, "RSTR_STRIDE_NODES", strides[0], "RSTR_STRIDE_COMP", strides[1], "RSTR_STRIDE_ELEM", 57 strides[2])); 58 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->ApplyNoTranspose)); 59 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->ApplyTranspose)); 60 } break; 61 case CEED_RESTRICTION_STANDARD: { 62 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &restriction_kernel_path)); 63 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 64 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 65 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 66 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 67 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 68 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 69 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyNoTranspose)); 70 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyTranspose)); 71 } break; 72 case CEED_RESTRICTION_ORIENTED: { 73 char *offset_kernel_path; 74 75 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-oriented.h", &restriction_kernel_path)); 76 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 77 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 78 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &offset_kernel_path)); 79 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, offset_kernel_path, &restriction_kernel_source)); 80 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 81 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 82 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 83 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 84 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedNoTranspose", &impl->ApplyNoTranspose)); 85 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyUnsignedNoTranspose)); 86 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedTranspose", &impl->ApplyTranspose)); 87 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyUnsignedTranspose)); 88 CeedCallBackend(CeedFree(&offset_kernel_path)); 89 } break; 90 case CEED_RESTRICTION_CURL_ORIENTED: { 91 char *offset_kernel_path; 92 93 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-curl-oriented.h", &restriction_kernel_path)); 94 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 95 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 96 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction-offset.h", &offset_kernel_path)); 97 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, offset_kernel_path, &restriction_kernel_source)); 98 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 99 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 100 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 101 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 102 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedNoTranspose", &impl->ApplyNoTranspose)); 103 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedNoTranspose", &impl->ApplyUnsignedNoTranspose)); 104 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyUnorientedNoTranspose)); 105 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedTranspose", &impl->ApplyTranspose)); 106 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedTranspose", &impl->ApplyUnsignedTranspose)); 107 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyUnorientedTranspose)); 108 CeedCallBackend(CeedFree(&offset_kernel_path)); 109 } break; 110 case CEED_RESTRICTION_POINTS: { 111 // LCOV_EXCL_START 112 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 113 // LCOV_EXCL_STOP 114 } break; 115 } 116 CeedCallBackend(CeedFree(&restriction_kernel_path)); 117 CeedCallBackend(CeedFree(&restriction_kernel_source)); 118 return CEED_ERROR_SUCCESS; 119 } 120 121 //------------------------------------------------------------------------------ 122 // Core apply restriction code 123 //------------------------------------------------------------------------------ 124 static inline int CeedElemRestrictionApply_Hip_Core(CeedElemRestriction rstr, CeedTransposeMode t_mode, bool use_signs, bool use_orients, 125 CeedVector u, CeedVector v, CeedRequest *request) { 126 Ceed ceed; 127 CeedRestrictionType rstr_type; 128 const CeedScalar *d_u; 129 CeedScalar *d_v; 130 CeedElemRestriction_Hip *impl; 131 132 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 133 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 134 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 135 136 // Assemble kernel if needed 137 if (!impl->module) { 138 CeedCallBackend(CeedElemRestrictionSetupCompile_Hip(rstr)); 139 } 140 141 // Get vectors 142 CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 143 if (t_mode == CEED_TRANSPOSE) { 144 // Sum into for transpose mode, e-vec to l-vec 145 CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 146 } else { 147 // Overwrite for notranspose mode, l-vec to e-vec 148 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 149 } 150 151 // Restrict 152 if (t_mode == CEED_NOTRANSPOSE) { 153 // L-vector -> E-vector 154 CeedInt elem_size; 155 156 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 157 const CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 158 const CeedInt grid = CeedDivUpInt(impl->num_nodes, block_size); 159 160 switch (rstr_type) { 161 case CEED_RESTRICTION_STRIDED: { 162 void *args[] = {&d_u, &d_v}; 163 164 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 165 } break; 166 case CEED_RESTRICTION_STANDARD: { 167 void *args[] = {&impl->d_ind, &d_u, &d_v}; 168 169 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 170 } break; 171 case CEED_RESTRICTION_ORIENTED: { 172 if (use_signs) { 173 void *args[] = {&impl->d_ind, &impl->d_orients, &d_u, &d_v}; 174 175 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 176 } else { 177 void *args[] = {&impl->d_ind, &d_u, &d_v}; 178 179 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedNoTranspose, grid, block_size, args)); 180 } 181 } break; 182 case CEED_RESTRICTION_CURL_ORIENTED: { 183 if (use_signs && use_orients) { 184 void *args[] = {&impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 185 186 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 187 } else if (use_orients) { 188 void *args[] = {&impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 189 190 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedNoTranspose, grid, block_size, args)); 191 } else { 192 void *args[] = {&impl->d_ind, &d_u, &d_v}; 193 194 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedNoTranspose, grid, block_size, args)); 195 } 196 } break; 197 case CEED_RESTRICTION_POINTS: { 198 // LCOV_EXCL_START 199 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 200 // LCOV_EXCL_STOP 201 } break; 202 } 203 } else { 204 // E-vector -> L-vector 205 const bool is_deterministic = impl->d_l_vec_indices != NULL; 206 const CeedInt block_size = 64; 207 const CeedInt grid = CeedDivUpInt(impl->num_nodes, block_size); 208 209 switch (rstr_type) { 210 case CEED_RESTRICTION_STRIDED: { 211 void *args[] = {&d_u, &d_v}; 212 213 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 214 } break; 215 case CEED_RESTRICTION_STANDARD: { 216 if (!is_deterministic) { 217 void *args[] = {&impl->d_ind, &d_u, &d_v}; 218 219 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 220 } else { 221 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 222 223 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 224 } 225 } break; 226 case CEED_RESTRICTION_ORIENTED: { 227 if (use_signs) { 228 if (!is_deterministic) { 229 void *args[] = {&impl->d_ind, &impl->d_orients, &d_u, &d_v}; 230 231 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 232 } else { 233 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_orients, &d_u, &d_v}; 234 235 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 236 } 237 } else { 238 if (!is_deterministic) { 239 void *args[] = {&impl->d_ind, &d_u, &d_v}; 240 241 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 242 } else { 243 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 244 245 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 246 } 247 } 248 } break; 249 case CEED_RESTRICTION_CURL_ORIENTED: { 250 if (use_signs && use_orients) { 251 if (!is_deterministic) { 252 void *args[] = {&impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 253 254 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 255 } else { 256 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_curl_orients, &d_u, &d_v}; 257 258 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 259 } 260 } else if (use_orients) { 261 if (!is_deterministic) { 262 void *args[] = {&impl->d_ind, &impl->d_curl_orients, &d_u, &d_v}; 263 264 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 265 } else { 266 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_curl_orients, &d_u, &d_v}; 267 268 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 269 } 270 } else { 271 if (!is_deterministic) { 272 void *args[] = {&impl->d_ind, &d_u, &d_v}; 273 274 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedTranspose, grid, block_size, args)); 275 } else { 276 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 277 278 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedTranspose, grid, block_size, args)); 279 } 280 } 281 } break; 282 case CEED_RESTRICTION_POINTS: { 283 // LCOV_EXCL_START 284 return CeedError(ceed, CEED_ERROR_UNSUPPORTED, "Backend does not implement restriction CeedElemRestrictionAtPoints"); 285 // LCOV_EXCL_STOP 286 } break; 287 } 288 } 289 290 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 291 292 // Restore arrays 293 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 294 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 295 return CEED_ERROR_SUCCESS; 296 } 297 298 //------------------------------------------------------------------------------ 299 // Apply restriction 300 //------------------------------------------------------------------------------ 301 static int CeedElemRestrictionApply_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 302 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, true, true, u, v, request); 303 } 304 305 //------------------------------------------------------------------------------ 306 // Apply unsigned restriction 307 //------------------------------------------------------------------------------ 308 static int CeedElemRestrictionApplyUnsigned_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 309 CeedRequest *request) { 310 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, true, u, v, request); 311 } 312 313 //------------------------------------------------------------------------------ 314 // Apply unoriented restriction 315 //------------------------------------------------------------------------------ 316 static int CeedElemRestrictionApplyUnoriented_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 317 CeedRequest *request) { 318 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, false, u, v, request); 319 } 320 321 //------------------------------------------------------------------------------ 322 // Get offsets 323 //------------------------------------------------------------------------------ 324 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 325 CeedElemRestriction_Hip *impl; 326 327 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 328 switch (mem_type) { 329 case CEED_MEM_HOST: 330 *offsets = impl->h_ind; 331 break; 332 case CEED_MEM_DEVICE: 333 *offsets = impl->d_ind; 334 break; 335 } 336 return CEED_ERROR_SUCCESS; 337 } 338 339 //------------------------------------------------------------------------------ 340 // Get orientations 341 //------------------------------------------------------------------------------ 342 static int CeedElemRestrictionGetOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const bool **orients) { 343 CeedElemRestriction_Hip *impl; 344 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 345 346 switch (mem_type) { 347 case CEED_MEM_HOST: 348 *orients = impl->h_orients; 349 break; 350 case CEED_MEM_DEVICE: 351 *orients = impl->d_orients; 352 break; 353 } 354 return CEED_ERROR_SUCCESS; 355 } 356 357 //------------------------------------------------------------------------------ 358 // Get curl-conforming orientations 359 //------------------------------------------------------------------------------ 360 static int CeedElemRestrictionGetCurlOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt8 **curl_orients) { 361 CeedElemRestriction_Hip *impl; 362 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 363 364 switch (mem_type) { 365 case CEED_MEM_HOST: 366 *curl_orients = impl->h_curl_orients; 367 break; 368 case CEED_MEM_DEVICE: 369 *curl_orients = impl->d_curl_orients; 370 break; 371 } 372 return CEED_ERROR_SUCCESS; 373 } 374 375 //------------------------------------------------------------------------------ 376 // Destroy restriction 377 //------------------------------------------------------------------------------ 378 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction rstr) { 379 Ceed ceed; 380 CeedElemRestriction_Hip *impl; 381 382 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 383 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 384 if (impl->module) { 385 CeedCallHip(ceed, hipModuleUnload(impl->module)); 386 } 387 CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 388 CeedCallHip(ceed, hipFree(impl->d_ind_allocated)); 389 CeedCallHip(ceed, hipFree(impl->d_t_offsets)); 390 CeedCallHip(ceed, hipFree(impl->d_t_indices)); 391 CeedCallHip(ceed, hipFree(impl->d_l_vec_indices)); 392 CeedCallBackend(CeedFree(&impl->h_orients_allocated)); 393 CeedCallHip(ceed, hipFree(impl->d_orients_allocated)); 394 CeedCallBackend(CeedFree(&impl->h_curl_orients_allocated)); 395 CeedCallHip(ceed, hipFree(impl->d_curl_orients_allocated)); 396 CeedCallBackend(CeedFree(&impl)); 397 return CEED_ERROR_SUCCESS; 398 } 399 400 //------------------------------------------------------------------------------ 401 // Create transpose offsets and indices 402 //------------------------------------------------------------------------------ 403 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction rstr, const CeedInt *indices) { 404 Ceed ceed; 405 bool *is_node; 406 CeedSize l_size; 407 CeedInt num_elem, elem_size, num_comp, num_nodes = 0; 408 CeedInt *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 409 CeedElemRestriction_Hip *impl; 410 411 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 412 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 413 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 414 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 415 CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 416 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 417 const CeedInt size_indices = num_elem * elem_size; 418 419 // Count num_nodes 420 CeedCallBackend(CeedCalloc(l_size, &is_node)); 421 422 for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 423 for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 424 impl->num_nodes = num_nodes; 425 426 // L-vector offsets array 427 CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 428 CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 429 for (CeedInt i = 0, j = 0; i < l_size; i++) { 430 if (is_node[i]) { 431 l_vec_indices[j] = i; 432 ind_to_offset[i] = j++; 433 } 434 } 435 CeedCallBackend(CeedFree(&is_node)); 436 437 // Compute transpose offsets and indices 438 const CeedInt size_offsets = num_nodes + 1; 439 440 CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 441 CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 442 // Count node multiplicity 443 for (CeedInt e = 0; e < num_elem; ++e) { 444 for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 445 } 446 // Convert to running sum 447 for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 448 // List all E-vec indices associated with L-vec node 449 for (CeedInt e = 0; e < num_elem; ++e) { 450 for (CeedInt i = 0; i < elem_size; ++i) { 451 const CeedInt lid = elem_size * e + i; 452 const CeedInt gid = indices[lid]; 453 454 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 455 } 456 } 457 // Reset running sum 458 for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 459 t_offsets[0] = 0; 460 461 // Copy data to device 462 // -- L-vector indices 463 CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 464 CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice)); 465 // -- Transpose offsets 466 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 467 CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice)); 468 // -- Transpose indices 469 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 470 CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice)); 471 472 // Cleanup 473 CeedCallBackend(CeedFree(&ind_to_offset)); 474 CeedCallBackend(CeedFree(&l_vec_indices)); 475 CeedCallBackend(CeedFree(&t_offsets)); 476 CeedCallBackend(CeedFree(&t_indices)); 477 return CEED_ERROR_SUCCESS; 478 } 479 480 //------------------------------------------------------------------------------ 481 // Create restriction 482 //------------------------------------------------------------------------------ 483 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients, 484 const CeedInt8 *curl_orients, CeedElemRestriction rstr) { 485 Ceed ceed, ceed_parent; 486 bool is_deterministic; 487 CeedInt num_elem, elem_size; 488 CeedRestrictionType rstr_type; 489 CeedElemRestriction_Hip *impl; 490 491 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 492 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 493 CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic)); 494 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 495 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 496 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 497 const CeedInt size = num_elem * elem_size; 498 499 CeedCallBackend(CeedCalloc(1, &impl)); 500 impl->num_nodes = size; 501 impl->h_ind = NULL; 502 impl->h_ind_allocated = NULL; 503 impl->d_ind = NULL; 504 impl->d_ind_allocated = NULL; 505 impl->d_t_indices = NULL; 506 impl->d_t_offsets = NULL; 507 impl->h_orients = NULL; 508 impl->h_orients_allocated = NULL; 509 impl->d_orients = NULL; 510 impl->d_orients_allocated = NULL; 511 impl->h_curl_orients = NULL; 512 impl->h_curl_orients_allocated = NULL; 513 impl->d_curl_orients = NULL; 514 impl->d_curl_orients_allocated = NULL; 515 CeedCallBackend(CeedElemRestrictionSetData(rstr, impl)); 516 517 // Set layouts 518 { 519 bool has_backend_strides; 520 CeedInt layout[3] = {1, size, elem_size}; 521 522 CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout)); 523 if (rstr_type == CEED_RESTRICTION_STRIDED) { 524 CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 525 if (has_backend_strides) { 526 CeedCallBackend(CeedElemRestrictionSetLLayout(rstr, layout)); 527 } 528 } 529 } 530 531 // Set up device offset/orientation arrays 532 if (rstr_type != CEED_RESTRICTION_STRIDED) { 533 switch (mem_type) { 534 case CEED_MEM_HOST: { 535 switch (copy_mode) { 536 case CEED_OWN_POINTER: 537 impl->h_ind_allocated = (CeedInt *)indices; 538 impl->h_ind = (CeedInt *)indices; 539 break; 540 case CEED_USE_POINTER: 541 impl->h_ind = (CeedInt *)indices; 542 break; 543 case CEED_COPY_VALUES: 544 CeedCallBackend(CeedMalloc(size, &impl->h_ind_allocated)); 545 memcpy(impl->h_ind_allocated, indices, size * sizeof(CeedInt)); 546 impl->h_ind = impl->h_ind_allocated; 547 break; 548 } 549 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 550 impl->d_ind_allocated = impl->d_ind; // We own the device memory 551 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice)); 552 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, indices)); 553 } break; 554 case CEED_MEM_DEVICE: { 555 switch (copy_mode) { 556 case CEED_COPY_VALUES: 557 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 558 impl->d_ind_allocated = impl->d_ind; // We own the device memory 559 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice)); 560 break; 561 case CEED_OWN_POINTER: 562 impl->d_ind = (CeedInt *)indices; 563 impl->d_ind_allocated = impl->d_ind; 564 break; 565 case CEED_USE_POINTER: 566 impl->d_ind = (CeedInt *)indices; 567 break; 568 } 569 CeedCallBackend(CeedMalloc(size, &impl->h_ind_allocated)); 570 CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, size * sizeof(CeedInt), hipMemcpyDeviceToHost)); 571 impl->h_ind = impl->h_ind_allocated; 572 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, indices)); 573 } break; 574 } 575 576 // Orientation data 577 if (rstr_type == CEED_RESTRICTION_ORIENTED) { 578 switch (mem_type) { 579 case CEED_MEM_HOST: { 580 switch (copy_mode) { 581 case CEED_OWN_POINTER: 582 impl->h_orients_allocated = (bool *)orients; 583 impl->h_orients = (bool *)orients; 584 break; 585 case CEED_USE_POINTER: 586 impl->h_orients = (bool *)orients; 587 break; 588 case CEED_COPY_VALUES: 589 CeedCallBackend(CeedMalloc(size, &impl->h_orients_allocated)); 590 memcpy(impl->h_orients_allocated, orients, size * sizeof(bool)); 591 impl->h_orients = impl->h_orients_allocated; 592 break; 593 } 594 CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients, size * sizeof(bool))); 595 impl->d_orients_allocated = impl->d_orients; // We own the device memory 596 CeedCallHip(ceed, hipMemcpy(impl->d_orients, orients, size * sizeof(bool), hipMemcpyHostToDevice)); 597 } break; 598 case CEED_MEM_DEVICE: { 599 switch (copy_mode) { 600 case CEED_COPY_VALUES: 601 CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients, size * sizeof(bool))); 602 impl->d_orients_allocated = impl->d_orients; // We own the device memory 603 CeedCallHip(ceed, hipMemcpy(impl->d_orients, orients, size * sizeof(bool), hipMemcpyDeviceToDevice)); 604 break; 605 case CEED_OWN_POINTER: 606 impl->d_orients = (bool *)orients; 607 impl->d_orients_allocated = impl->d_orients; 608 break; 609 case CEED_USE_POINTER: 610 impl->d_orients = (bool *)orients; 611 break; 612 } 613 CeedCallBackend(CeedMalloc(size, &impl->h_orients_allocated)); 614 CeedCallHip(ceed, hipMemcpy(impl->h_orients_allocated, impl->d_orients, size * sizeof(bool), hipMemcpyDeviceToHost)); 615 impl->h_orients = impl->h_orients_allocated; 616 } break; 617 } 618 } else if (rstr_type == CEED_RESTRICTION_CURL_ORIENTED) { 619 switch (mem_type) { 620 case CEED_MEM_HOST: { 621 switch (copy_mode) { 622 case CEED_OWN_POINTER: 623 impl->h_curl_orients_allocated = (CeedInt8 *)curl_orients; 624 impl->h_curl_orients = (CeedInt8 *)curl_orients; 625 break; 626 case CEED_USE_POINTER: 627 impl->h_curl_orients = (CeedInt8 *)curl_orients; 628 break; 629 case CEED_COPY_VALUES: 630 CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_allocated)); 631 memcpy(impl->h_curl_orients_allocated, curl_orients, 3 * size * sizeof(CeedInt8)); 632 impl->h_curl_orients = impl->h_curl_orients_allocated; 633 break; 634 } 635 CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients, 3 * size * sizeof(CeedInt8))); 636 impl->d_curl_orients_allocated = impl->d_curl_orients; // We own the device memory 637 CeedCallHip(ceed, hipMemcpy(impl->d_curl_orients, curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyHostToDevice)); 638 } break; 639 case CEED_MEM_DEVICE: { 640 switch (copy_mode) { 641 case CEED_COPY_VALUES: 642 CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients, 3 * size * sizeof(CeedInt8))); 643 impl->d_curl_orients_allocated = impl->d_curl_orients; // We own the device memory 644 CeedCallHip(ceed, hipMemcpy(impl->d_curl_orients, curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToDevice)); 645 break; 646 case CEED_OWN_POINTER: 647 impl->d_curl_orients = (CeedInt8 *)curl_orients; 648 impl->d_curl_orients_allocated = impl->d_curl_orients; 649 break; 650 case CEED_USE_POINTER: 651 impl->d_curl_orients = (CeedInt8 *)curl_orients; 652 break; 653 } 654 CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_allocated)); 655 CeedCallHip(ceed, hipMemcpy(impl->h_curl_orients_allocated, impl->d_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToHost)); 656 impl->h_curl_orients = impl->h_curl_orients_allocated; 657 } break; 658 } 659 } 660 } 661 662 // Register backend functions 663 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Hip)); 664 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApplyUnsigned_Hip)); 665 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApplyUnoriented_Hip)); 666 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Hip)); 667 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOrientations", CeedElemRestrictionGetOrientations_Hip)); 668 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetCurlOrientations", CeedElemRestrictionGetCurlOrientations_Hip)); 669 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Hip)); 670 return CEED_ERROR_SUCCESS; 671 } 672 673 //------------------------------------------------------------------------------ 674