1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <stdbool.h> 12 #include <stddef.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Apply restriction 22 //------------------------------------------------------------------------------ 23 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 24 CeedElemRestriction_Hip *impl; 25 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 26 Ceed ceed; 27 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 28 Ceed_Hip *data; 29 CeedCallBackend(CeedGetData(ceed, &data)); 30 const CeedInt num_nodes = impl->num_nodes; 31 CeedInt num_elem, elem_size; 32 CeedElemRestrictionGetNumElements(r, &num_elem); 33 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 34 hipFunction_t kernel; 35 36 // Get vectors 37 const CeedScalar *d_u; 38 CeedScalar *d_v; 39 CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 40 if (t_mode == CEED_TRANSPOSE) { 41 // Sum into for transpose mode, e-vec to l-vec 42 CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 43 } else { 44 // Overwrite for notranspose mode, l-vec to e-vec 45 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 46 } 47 48 // Restrict 49 if (t_mode == CEED_NOTRANSPOSE) { 50 // L-vector -> E-vector 51 if (impl->d_ind) { 52 // -- Offsets provided 53 kernel = impl->OffsetNoTranspose; 54 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 55 CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 56 57 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 58 } else { 59 // -- Strided restriction 60 kernel = impl->StridedNoTranspose; 61 void *args[] = {&num_elem, &d_u, &d_v}; 62 CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 63 64 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 65 } 66 } else { 67 // E-vector -> L-vector 68 if (impl->d_ind) { 69 // -- Offsets provided 70 CeedInt block_size = 64; 71 if (impl->OffsetTranspose) { 72 kernel = impl->OffsetTranspose; 73 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 74 75 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 76 } else { 77 kernel = impl->OffsetTransposeDet; 78 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 79 80 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 81 } 82 } else { 83 // -- Strided restriction 84 kernel = impl->StridedTranspose; 85 void *args[] = {&num_elem, &d_u, &d_v}; 86 CeedInt block_size = 64; 87 88 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 89 } 90 } 91 92 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 93 94 // Restore arrays 95 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 96 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 97 return CEED_ERROR_SUCCESS; 98 } 99 100 //------------------------------------------------------------------------------ 101 // Get offsets 102 //------------------------------------------------------------------------------ 103 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 104 CeedElemRestriction_Hip *impl; 105 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 106 107 switch (mem_type) { 108 case CEED_MEM_HOST: 109 *offsets = impl->h_ind; 110 break; 111 case CEED_MEM_DEVICE: 112 *offsets = impl->d_ind; 113 break; 114 } 115 return CEED_ERROR_SUCCESS; 116 } 117 118 //------------------------------------------------------------------------------ 119 // Destroy restriction 120 //------------------------------------------------------------------------------ 121 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) { 122 CeedElemRestriction_Hip *impl; 123 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 124 125 Ceed ceed; 126 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 127 CeedCallHip(ceed, hipModuleUnload(impl->module)); 128 CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 129 CeedCallHip(ceed, hipFree(impl->d_ind_allocated)); 130 CeedCallHip(ceed, hipFree(impl->d_t_offsets)); 131 CeedCallHip(ceed, hipFree(impl->d_t_indices)); 132 CeedCallHip(ceed, hipFree(impl->d_l_vec_indices)); 133 CeedCallBackend(CeedFree(&impl)); 134 135 return CEED_ERROR_SUCCESS; 136 } 137 138 //------------------------------------------------------------------------------ 139 // Create transpose offsets and indices 140 //------------------------------------------------------------------------------ 141 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r, const CeedInt *indices) { 142 Ceed ceed; 143 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 144 CeedElemRestriction_Hip *impl; 145 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 146 CeedSize l_size; 147 CeedInt num_elem, elem_size, num_comp; 148 CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 149 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 150 CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 151 CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 152 153 // Count num_nodes 154 bool *is_node; 155 CeedCallBackend(CeedCalloc(l_size, &is_node)); 156 const CeedInt size_indices = num_elem * elem_size; 157 for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 158 CeedInt num_nodes = 0; 159 for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 160 impl->num_nodes = num_nodes; 161 162 // L-vector offsets array 163 CeedInt *ind_to_offset, *l_vec_indices; 164 CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 165 CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 166 CeedInt j = 0; 167 for (CeedInt i = 0; i < l_size; i++) { 168 if (is_node[i]) { 169 l_vec_indices[j] = i; 170 ind_to_offset[i] = j++; 171 } 172 } 173 CeedCallBackend(CeedFree(&is_node)); 174 175 // Compute transpose offsets and indices 176 const CeedInt size_offsets = num_nodes + 1; 177 CeedInt *t_offsets; 178 CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 179 CeedInt *t_indices; 180 CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 181 // Count node multiplicity 182 for (CeedInt e = 0; e < num_elem; ++e) { 183 for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 184 } 185 // Convert to running sum 186 for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 187 // List all E-vec indices associated with L-vec node 188 for (CeedInt e = 0; e < num_elem; ++e) { 189 for (CeedInt i = 0; i < elem_size; ++i) { 190 const CeedInt lid = elem_size * e + i; 191 const CeedInt gid = indices[lid]; 192 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 193 } 194 } 195 // Reset running sum 196 for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 197 t_offsets[0] = 0; 198 199 // Copy data to device 200 // -- L-vector indices 201 CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 202 CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice)); 203 // -- Transpose offsets 204 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 205 CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice)); 206 // -- Transpose indices 207 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 208 CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice)); 209 210 // Cleanup 211 CeedCallBackend(CeedFree(&ind_to_offset)); 212 CeedCallBackend(CeedFree(&l_vec_indices)); 213 CeedCallBackend(CeedFree(&t_offsets)); 214 CeedCallBackend(CeedFree(&t_indices)); 215 216 return CEED_ERROR_SUCCESS; 217 } 218 219 //------------------------------------------------------------------------------ 220 // Create restriction 221 //------------------------------------------------------------------------------ 222 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients, 223 const CeedInt8 *curl_orients, CeedElemRestriction r) { 224 Ceed ceed; 225 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 226 CeedElemRestriction_Hip *impl; 227 CeedCallBackend(CeedCalloc(1, &impl)); 228 Ceed parent; 229 CeedCallBackend(CeedGetParent(ceed, &parent)); 230 bool is_deterministic; 231 CeedCallBackend(CeedIsDeterministic(parent, &is_deterministic)); 232 CeedInt num_elem, num_comp, elem_size; 233 CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 234 CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 235 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 236 CeedInt size = num_elem * elem_size; 237 CeedInt strides[3] = {1, size, elem_size}; 238 CeedInt comp_stride = 1; 239 240 CeedRestrictionType rstr_type; 241 CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type)); 242 CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND, 243 "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented"); 244 245 // Stride data 246 bool is_strided; 247 CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided)); 248 if (is_strided) { 249 bool has_backend_strides; 250 CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides)); 251 if (!has_backend_strides) { 252 CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 253 } 254 } else { 255 CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 256 } 257 258 impl->h_ind = NULL; 259 impl->h_ind_allocated = NULL; 260 impl->d_ind = NULL; 261 impl->d_ind_allocated = NULL; 262 impl->d_t_indices = NULL; 263 impl->d_t_offsets = NULL; 264 impl->num_nodes = size; 265 CeedCallBackend(CeedElemRestrictionSetData(r, impl)); 266 CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 267 CeedCallBackend(CeedElemRestrictionSetELayout(r, layout)); 268 269 // Set up device indices/offset arrays 270 switch (mem_type) { 271 case CEED_MEM_HOST: { 272 switch (copy_mode) { 273 case CEED_OWN_POINTER: 274 impl->h_ind_allocated = (CeedInt *)indices; 275 impl->h_ind = (CeedInt *)indices; 276 break; 277 case CEED_USE_POINTER: 278 impl->h_ind = (CeedInt *)indices; 279 break; 280 case CEED_COPY_VALUES: 281 if (indices != NULL) { 282 CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 283 memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt)); 284 impl->h_ind = impl->h_ind_allocated; 285 } 286 break; 287 } 288 if (indices != NULL) { 289 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 290 impl->d_ind_allocated = impl->d_ind; // We own the device memory 291 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice)); 292 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices)); 293 } 294 break; 295 } 296 case CEED_MEM_DEVICE: { 297 switch (copy_mode) { 298 case CEED_COPY_VALUES: 299 if (indices != NULL) { 300 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 301 impl->d_ind_allocated = impl->d_ind; // We own the device memory 302 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice)); 303 } 304 break; 305 case CEED_OWN_POINTER: 306 impl->d_ind = (CeedInt *)indices; 307 impl->d_ind_allocated = impl->d_ind; 308 break; 309 case CEED_USE_POINTER: 310 impl->d_ind = (CeedInt *)indices; 311 } 312 if (indices != NULL) { 313 CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 314 CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost)); 315 impl->h_ind = impl->h_ind_allocated; 316 if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices)); 317 } 318 break; 319 } 320 // LCOV_EXCL_START 321 default: 322 return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported"); 323 // LCOV_EXCL_STOP 324 } 325 326 // Compile HIP kernels 327 CeedInt num_nodes = impl->num_nodes; 328 char *restriction_kernel_path, *restriction_kernel_source; 329 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path)); 330 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 331 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 332 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 333 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem, 334 "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES", 335 strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2])); 336 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose)); 337 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose)); 338 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose)); 339 if (!is_deterministic) { 340 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose)); 341 } else { 342 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTransposeDet", &impl->OffsetTransposeDet)); 343 } 344 CeedCallBackend(CeedFree(&restriction_kernel_path)); 345 CeedCallBackend(CeedFree(&restriction_kernel_source)); 346 347 // Register backend functions 348 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Hip)); 349 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Hip)); 350 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Hip)); 351 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Hip)); 352 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Hip)); 353 return CEED_ERROR_SUCCESS; 354 } 355 356 //------------------------------------------------------------------------------ 357