1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <hip/hip_runtime.h> 12 #include <stdbool.h> 13 #include <stddef.h> 14 #include <string.h> 15 #include "ceed-hip-ref.h" 16 #include "../hip/ceed-hip-compile.h" 17 18 //------------------------------------------------------------------------------ 19 // Apply restriction 20 //------------------------------------------------------------------------------ 21 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r, 22 CeedTransposeMode t_mode, CeedVector u, 23 CeedVector v, CeedRequest *request) { 24 int ierr; 25 CeedElemRestriction_Hip *impl; 26 ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr); 27 Ceed ceed; 28 ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr); 29 Ceed_Hip *data; 30 ierr = CeedGetData(ceed, &data); CeedChkBackend(ierr); 31 const CeedInt block_size = 64; 32 const CeedInt num_nodes = impl->num_nodes; 33 CeedInt num_elem, elem_size; 34 CeedElemRestrictionGetNumElements(r, &num_elem); 35 ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr); 36 hipFunction_t kernel; 37 38 // Get vectors 39 const CeedScalar *d_u; 40 CeedScalar *d_v; 41 ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr); 42 if (t_mode == CEED_TRANSPOSE) { 43 // Sum into for transpose mode, e-vec to l-vec 44 ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr); 45 } else { 46 // Overwrite for notranspose mode, l-vec to e-vec 47 ierr = CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr); 48 } 49 50 // Restrict 51 if (t_mode == CEED_NOTRANSPOSE) { 52 // L-vector -> E-vector 53 if (impl->d_ind) { 54 // -- Offsets provided 55 kernel = impl->OffsetNoTranspose; 56 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 57 CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 58 ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), 59 block_size, args); CeedChkBackend(ierr); 60 } else { 61 // -- Strided restriction 62 kernel = impl->StridedNoTranspose; 63 void *args[] = {&num_elem, &d_u, &d_v}; 64 CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 65 ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), 66 block_size, args); CeedChkBackend(ierr); 67 } 68 } else { 69 // E-vector -> L-vector 70 if (impl->d_ind) { 71 // -- Offsets provided 72 kernel = impl->OffsetTranspose; 73 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, 74 &impl->d_t_offsets, &d_u, &d_v 75 }; 76 ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), 77 block_size, args); CeedChkBackend(ierr); 78 } else { 79 // -- Strided restriction 80 kernel = impl->StridedTranspose; 81 void *args[] = {&num_elem, &d_u, &d_v}; 82 ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), 83 block_size, args); CeedChkBackend(ierr); 84 } 85 } 86 87 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) 88 *request = NULL; 89 90 // Restore arrays 91 ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr); 92 ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr); 93 return CEED_ERROR_SUCCESS; 94 } 95 96 //------------------------------------------------------------------------------ 97 // Blocked not supported 98 //------------------------------------------------------------------------------ 99 int CeedElemRestrictionApplyBlock_Hip(CeedElemRestriction r, CeedInt block, 100 CeedTransposeMode t_mode, CeedVector u, 101 CeedVector v, CeedRequest *request) { 102 // LCOV_EXCL_START 103 int ierr; 104 Ceed ceed; 105 ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr); 106 return CeedError(ceed, CEED_ERROR_BACKEND, 107 "Backend does not implement blocked restrictions"); 108 // LCOV_EXCL_STOP 109 } 110 111 //------------------------------------------------------------------------------ 112 // Get offsets 113 //------------------------------------------------------------------------------ 114 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, 115 CeedMemType mtype, const CeedInt **offsets) { 116 int ierr; 117 CeedElemRestriction_Hip *impl; 118 ierr = CeedElemRestrictionGetData(rstr, &impl); CeedChkBackend(ierr); 119 120 switch (mtype) { 121 case CEED_MEM_HOST: 122 *offsets = impl->h_ind; 123 break; 124 case CEED_MEM_DEVICE: 125 *offsets = impl->d_ind; 126 break; 127 } 128 return CEED_ERROR_SUCCESS; 129 } 130 131 //------------------------------------------------------------------------------ 132 // Destroy restriction 133 //------------------------------------------------------------------------------ 134 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) { 135 int ierr; 136 CeedElemRestriction_Hip *impl; 137 ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr); 138 139 Ceed ceed; 140 ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr); 141 ierr = hipModuleUnload(impl->module); CeedChk_Hip(ceed, ierr); 142 ierr = CeedFree(&impl->h_ind_allocated); CeedChkBackend(ierr); 143 ierr = hipFree(impl->d_ind_allocated); CeedChk_Hip(ceed, ierr); 144 ierr = hipFree(impl->d_t_offsets); CeedChk_Hip(ceed, ierr); 145 ierr = hipFree(impl->d_t_indices); CeedChk_Hip(ceed, ierr); 146 ierr = hipFree(impl->d_l_vec_indices); CeedChk_Hip(ceed, ierr); 147 ierr = CeedFree(&impl); CeedChkBackend(ierr); 148 149 return CEED_ERROR_SUCCESS; 150 } 151 152 //------------------------------------------------------------------------------ 153 // Create transpose offsets and indices 154 //------------------------------------------------------------------------------ 155 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r, 156 const CeedInt *indices) { 157 int ierr; 158 Ceed ceed; 159 ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr); 160 CeedElemRestriction_Hip *impl; 161 ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr); 162 CeedSize l_size; 163 CeedInt num_elem, elem_size, num_comp; 164 ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr); 165 ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr); 166 ierr = CeedElemRestrictionGetLVectorSize(r, &l_size); CeedChkBackend(ierr); 167 ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr); 168 169 // Count num_nodes 170 bool *is_node; 171 ierr = CeedCalloc(l_size, &is_node); CeedChkBackend(ierr); 172 const CeedInt size_indices = num_elem * elem_size; 173 for (CeedInt i = 0; i < size_indices; i++) 174 is_node[indices[i]] = 1; 175 CeedInt num_nodes = 0; 176 for (CeedInt i = 0; i < l_size; i++) 177 num_nodes += is_node[i]; 178 impl->num_nodes = num_nodes; 179 180 // L-vector offsets array 181 CeedInt *ind_to_offset, *l_vec_indices; 182 ierr = CeedCalloc(l_size, &ind_to_offset); CeedChkBackend(ierr); 183 ierr = CeedCalloc(num_nodes, &l_vec_indices); CeedChkBackend(ierr); 184 CeedInt j = 0; 185 for (CeedInt i = 0; i < l_size; i++) 186 if (is_node[i]) { 187 l_vec_indices[j] = i; 188 ind_to_offset[i] = j++; 189 } 190 ierr = CeedFree(&is_node); CeedChkBackend(ierr); 191 192 // Compute transpose offsets and indices 193 const CeedInt size_offsets = num_nodes + 1; 194 CeedInt *t_offsets; 195 ierr = CeedCalloc(size_offsets, &t_offsets); CeedChkBackend(ierr); 196 CeedInt *t_indices; 197 ierr = CeedMalloc(size_indices, &t_indices); CeedChkBackend(ierr); 198 // Count node multiplicity 199 for (CeedInt e = 0; e < num_elem; ++e) 200 for (CeedInt i = 0; i < elem_size; ++i) 201 ++t_offsets[ind_to_offset[indices[elem_size*e + i]] + 1]; 202 // Convert to running sum 203 for (CeedInt i = 1; i < size_offsets; ++i) 204 t_offsets[i] += t_offsets[i-1]; 205 // List all E-vec indices associated with L-vec node 206 for (CeedInt e = 0; e < num_elem; ++e) { 207 for (CeedInt i = 0; i < elem_size; ++i) { 208 const CeedInt lid = elem_size*e + i; 209 const CeedInt gid = indices[lid]; 210 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 211 } 212 } 213 // Reset running sum 214 for (int i = size_offsets - 1; i > 0; --i) 215 t_offsets[i] = t_offsets[i - 1]; 216 t_offsets[0] = 0; 217 218 // Copy data to device 219 // -- L-vector indices 220 ierr = hipMalloc((void **)&impl->d_l_vec_indices, num_nodes*sizeof(CeedInt)); 221 CeedChk_Hip(ceed, ierr); 222 ierr = hipMemcpy(impl->d_l_vec_indices, l_vec_indices, 223 num_nodes*sizeof(CeedInt), hipMemcpyHostToDevice); 224 CeedChk_Hip(ceed, ierr); 225 // -- Transpose offsets 226 ierr = hipMalloc((void **)&impl->d_t_offsets, size_offsets*sizeof(CeedInt)); 227 CeedChk_Hip(ceed, ierr); 228 ierr = hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets*sizeof(CeedInt), 229 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 230 // -- Transpose indices 231 ierr = hipMalloc((void **)&impl->d_t_indices, size_indices*sizeof(CeedInt)); 232 CeedChk_Hip(ceed, ierr); 233 ierr = hipMemcpy(impl->d_t_indices, t_indices, size_indices*sizeof(CeedInt), 234 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 235 236 // Cleanup 237 ierr = CeedFree(&ind_to_offset); CeedChkBackend(ierr); 238 ierr = CeedFree(&l_vec_indices); CeedChkBackend(ierr); 239 ierr = CeedFree(&t_offsets); CeedChkBackend(ierr); 240 ierr = CeedFree(&t_indices); CeedChkBackend(ierr); 241 242 return CEED_ERROR_SUCCESS; 243 } 244 245 //------------------------------------------------------------------------------ 246 // Create restriction 247 //------------------------------------------------------------------------------ 248 int CeedElemRestrictionCreate_Hip(CeedMemType mtype, CeedCopyMode cmode, 249 const CeedInt *indices, 250 CeedElemRestriction r) { 251 int ierr; 252 Ceed ceed; 253 ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr); 254 CeedElemRestriction_Hip *impl; 255 ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr); 256 CeedInt num_elem, num_comp, elem_size; 257 ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr); 258 ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr); 259 ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr); 260 CeedInt size = num_elem * elem_size; 261 CeedInt strides[3] = {1, size, elem_size}; 262 CeedInt comp_stride = 1; 263 264 // Stride data 265 bool is_strided; 266 ierr = CeedElemRestrictionIsStrided(r, &is_strided); CeedChkBackend(ierr); 267 if (is_strided) { 268 bool has_backend_strides; 269 ierr = CeedElemRestrictionHasBackendStrides(r, &has_backend_strides); 270 CeedChkBackend(ierr); 271 if (!has_backend_strides) { 272 ierr = CeedElemRestrictionGetStrides(r, &strides); CeedChkBackend(ierr); 273 } 274 } else { 275 ierr = CeedElemRestrictionGetCompStride(r, &comp_stride); CeedChkBackend(ierr); 276 } 277 278 impl->h_ind = NULL; 279 impl->h_ind_allocated = NULL; 280 impl->d_ind = NULL; 281 impl->d_ind_allocated = NULL; 282 impl->d_t_indices = NULL; 283 impl->d_t_offsets = NULL; 284 impl->num_nodes = size; 285 ierr = CeedElemRestrictionSetData(r, impl); CeedChkBackend(ierr); 286 CeedInt layout[3] = {1, elem_size*num_elem, elem_size}; 287 ierr = CeedElemRestrictionSetELayout(r, layout); CeedChkBackend(ierr); 288 289 // Set up device indices/offset arrays 290 if (mtype == CEED_MEM_HOST) { 291 switch (cmode) { 292 case CEED_OWN_POINTER: 293 impl->h_ind_allocated = (CeedInt *)indices; 294 impl->h_ind = (CeedInt *)indices; 295 break; 296 case CEED_USE_POINTER: 297 impl->h_ind = (CeedInt *)indices; 298 break; 299 case CEED_COPY_VALUES: 300 if (indices != NULL) { 301 ierr = CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated); 302 CeedChkBackend(ierr); 303 memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt)); 304 impl->h_ind = impl->h_ind_allocated; 305 } 306 break; 307 } 308 if (indices != NULL) { 309 ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt)); 310 CeedChk_Hip(ceed, ierr); 311 impl->d_ind_allocated = impl->d_ind; // We own the device memory 312 ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), 313 hipMemcpyHostToDevice); 314 CeedChk_Hip(ceed, ierr); 315 ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr); 316 } 317 } else if (mtype == CEED_MEM_DEVICE) { 318 switch (cmode) { 319 case CEED_COPY_VALUES: 320 if (indices != NULL) { 321 ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt)); 322 CeedChk_Hip(ceed, ierr); 323 impl->d_ind_allocated = impl->d_ind; // We own the device memory 324 ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), 325 hipMemcpyDeviceToDevice); 326 CeedChk_Hip(ceed, ierr); 327 } 328 break; 329 case CEED_OWN_POINTER: 330 impl->d_ind = (CeedInt *)indices; 331 impl->d_ind_allocated = impl->d_ind; 332 break; 333 case CEED_USE_POINTER: 334 impl->d_ind = (CeedInt *)indices; 335 } 336 if (indices != NULL) { 337 ierr = CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated); 338 CeedChkBackend(ierr); 339 ierr = hipMemcpy(impl->h_ind_allocated, impl->d_ind, 340 elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost); 341 CeedChk_Hip(ceed, ierr); 342 impl->h_ind = impl->h_ind_allocated; 343 ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr); 344 } 345 } else { 346 // LCOV_EXCL_START 347 return CeedError(ceed, CEED_ERROR_BACKEND, 348 "Only MemType = HOST or DEVICE supported"); 349 // LCOV_EXCL_STOP 350 } 351 352 // Compile HIP kernels 353 CeedInt num_nodes = impl->num_nodes; 354 char *restriction_kernel_path, *restriction_kernel_source; 355 ierr = CeedGetJitAbsolutePath(ceed, 356 "ceed/jit-source/hip/hip-ref-restriction.h", 357 &restriction_kernel_path); CeedChkBackend(ierr); 358 CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source -----\n"); 359 ierr = CeedLoadSourceToBuffer(ceed, restriction_kernel_path, 360 &restriction_kernel_source); 361 CeedChkBackend(ierr); 362 CeedDebug256(ceed, 2, 363 "----- Loading Restriction Kernel Source Complete! -----\n"); 364 ierr = CeedCompileHip(ceed, restriction_kernel_source, &impl->module, 8, 365 "RESTR_ELEM_SIZE", elem_size, 366 "RESTR_NUM_ELEM", num_elem, 367 "RESTR_NUM_COMP", num_comp, 368 "RESTR_NUM_NODES", num_nodes, 369 "RESTR_COMP_STRIDE", comp_stride, 370 "RESTR_STRIDE_NODES", strides[0], 371 "RESTR_STRIDE_COMP", strides[1], 372 "RESTR_STRIDE_ELEM", strides[2]); CeedChkBackend(ierr); 373 ierr = CeedGetKernelHip(ceed, impl->module, "StridedNoTranspose", 374 &impl->StridedNoTranspose); CeedChkBackend(ierr); 375 ierr = CeedGetKernelHip(ceed, impl->module, "OffsetNoTranspose", 376 &impl->OffsetNoTranspose); CeedChkBackend(ierr); 377 ierr = CeedGetKernelHip(ceed, impl->module, "StridedTranspose", 378 &impl->StridedTranspose); CeedChkBackend(ierr); 379 ierr = CeedGetKernelHip(ceed, impl->module, "OffsetTranspose", 380 &impl->OffsetTranspose); CeedChkBackend(ierr); 381 ierr = CeedFree(&restriction_kernel_path); CeedChkBackend(ierr); 382 ierr = CeedFree(&restriction_kernel_source); CeedChkBackend(ierr); 383 384 // Register backend functions 385 ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", 386 CeedElemRestrictionApply_Hip); 387 CeedChkBackend(ierr); 388 ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyBlock", 389 CeedElemRestrictionApplyBlock_Hip); 390 CeedChkBackend(ierr); 391 ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", 392 CeedElemRestrictionGetOffsets_Hip); 393 CeedChkBackend(ierr); 394 ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", 395 CeedElemRestrictionDestroy_Hip); 396 CeedChkBackend(ierr); 397 return CEED_ERROR_SUCCESS; 398 } 399 400 //------------------------------------------------------------------------------ 401 // Blocked not supported 402 //------------------------------------------------------------------------------ 403 int CeedElemRestrictionCreateBlocked_Hip(const CeedMemType mtype, 404 const CeedCopyMode cmode, const CeedInt *indices, CeedElemRestriction r) { 405 int ierr; 406 Ceed ceed; 407 ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr); 408 return CeedError(ceed, CEED_ERROR_BACKEND, 409 "Backend does not implement blocked restrictions"); 410 } 411 //------------------------------------------------------------------------------ 412