1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <stdbool.h> 12 #include <stddef.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Compile restriction kernels 22 //------------------------------------------------------------------------------ 23 static inline int CeedElemRestrictionSetupCompile_Hip(CeedElemRestriction rstr) { 24 Ceed ceed; 25 bool is_deterministic; 26 CeedInt num_elem, num_comp, elem_size, comp_stride; 27 CeedRestrictionType rstr_type; 28 CeedElemRestriction_Hip *impl; 29 30 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 31 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 32 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 33 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 34 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 35 CeedCallBackend(CeedElemRestrictionGetCompStride(rstr, &comp_stride)); 36 if (rstr_type == CEED_RESTRICTION_POINTS) { 37 CeedCallBackend(CeedElemRestrictionGetMaxPointsInElement(rstr, &elem_size)); 38 } else { 39 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 40 } 41 is_deterministic = impl->d_l_vec_indices != NULL; 42 43 // Compile HIP kernels 44 switch (rstr_type) { 45 case CEED_RESTRICTION_STRIDED: { 46 const char restriction_kernel_source[] = "// Strided restriction source\n#include <ceed/jit-source/hip/hip-ref-restriction-strided.h>\n"; 47 bool has_backend_strides; 48 CeedInt strides[3] = {1, num_elem * elem_size, elem_size}; 49 50 CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 51 if (!has_backend_strides) { 52 CeedCallBackend(CeedElemRestrictionGetStrides(rstr, strides)); 53 } 54 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 55 "RSTR_NUM_COMP", num_comp, "RSTR_STRIDE_NODES", strides[0], "RSTR_STRIDE_COMP", strides[1], "RSTR_STRIDE_ELEM", 56 strides[2])); 57 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->ApplyNoTranspose)); 58 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->ApplyTranspose)); 59 } break; 60 case CEED_RESTRICTION_STANDARD: { 61 const char restriction_kernel_source[] = "// Standard restriction source\n#include <ceed/jit-source/hip/hip-ref-restriction-offset.h>\n"; 62 63 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 64 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 65 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 66 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyNoTranspose)); 67 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyTranspose)); 68 } break; 69 case CEED_RESTRICTION_POINTS: { 70 const char restriction_kernel_source[] = 71 "// AtPoints restriction source\n#include <ceed/jit-source/hip/hip-ref-restriction-at-points.h>\n\n" 72 "// Standard restriction source\n#include <ceed/jit-source/hip/hip-ref-restriction-offset.h>\n"; 73 74 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 75 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 76 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 77 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyNoTranspose)); 78 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "AtPointsTranspose", &impl->ApplyTranspose)); 79 } break; 80 case CEED_RESTRICTION_ORIENTED: { 81 const char restriction_kernel_source[] = 82 "// Oriented restriction source\n#include <ceed/jit-source/hip/hip-ref-restriction-oriented.h>\n\n" 83 "// Standard restriction source\n#include <ceed/jit-source/hip/hip-ref-restriction-offset.h>\n"; 84 85 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 86 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 87 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 88 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedNoTranspose", &impl->ApplyNoTranspose)); 89 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyUnsignedNoTranspose)); 90 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OrientedTranspose", &impl->ApplyTranspose)); 91 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyUnsignedTranspose)); 92 } break; 93 case CEED_RESTRICTION_CURL_ORIENTED: { 94 const char restriction_kernel_source[] = 95 "// Curl oriented restriction source\n#include <ceed/jit-source/hip/hip-ref-restriction-curl-oriented.h>\n\n" 96 "// Standard restriction source\n#include <ceed/jit-source/hip/hip-ref-restriction-offset.h>\n"; 97 98 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 6, "RSTR_ELEM_SIZE", elem_size, "RSTR_NUM_ELEM", num_elem, 99 "RSTR_NUM_COMP", num_comp, "RSTR_NUM_NODES", impl->num_nodes, "RSTR_COMP_STRIDE", comp_stride, 100 "USE_DETERMINISTIC", is_deterministic ? 1 : 0)); 101 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedNoTranspose", &impl->ApplyNoTranspose)); 102 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedNoTranspose", &impl->ApplyUnsignedNoTranspose)); 103 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->ApplyUnorientedNoTranspose)); 104 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedTranspose", &impl->ApplyTranspose)); 105 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "CurlOrientedUnsignedTranspose", &impl->ApplyUnsignedTranspose)); 106 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->ApplyUnorientedTranspose)); 107 108 } break; 109 } 110 return CEED_ERROR_SUCCESS; 111 } 112 113 //------------------------------------------------------------------------------ 114 // Core apply restriction code 115 //------------------------------------------------------------------------------ 116 static inline int CeedElemRestrictionApply_Hip_Core(CeedElemRestriction rstr, CeedTransposeMode t_mode, bool use_signs, bool use_orients, 117 CeedVector u, CeedVector v, CeedRequest *request) { 118 Ceed ceed; 119 CeedRestrictionType rstr_type; 120 const CeedScalar *d_u; 121 CeedScalar *d_v; 122 CeedElemRestriction_Hip *impl; 123 124 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 125 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 126 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 127 128 // Assemble kernel if needed 129 if (!impl->module) { 130 CeedCallBackend(CeedElemRestrictionSetupCompile_Hip(rstr)); 131 } 132 133 // Get vectors 134 CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 135 if (t_mode == CEED_TRANSPOSE) { 136 // Sum into for transpose mode, e-vec to l-vec 137 CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 138 } else { 139 // Overwrite for notranspose mode, l-vec to e-vec 140 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 141 } 142 143 // Restrict 144 if (t_mode == CEED_NOTRANSPOSE) { 145 // L-vector -> E-vector 146 CeedInt elem_size; 147 148 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 149 const CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 150 const CeedInt grid = CeedDivUpInt(impl->num_nodes, block_size); 151 152 switch (rstr_type) { 153 case CEED_RESTRICTION_STRIDED: { 154 void *args[] = {&d_u, &d_v}; 155 156 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 157 } break; 158 case CEED_RESTRICTION_POINTS: 159 case CEED_RESTRICTION_STANDARD: { 160 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 161 162 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 163 } break; 164 case CEED_RESTRICTION_ORIENTED: { 165 if (use_signs) { 166 void *args[] = {&impl->d_offsets, &impl->d_orients, &d_u, &d_v}; 167 168 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 169 } else { 170 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 171 172 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedNoTranspose, grid, block_size, args)); 173 } 174 } break; 175 case CEED_RESTRICTION_CURL_ORIENTED: { 176 if (use_signs && use_orients) { 177 void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v}; 178 179 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyNoTranspose, grid, block_size, args)); 180 } else if (use_orients) { 181 void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v}; 182 183 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedNoTranspose, grid, block_size, args)); 184 } else { 185 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 186 187 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedNoTranspose, grid, block_size, args)); 188 } 189 } break; 190 } 191 } else { 192 // E-vector -> L-vector 193 const bool is_deterministic = impl->d_l_vec_indices != NULL; 194 const CeedInt block_size = 64; 195 const CeedInt grid = CeedDivUpInt(impl->num_nodes, block_size); 196 197 switch (rstr_type) { 198 case CEED_RESTRICTION_STRIDED: { 199 void *args[] = {&d_u, &d_v}; 200 201 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 202 } break; 203 case CEED_RESTRICTION_POINTS: { 204 if (!is_deterministic) { 205 void *args[] = {&impl->d_offsets, &impl->d_points_per_elem, &d_u, &d_v}; 206 207 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 208 } else { 209 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_points_per_elem, &impl->d_t_offsets, &d_u, &d_v}; 210 211 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 212 } 213 } break; 214 case CEED_RESTRICTION_STANDARD: { 215 if (!is_deterministic) { 216 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 217 218 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 219 } else { 220 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 221 222 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 223 } 224 } break; 225 case CEED_RESTRICTION_ORIENTED: { 226 if (use_signs) { 227 if (!is_deterministic) { 228 void *args[] = {&impl->d_offsets, &impl->d_orients, &d_u, &d_v}; 229 230 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 231 } else { 232 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_orients, &d_u, &d_v}; 233 234 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 235 } 236 } else { 237 if (!is_deterministic) { 238 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 239 240 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 241 } else { 242 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 243 244 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 245 } 246 } 247 } break; 248 case CEED_RESTRICTION_CURL_ORIENTED: { 249 if (use_signs && use_orients) { 250 if (!is_deterministic) { 251 void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v}; 252 253 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 254 } else { 255 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_curl_orients, &d_u, &d_v}; 256 257 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyTranspose, grid, block_size, args)); 258 } 259 } else if (use_orients) { 260 if (!is_deterministic) { 261 void *args[] = {&impl->d_offsets, &impl->d_curl_orients, &d_u, &d_v}; 262 263 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 264 } else { 265 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &impl->d_curl_orients, &d_u, &d_v}; 266 267 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnsignedTranspose, grid, block_size, args)); 268 } 269 } else { 270 if (!is_deterministic) { 271 void *args[] = {&impl->d_offsets, &d_u, &d_v}; 272 273 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedTranspose, grid, block_size, args)); 274 } else { 275 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 276 277 CeedCallBackend(CeedRunKernel_Hip(ceed, impl->ApplyUnorientedTranspose, grid, block_size, args)); 278 } 279 } 280 } break; 281 } 282 } 283 284 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 285 286 // Restore arrays 287 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 288 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 289 return CEED_ERROR_SUCCESS; 290 } 291 292 //------------------------------------------------------------------------------ 293 // Apply restriction 294 //------------------------------------------------------------------------------ 295 static int CeedElemRestrictionApply_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 296 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, true, true, u, v, request); 297 } 298 299 //------------------------------------------------------------------------------ 300 // Apply unsigned restriction 301 //------------------------------------------------------------------------------ 302 static int CeedElemRestrictionApplyUnsigned_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 303 CeedRequest *request) { 304 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, true, u, v, request); 305 } 306 307 //------------------------------------------------------------------------------ 308 // Apply unoriented restriction 309 //------------------------------------------------------------------------------ 310 static int CeedElemRestrictionApplyUnoriented_Hip(CeedElemRestriction rstr, CeedTransposeMode t_mode, CeedVector u, CeedVector v, 311 CeedRequest *request) { 312 return CeedElemRestrictionApply_Hip_Core(rstr, t_mode, false, false, u, v, request); 313 } 314 315 //------------------------------------------------------------------------------ 316 // Get offsets 317 //------------------------------------------------------------------------------ 318 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 319 CeedElemRestriction_Hip *impl; 320 CeedRestrictionType rstr_type; 321 322 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 323 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 324 switch (mem_type) { 325 case CEED_MEM_HOST: 326 *offsets = rstr_type == CEED_RESTRICTION_POINTS ? impl->h_offsets_at_points : impl->h_offsets; 327 break; 328 case CEED_MEM_DEVICE: 329 *offsets = rstr_type == CEED_RESTRICTION_POINTS ? impl->d_offsets_at_points : impl->d_offsets; 330 break; 331 } 332 return CEED_ERROR_SUCCESS; 333 } 334 335 //------------------------------------------------------------------------------ 336 // Get orientations 337 //------------------------------------------------------------------------------ 338 static int CeedElemRestrictionGetOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const bool **orients) { 339 CeedElemRestriction_Hip *impl; 340 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 341 342 switch (mem_type) { 343 case CEED_MEM_HOST: 344 *orients = impl->h_orients; 345 break; 346 case CEED_MEM_DEVICE: 347 *orients = impl->d_orients; 348 break; 349 } 350 return CEED_ERROR_SUCCESS; 351 } 352 353 //------------------------------------------------------------------------------ 354 // Get curl-conforming orientations 355 //------------------------------------------------------------------------------ 356 static int CeedElemRestrictionGetCurlOrientations_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt8 **curl_orients) { 357 CeedElemRestriction_Hip *impl; 358 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 359 360 switch (mem_type) { 361 case CEED_MEM_HOST: 362 *curl_orients = impl->h_curl_orients; 363 break; 364 case CEED_MEM_DEVICE: 365 *curl_orients = impl->d_curl_orients; 366 break; 367 } 368 return CEED_ERROR_SUCCESS; 369 } 370 371 //------------------------------------------------------------------------------ 372 // Get offset for padded AtPoints E-layout 373 //------------------------------------------------------------------------------ 374 static int CeedElemRestrictionGetAtPointsElementOffset_Hip(CeedElemRestriction rstr, CeedInt elem, CeedSize *elem_offset) { 375 CeedInt layout[3]; 376 377 CeedCallBackend(CeedElemRestrictionGetELayout(rstr, layout)); 378 *elem_offset = 0 * layout[0] + 0 * layout[1] + elem * layout[2]; 379 return CEED_ERROR_SUCCESS; 380 } 381 382 //------------------------------------------------------------------------------ 383 // Destroy restriction 384 //------------------------------------------------------------------------------ 385 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction rstr) { 386 Ceed ceed; 387 CeedElemRestriction_Hip *impl; 388 389 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 390 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 391 if (impl->module) { 392 CeedCallHip(ceed, hipModuleUnload(impl->module)); 393 } 394 CeedCallBackend(CeedFree(&impl->h_offsets_owned)); 395 CeedCallHip(ceed, hipFree((CeedInt *)impl->d_offsets_owned)); 396 CeedCallHip(ceed, hipFree((CeedInt *)impl->d_t_offsets)); 397 CeedCallHip(ceed, hipFree((CeedInt *)impl->d_t_indices)); 398 CeedCallHip(ceed, hipFree((CeedInt *)impl->d_l_vec_indices)); 399 CeedCallBackend(CeedFree(&impl->h_orients_owned)); 400 CeedCallHip(ceed, hipFree((bool *)impl->d_orients_owned)); 401 CeedCallBackend(CeedFree(&impl->h_curl_orients_owned)); 402 CeedCallHip(ceed, hipFree((CeedInt8 *)impl->d_curl_orients_owned)); 403 CeedCallBackend(CeedFree(&impl->h_offsets_at_points_owned)); 404 CeedCallHip(ceed, hipFree((CeedInt8 *)impl->d_offsets_at_points_owned)); 405 CeedCallBackend(CeedFree(&impl->h_points_per_elem_owned)); 406 CeedCallHip(ceed, hipFree((CeedInt *)impl->d_points_per_elem_owned)); 407 CeedCallBackend(CeedFree(&impl)); 408 return CEED_ERROR_SUCCESS; 409 } 410 411 //------------------------------------------------------------------------------ 412 // Create transpose offsets and indices 413 //------------------------------------------------------------------------------ 414 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction rstr, const CeedInt elem_size, const CeedInt *indices) { 415 Ceed ceed; 416 bool *is_node; 417 CeedSize l_size; 418 CeedInt num_elem, num_comp, num_nodes = 0; 419 CeedInt *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices; 420 CeedRestrictionType rstr_type; 421 CeedElemRestriction_Hip *impl; 422 423 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 424 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 425 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 426 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 427 CeedCallBackend(CeedElemRestrictionGetLVectorSize(rstr, &l_size)); 428 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 429 const CeedInt size_indices = num_elem * elem_size; 430 431 // Count num_nodes 432 CeedCallBackend(CeedCalloc(l_size, &is_node)); 433 434 for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 435 for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 436 impl->num_nodes = num_nodes; 437 438 // L-vector offsets array 439 CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 440 CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 441 for (CeedInt i = 0, j = 0; i < l_size; i++) { 442 if (is_node[i]) { 443 l_vec_indices[j] = i; 444 ind_to_offset[i] = j++; 445 } 446 } 447 CeedCallBackend(CeedFree(&is_node)); 448 449 // Compute transpose offsets and indices 450 const CeedInt size_offsets = num_nodes + 1; 451 452 CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 453 CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 454 // Count node multiplicity 455 for (CeedInt e = 0; e < num_elem; ++e) { 456 for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 457 } 458 // Convert to running sum 459 for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 460 // List all E-vec indices associated with L-vec node 461 for (CeedInt e = 0; e < num_elem; ++e) { 462 for (CeedInt i = 0; i < elem_size; ++i) { 463 const CeedInt lid = elem_size * e + i; 464 const CeedInt gid = indices[lid]; 465 466 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 467 } 468 } 469 // Reset running sum 470 for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 471 t_offsets[0] = 0; 472 473 // Copy data to device 474 // -- L-vector indices 475 CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 476 CeedCallHip(ceed, hipMemcpy((CeedInt *)impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice)); 477 // -- Transpose offsets 478 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 479 CeedCallHip(ceed, hipMemcpy((CeedInt *)impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice)); 480 // -- Transpose indices 481 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 482 CeedCallHip(ceed, hipMemcpy((CeedInt *)impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice)); 483 484 // Cleanup 485 CeedCallBackend(CeedFree(&ind_to_offset)); 486 CeedCallBackend(CeedFree(&l_vec_indices)); 487 CeedCallBackend(CeedFree(&t_offsets)); 488 CeedCallBackend(CeedFree(&t_indices)); 489 return CEED_ERROR_SUCCESS; 490 } 491 492 //------------------------------------------------------------------------------ 493 // Create restriction 494 //------------------------------------------------------------------------------ 495 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *offsets, const bool *orients, 496 const CeedInt8 *curl_orients, CeedElemRestriction rstr) { 497 Ceed ceed, ceed_parent; 498 bool is_deterministic; 499 CeedInt num_elem, num_comp, elem_size; 500 CeedRestrictionType rstr_type; 501 CeedElemRestriction_Hip *impl; 502 503 CeedCallBackend(CeedElemRestrictionGetCeed(rstr, &ceed)); 504 CeedCallBackend(CeedGetParent(ceed, &ceed_parent)); 505 CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic)); 506 CeedCallBackend(CeedElemRestrictionGetNumElements(rstr, &num_elem)); 507 CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr, &num_comp)); 508 CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size)); 509 CeedCallBackend(CeedElemRestrictionGetType(rstr, &rstr_type)); 510 // Use max number of points as elem size for AtPoints restrictions 511 if (rstr_type == CEED_RESTRICTION_POINTS) { 512 CeedInt max_points = 0; 513 514 for (CeedInt i = 0; i < num_elem; i++) { 515 max_points = CeedIntMax(max_points, offsets[i + 1] - offsets[i]); 516 } 517 elem_size = max_points; 518 } 519 const CeedInt size = num_elem * elem_size; 520 521 CeedCallBackend(CeedCalloc(1, &impl)); 522 impl->num_nodes = size; 523 CeedCallBackend(CeedElemRestrictionSetData(rstr, impl)); 524 525 // Set layouts 526 { 527 bool has_backend_strides; 528 CeedInt layout[3] = {1, size, elem_size}; 529 530 CeedCallBackend(CeedElemRestrictionSetELayout(rstr, layout)); 531 if (rstr_type == CEED_RESTRICTION_STRIDED) { 532 CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &has_backend_strides)); 533 if (has_backend_strides) { 534 CeedCallBackend(CeedElemRestrictionSetLLayout(rstr, layout)); 535 } 536 } 537 } 538 539 // Pad AtPoints indices 540 if (rstr_type == CEED_RESTRICTION_POINTS) { 541 CeedSize offsets_len = elem_size * num_elem, at_points_size = num_elem + 1; 542 CeedInt max_points = elem_size, *offsets_padded, *points_per_elem; 543 544 CeedCheck(mem_type == CEED_MEM_HOST, ceed, CEED_ERROR_BACKEND, "only MemType Host supported when creating AtPoints restriction"); 545 CeedCallBackend(CeedMalloc(offsets_len, &offsets_padded)); 546 CeedCallBackend(CeedMalloc(num_elem, &points_per_elem)); 547 for (CeedInt i = 0; i < num_elem; i++) { 548 CeedInt num_points = offsets[i + 1] - offsets[i]; 549 550 points_per_elem[i] = num_points; 551 at_points_size += num_points; 552 // -- Copy all points in element 553 for (CeedInt j = 0; j < num_points; j++) { 554 offsets_padded[i * max_points + j] = offsets[offsets[i] + j] * num_comp; 555 } 556 // -- Replicate out last point in element 557 for (CeedInt j = num_points; j < max_points; j++) { 558 offsets_padded[i * max_points + j] = offsets[offsets[i] + num_points - 1] * num_comp; 559 } 560 } 561 CeedCallBackend(CeedSetHostCeedIntArray(offsets, copy_mode, at_points_size, &impl->h_offsets_at_points_owned, &impl->h_offsets_at_points_borrowed, 562 &impl->h_offsets_at_points)); 563 CeedCallHip(ceed, hipMalloc((void **)&impl->d_offsets_at_points_owned, at_points_size * sizeof(CeedInt))); 564 CeedCallHip(ceed, hipMemcpy((CeedInt **)impl->d_offsets_at_points_owned, impl->h_offsets_at_points, at_points_size * sizeof(CeedInt), 565 hipMemcpyHostToDevice)); 566 impl->d_offsets_at_points = (CeedInt *)impl->d_offsets_at_points_owned; 567 568 // -- Use padded offsets for the rest of the setup 569 offsets = (const CeedInt *)offsets_padded; 570 copy_mode = CEED_OWN_POINTER; 571 CeedCallBackend(CeedElemRestrictionSetAtPointsEVectorSize(rstr, elem_size * num_elem * num_comp)); 572 573 // -- Points per element 574 CeedCallBackend(CeedSetHostCeedIntArray(points_per_elem, CEED_OWN_POINTER, num_elem, &impl->h_points_per_elem_owned, 575 &impl->h_points_per_elem_borrowed, &impl->h_points_per_elem)); 576 CeedCallHip(ceed, hipMalloc((void **)&impl->d_points_per_elem_owned, num_elem * sizeof(CeedInt))); 577 CeedCallHip(ceed, 578 hipMemcpy((CeedInt **)impl->d_points_per_elem_owned, impl->h_points_per_elem, num_elem * sizeof(CeedInt), hipMemcpyHostToDevice)); 579 impl->d_points_per_elem = (CeedInt *)impl->d_points_per_elem_owned; 580 } 581 582 // Set up device offset/orientation arrays 583 if (rstr_type != CEED_RESTRICTION_STRIDED) { 584 switch (mem_type) { 585 case CEED_MEM_HOST: { 586 CeedCallBackend(CeedSetHostCeedIntArray(offsets, copy_mode, size, &impl->h_offsets_owned, &impl->h_offsets_borrowed, &impl->h_offsets)); 587 CeedCallHip(ceed, hipMalloc((void **)&impl->d_offsets_owned, size * sizeof(CeedInt))); 588 CeedCallHip(ceed, hipMemcpy((CeedInt **)impl->d_offsets_owned, impl->h_offsets, size * sizeof(CeedInt), hipMemcpyHostToDevice)); 589 impl->d_offsets = (CeedInt *)impl->d_offsets_owned; 590 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, elem_size, offsets)); 591 } break; 592 case CEED_MEM_DEVICE: { 593 CeedCallBackend(CeedSetDeviceCeedIntArray_Hip(ceed, offsets, copy_mode, size, &impl->d_offsets_owned, &impl->d_offsets_borrowed, 594 (const CeedInt **)&impl->d_offsets)); 595 CeedCallBackend(CeedMalloc(size, &impl->h_offsets_owned)); 596 CeedCallHip(ceed, hipMemcpy((CeedInt **)impl->h_offsets_owned, impl->d_offsets, size * sizeof(CeedInt), hipMemcpyDeviceToHost)); 597 impl->h_offsets = impl->h_offsets_owned; 598 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(rstr, elem_size, offsets)); 599 } break; 600 } 601 602 // Orientation data 603 if (rstr_type == CEED_RESTRICTION_ORIENTED) { 604 switch (mem_type) { 605 case CEED_MEM_HOST: { 606 CeedCallBackend(CeedSetHostBoolArray(orients, copy_mode, size, &impl->h_orients_owned, &impl->h_orients_borrowed, &impl->h_orients)); 607 CeedCallHip(ceed, hipMalloc((void **)&impl->d_orients_owned, size * sizeof(bool))); 608 CeedCallHip(ceed, hipMemcpy((bool *)impl->d_orients_owned, impl->h_orients, size * sizeof(bool), hipMemcpyHostToDevice)); 609 impl->d_orients = impl->d_orients_owned; 610 } break; 611 case CEED_MEM_DEVICE: { 612 CeedCallBackend(CeedSetDeviceBoolArray_Hip(ceed, orients, copy_mode, size, &impl->d_orients_owned, &impl->d_orients_borrowed, 613 (const bool **)&impl->d_orients)); 614 CeedCallBackend(CeedMalloc(size, &impl->h_orients_owned)); 615 CeedCallHip(ceed, hipMemcpy((bool *)impl->h_orients_owned, impl->d_orients, size * sizeof(bool), hipMemcpyDeviceToHost)); 616 impl->h_orients = impl->h_orients_owned; 617 } break; 618 } 619 } else if (rstr_type == CEED_RESTRICTION_CURL_ORIENTED) { 620 switch (mem_type) { 621 case CEED_MEM_HOST: { 622 CeedCallBackend(CeedSetHostCeedInt8Array(curl_orients, copy_mode, 3 * size, &impl->h_curl_orients_owned, &impl->h_curl_orients_borrowed, 623 &impl->h_curl_orients)); 624 CeedCallHip(ceed, hipMalloc((void **)&impl->d_curl_orients_owned, 3 * size * sizeof(CeedInt8))); 625 CeedCallHip(ceed, 626 hipMemcpy((CeedInt8 *)impl->d_curl_orients_owned, impl->h_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyHostToDevice)); 627 impl->d_curl_orients = impl->d_curl_orients_owned; 628 } break; 629 case CEED_MEM_DEVICE: { 630 CeedCallBackend(CeedSetDeviceCeedInt8Array_Hip(ceed, curl_orients, copy_mode, 3 * size, &impl->d_curl_orients_owned, 631 &impl->d_curl_orients_borrowed, (const CeedInt8 **)&impl->d_curl_orients)); 632 CeedCallBackend(CeedMalloc(3 * size, &impl->h_curl_orients_owned)); 633 CeedCallHip(ceed, 634 hipMemcpy((CeedInt8 *)impl->h_curl_orients_owned, impl->d_curl_orients, 3 * size * sizeof(CeedInt8), hipMemcpyDeviceToHost)); 635 impl->h_curl_orients = impl->h_curl_orients_owned; 636 } break; 637 } 638 } 639 } 640 641 // Register backend functions 642 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Apply", CeedElemRestrictionApply_Hip)); 643 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnsigned", CeedElemRestrictionApplyUnsigned_Hip)); 644 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "ApplyUnoriented", CeedElemRestrictionApplyUnoriented_Hip)); 645 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOffsets", CeedElemRestrictionGetOffsets_Hip)); 646 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetOrientations", CeedElemRestrictionGetOrientations_Hip)); 647 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetCurlOrientations", CeedElemRestrictionGetCurlOrientations_Hip)); 648 if (rstr_type == CEED_RESTRICTION_POINTS) { 649 CeedCallBackend( 650 CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "GetAtPointsElementOffset", CeedElemRestrictionGetAtPointsElementOffset_Hip)); 651 } 652 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", rstr, "Destroy", CeedElemRestrictionDestroy_Hip)); 653 return CEED_ERROR_SUCCESS; 654 } 655 656 //------------------------------------------------------------------------------ 657