1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <stdbool.h> 12 #include <stddef.h> 13 #include <string.h> 14 #include <hip/hip_runtime.h> 15 16 #include "../hip/ceed-hip-common.h" 17 #include "../hip/ceed-hip-compile.h" 18 #include "ceed-hip-ref.h" 19 20 //------------------------------------------------------------------------------ 21 // Apply restriction 22 //------------------------------------------------------------------------------ 23 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 24 CeedElemRestriction_Hip *impl; 25 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 26 Ceed ceed; 27 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 28 Ceed_Hip *data; 29 CeedCallBackend(CeedGetData(ceed, &data)); 30 const CeedInt block_size = 64; 31 const CeedInt num_nodes = impl->num_nodes; 32 CeedInt num_elem, elem_size; 33 CeedElemRestrictionGetNumElements(r, &num_elem); 34 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 35 hipFunction_t kernel; 36 37 // Get vectors 38 const CeedScalar *d_u; 39 CeedScalar *d_v; 40 CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 41 if (t_mode == CEED_TRANSPOSE) { 42 // Sum into for transpose mode, e-vec to l-vec 43 CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 44 } else { 45 // Overwrite for notranspose mode, l-vec to e-vec 46 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 47 } 48 49 // Restrict 50 if (t_mode == CEED_NOTRANSPOSE) { 51 // L-vector -> E-vector 52 if (impl->d_ind) { 53 // -- Offsets provided 54 kernel = impl->OffsetNoTranspose; 55 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 56 CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 57 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 58 } else { 59 // -- Strided restriction 60 kernel = impl->StridedNoTranspose; 61 void *args[] = {&num_elem, &d_u, &d_v}; 62 CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256; 63 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 64 } 65 } else { 66 // E-vector -> L-vector 67 if (impl->d_ind) { 68 // -- Offsets provided 69 kernel = impl->OffsetTranspose; 70 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 71 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 72 } else { 73 // -- Strided restriction 74 kernel = impl->StridedTranspose; 75 void *args[] = {&num_elem, &d_u, &d_v}; 76 CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 77 } 78 } 79 80 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 81 82 // Restore arrays 83 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 84 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 85 return CEED_ERROR_SUCCESS; 86 } 87 88 //------------------------------------------------------------------------------ 89 // Get offsets 90 //------------------------------------------------------------------------------ 91 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 92 CeedElemRestriction_Hip *impl; 93 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 94 95 switch (mem_type) { 96 case CEED_MEM_HOST: 97 *offsets = impl->h_ind; 98 break; 99 case CEED_MEM_DEVICE: 100 *offsets = impl->d_ind; 101 break; 102 } 103 return CEED_ERROR_SUCCESS; 104 } 105 106 //------------------------------------------------------------------------------ 107 // Destroy restriction 108 //------------------------------------------------------------------------------ 109 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) { 110 CeedElemRestriction_Hip *impl; 111 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 112 113 Ceed ceed; 114 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 115 CeedCallHip(ceed, hipModuleUnload(impl->module)); 116 CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 117 CeedCallHip(ceed, hipFree(impl->d_ind_allocated)); 118 CeedCallHip(ceed, hipFree(impl->d_t_offsets)); 119 CeedCallHip(ceed, hipFree(impl->d_t_indices)); 120 CeedCallHip(ceed, hipFree(impl->d_l_vec_indices)); 121 CeedCallBackend(CeedFree(&impl)); 122 123 return CEED_ERROR_SUCCESS; 124 } 125 126 //------------------------------------------------------------------------------ 127 // Create transpose offsets and indices 128 //------------------------------------------------------------------------------ 129 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r, const CeedInt *indices) { 130 Ceed ceed; 131 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 132 CeedElemRestriction_Hip *impl; 133 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 134 CeedSize l_size; 135 CeedInt num_elem, elem_size, num_comp; 136 CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 137 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 138 CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 139 CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 140 141 // Count num_nodes 142 bool *is_node; 143 CeedCallBackend(CeedCalloc(l_size, &is_node)); 144 const CeedInt size_indices = num_elem * elem_size; 145 for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 146 CeedInt num_nodes = 0; 147 for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 148 impl->num_nodes = num_nodes; 149 150 // L-vector offsets array 151 CeedInt *ind_to_offset, *l_vec_indices; 152 CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 153 CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 154 CeedInt j = 0; 155 for (CeedInt i = 0; i < l_size; i++) { 156 if (is_node[i]) { 157 l_vec_indices[j] = i; 158 ind_to_offset[i] = j++; 159 } 160 } 161 CeedCallBackend(CeedFree(&is_node)); 162 163 // Compute transpose offsets and indices 164 const CeedInt size_offsets = num_nodes + 1; 165 CeedInt *t_offsets; 166 CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 167 CeedInt *t_indices; 168 CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 169 // Count node multiplicity 170 for (CeedInt e = 0; e < num_elem; ++e) { 171 for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 172 } 173 // Convert to running sum 174 for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 175 // List all E-vec indices associated with L-vec node 176 for (CeedInt e = 0; e < num_elem; ++e) { 177 for (CeedInt i = 0; i < elem_size; ++i) { 178 const CeedInt lid = elem_size * e + i; 179 const CeedInt gid = indices[lid]; 180 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 181 } 182 } 183 // Reset running sum 184 for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 185 t_offsets[0] = 0; 186 187 // Copy data to device 188 // -- L-vector indices 189 CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 190 CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice)); 191 // -- Transpose offsets 192 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 193 CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice)); 194 // -- Transpose indices 195 CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 196 CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice)); 197 198 // Cleanup 199 CeedCallBackend(CeedFree(&ind_to_offset)); 200 CeedCallBackend(CeedFree(&l_vec_indices)); 201 CeedCallBackend(CeedFree(&t_offsets)); 202 CeedCallBackend(CeedFree(&t_indices)); 203 204 return CEED_ERROR_SUCCESS; 205 } 206 207 //------------------------------------------------------------------------------ 208 // Create restriction 209 //------------------------------------------------------------------------------ 210 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, CeedElemRestriction r) { 211 Ceed ceed; 212 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 213 CeedElemRestriction_Hip *impl; 214 CeedCallBackend(CeedCalloc(1, &impl)); 215 CeedInt num_elem, num_comp, elem_size; 216 CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 217 CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 218 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 219 CeedInt size = num_elem * elem_size; 220 CeedInt strides[3] = {1, size, elem_size}; 221 CeedInt comp_stride = 1; 222 223 // Stride data 224 bool is_strided; 225 CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided)); 226 if (is_strided) { 227 bool has_backend_strides; 228 CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides)); 229 if (!has_backend_strides) { 230 CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 231 } 232 } else { 233 CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 234 } 235 236 impl->h_ind = NULL; 237 impl->h_ind_allocated = NULL; 238 impl->d_ind = NULL; 239 impl->d_ind_allocated = NULL; 240 impl->d_t_indices = NULL; 241 impl->d_t_offsets = NULL; 242 impl->num_nodes = size; 243 CeedCallBackend(CeedElemRestrictionSetData(r, impl)); 244 CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 245 CeedCallBackend(CeedElemRestrictionSetELayout(r, layout)); 246 247 // Set up device indices/offset arrays 248 switch (mem_type) { 249 case CEED_MEM_HOST: { 250 switch (copy_mode) { 251 case CEED_OWN_POINTER: 252 impl->h_ind_allocated = (CeedInt *)indices; 253 impl->h_ind = (CeedInt *)indices; 254 break; 255 case CEED_USE_POINTER: 256 impl->h_ind = (CeedInt *)indices; 257 break; 258 case CEED_COPY_VALUES: 259 if (indices != NULL) { 260 CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 261 memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt)); 262 impl->h_ind = impl->h_ind_allocated; 263 } 264 break; 265 } 266 if (indices != NULL) { 267 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 268 impl->d_ind_allocated = impl->d_ind; // We own the device memory 269 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice)); 270 CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices)); 271 } 272 break; 273 } 274 case CEED_MEM_DEVICE: { 275 switch (copy_mode) { 276 case CEED_COPY_VALUES: 277 if (indices != NULL) { 278 CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 279 impl->d_ind_allocated = impl->d_ind; // We own the device memory 280 CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice)); 281 } 282 break; 283 case CEED_OWN_POINTER: 284 impl->d_ind = (CeedInt *)indices; 285 impl->d_ind_allocated = impl->d_ind; 286 break; 287 case CEED_USE_POINTER: 288 impl->d_ind = (CeedInt *)indices; 289 } 290 if (indices != NULL) { 291 CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 292 CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost)); 293 impl->h_ind = impl->h_ind_allocated; 294 CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices)); 295 } 296 break; 297 } 298 // LCOV_EXCL_START 299 default: 300 return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported"); 301 // LCOV_EXCL_STOP 302 } 303 304 // Compile HIP kernels 305 CeedInt num_nodes = impl->num_nodes; 306 char *restriction_kernel_path, *restriction_kernel_source; 307 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path)); 308 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 309 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 310 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 311 CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem, 312 "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES", 313 strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2])); 314 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose)); 315 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose)); 316 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose)); 317 CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose)); 318 CeedCallBackend(CeedFree(&restriction_kernel_path)); 319 CeedCallBackend(CeedFree(&restriction_kernel_source)); 320 321 // Register backend functions 322 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Hip)); 323 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Hip)); 324 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Hip)); 325 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Hip)); 326 return CEED_ERROR_SUCCESS; 327 } 328 329 //------------------------------------------------------------------------------ 330