1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <cuda.h> 12 #include <cuda_runtime.h> 13 #include <stdbool.h> 14 #include <stddef.h> 15 #include <string.h> 16 17 #include "../cuda/ceed-cuda-common.h" 18 #include "../cuda/ceed-cuda-compile.h" 19 #include "ceed-cuda-ref.h" 20 21 //------------------------------------------------------------------------------ 22 // Apply restriction 23 //------------------------------------------------------------------------------ 24 static int CeedElemRestrictionApply_Cuda(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) { 25 CeedElemRestriction_Cuda *impl; 26 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 27 Ceed ceed; 28 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 29 Ceed_Cuda *data; 30 CeedCallBackend(CeedGetData(ceed, &data)); 31 const CeedInt warp_size = 32; 32 const CeedInt block_size = warp_size; 33 const CeedInt num_nodes = impl->num_nodes; 34 CeedInt num_elem, elem_size; 35 CeedElemRestrictionGetNumElements(r, &num_elem); 36 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 37 CUfunction kernel; 38 39 // Get vectors 40 const CeedScalar *d_u; 41 CeedScalar *d_v; 42 CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 43 if (t_mode == CEED_TRANSPOSE) { 44 // Sum into for transpose mode, e-vec to l-vec 45 CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 46 } else { 47 // Overwrite for notranspose mode, l-vec to e-vec 48 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 49 } 50 51 // Restrict 52 if (t_mode == CEED_NOTRANSPOSE) { 53 // L-vector -> E-vector 54 if (impl->d_ind) { 55 // -- Offsets provided 56 kernel = impl->OffsetNoTranspose; 57 void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v}; 58 CeedInt block_size = elem_size < 1024 ? (elem_size > 32 ? elem_size : 32) : 1024; 59 CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 60 } else { 61 // -- Strided restriction 62 kernel = impl->StridedNoTranspose; 63 void *args[] = {&num_elem, &d_u, &d_v}; 64 CeedInt block_size = elem_size < 1024 ? (elem_size > 32 ? elem_size : 32) : 1024; 65 CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 66 } 67 } else { 68 // E-vector -> L-vector 69 if (impl->d_ind) { 70 // -- Offsets provided 71 kernel = impl->OffsetTranspose; 72 void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v}; 73 CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 74 } else { 75 // -- Strided restriction 76 kernel = impl->StridedTranspose; 77 void *args[] = {&num_elem, &d_u, &d_v}; 78 CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args)); 79 } 80 } 81 82 if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL; 83 84 // Restore arrays 85 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 86 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 87 return CEED_ERROR_SUCCESS; 88 } 89 90 //------------------------------------------------------------------------------ 91 // Get offsets 92 //------------------------------------------------------------------------------ 93 static int CeedElemRestrictionGetOffsets_Cuda(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) { 94 CeedElemRestriction_Cuda *impl; 95 CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl)); 96 97 switch (mem_type) { 98 case CEED_MEM_HOST: 99 *offsets = impl->h_ind; 100 break; 101 case CEED_MEM_DEVICE: 102 *offsets = impl->d_ind; 103 break; 104 } 105 return CEED_ERROR_SUCCESS; 106 } 107 108 //------------------------------------------------------------------------------ 109 // Destroy restriction 110 //------------------------------------------------------------------------------ 111 static int CeedElemRestrictionDestroy_Cuda(CeedElemRestriction r) { 112 CeedElemRestriction_Cuda *impl; 113 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 114 115 Ceed ceed; 116 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 117 CeedCallCuda(ceed, cuModuleUnload(impl->module)); 118 CeedCallBackend(CeedFree(&impl->h_ind_allocated)); 119 CeedCallCuda(ceed, cudaFree(impl->d_ind_allocated)); 120 CeedCallCuda(ceed, cudaFree(impl->d_t_offsets)); 121 CeedCallCuda(ceed, cudaFree(impl->d_t_indices)); 122 CeedCallCuda(ceed, cudaFree(impl->d_l_vec_indices)); 123 CeedCallBackend(CeedFree(&impl)); 124 125 return CEED_ERROR_SUCCESS; 126 } 127 128 //------------------------------------------------------------------------------ 129 // Create transpose offsets and indices 130 //------------------------------------------------------------------------------ 131 static int CeedElemRestrictionOffset_Cuda(const CeedElemRestriction r, const CeedInt *indices) { 132 Ceed ceed; 133 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 134 CeedElemRestriction_Cuda *impl; 135 CeedCallBackend(CeedElemRestrictionGetData(r, &impl)); 136 CeedSize l_size; 137 CeedInt num_elem, elem_size, num_comp; 138 CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 139 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 140 CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size)); 141 CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 142 143 // Count num_nodes 144 bool *is_node; 145 CeedCallBackend(CeedCalloc(l_size, &is_node)); 146 const CeedInt size_indices = num_elem * elem_size; 147 for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1; 148 CeedInt num_nodes = 0; 149 for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i]; 150 impl->num_nodes = num_nodes; 151 152 // L-vector offsets array 153 CeedInt *ind_to_offset, *l_vec_indices; 154 CeedCallBackend(CeedCalloc(l_size, &ind_to_offset)); 155 CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices)); 156 CeedInt j = 0; 157 for (CeedInt i = 0; i < l_size; i++) { 158 if (is_node[i]) { 159 l_vec_indices[j] = i; 160 ind_to_offset[i] = j++; 161 } 162 } 163 CeedCallBackend(CeedFree(&is_node)); 164 165 // Compute transpose offsets and indices 166 const CeedInt size_offsets = num_nodes + 1; 167 CeedInt *t_offsets; 168 CeedCallBackend(CeedCalloc(size_offsets, &t_offsets)); 169 CeedInt *t_indices; 170 CeedCallBackend(CeedMalloc(size_indices, &t_indices)); 171 // Count node multiplicity 172 for (CeedInt e = 0; e < num_elem; ++e) { 173 for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1]; 174 } 175 // Convert to running sum 176 for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1]; 177 // List all E-vec indices associated with L-vec node 178 for (CeedInt e = 0; e < num_elem; ++e) { 179 for (CeedInt i = 0; i < elem_size; ++i) { 180 const CeedInt lid = elem_size * e + i; 181 const CeedInt gid = indices[lid]; 182 t_indices[t_offsets[ind_to_offset[gid]]++] = lid; 183 } 184 } 185 // Reset running sum 186 for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1]; 187 t_offsets[0] = 0; 188 189 // Copy data to device 190 // -- L-vector indices 191 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt))); 192 CeedCallCuda(ceed, cudaMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), cudaMemcpyHostToDevice)); 193 // -- Transpose offsets 194 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt))); 195 CeedCallCuda(ceed, cudaMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), cudaMemcpyHostToDevice)); 196 // -- Transpose indices 197 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt))); 198 CeedCallCuda(ceed, cudaMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), cudaMemcpyHostToDevice)); 199 200 // Cleanup 201 CeedCallBackend(CeedFree(&ind_to_offset)); 202 CeedCallBackend(CeedFree(&l_vec_indices)); 203 CeedCallBackend(CeedFree(&t_offsets)); 204 CeedCallBackend(CeedFree(&t_indices)); 205 206 return CEED_ERROR_SUCCESS; 207 } 208 209 //------------------------------------------------------------------------------ 210 // Create restriction 211 //------------------------------------------------------------------------------ 212 int CeedElemRestrictionCreate_Cuda(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients, 213 const CeedInt *curl_orients, CeedElemRestriction r) { 214 Ceed ceed; 215 CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed)); 216 CeedElemRestriction_Cuda *impl; 217 CeedCallBackend(CeedCalloc(1, &impl)); 218 CeedInt num_elem, num_comp, elem_size; 219 CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem)); 220 CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp)); 221 CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size)); 222 CeedInt size = num_elem * elem_size; 223 CeedInt strides[3] = {1, size, elem_size}; 224 CeedInt comp_stride = 1; 225 226 // Stride data 227 bool is_strided; 228 CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided)); 229 if (is_strided) { 230 bool has_backend_strides; 231 CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides)); 232 if (!has_backend_strides) { 233 CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides)); 234 } 235 } else { 236 CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride)); 237 } 238 239 impl->h_ind = NULL; 240 impl->h_ind_allocated = NULL; 241 impl->d_ind = NULL; 242 impl->d_ind_allocated = NULL; 243 impl->d_t_indices = NULL; 244 impl->d_t_offsets = NULL; 245 impl->num_nodes = size; 246 CeedCallBackend(CeedElemRestrictionSetData(r, impl)); 247 CeedInt layout[3] = {1, elem_size * num_elem, elem_size}; 248 CeedCallBackend(CeedElemRestrictionSetELayout(r, layout)); 249 250 // Set up device indices/offset arrays 251 switch (mem_type) { 252 case CEED_MEM_HOST: { 253 switch (copy_mode) { 254 case CEED_OWN_POINTER: 255 impl->h_ind_allocated = (CeedInt *)indices; 256 impl->h_ind = (CeedInt *)indices; 257 break; 258 case CEED_USE_POINTER: 259 impl->h_ind = (CeedInt *)indices; 260 break; 261 case CEED_COPY_VALUES: 262 if (indices != NULL) { 263 CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 264 memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt)); 265 impl->h_ind = impl->h_ind_allocated; 266 } 267 break; 268 } 269 if (indices != NULL) { 270 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 271 impl->d_ind_allocated = impl->d_ind; // We own the device memory 272 CeedCallCuda(ceed, cudaMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), cudaMemcpyHostToDevice)); 273 CeedCallBackend(CeedElemRestrictionOffset_Cuda(r, indices)); 274 } 275 break; 276 } 277 case CEED_MEM_DEVICE: { 278 switch (copy_mode) { 279 case CEED_COPY_VALUES: 280 if (indices != NULL) { 281 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_ind, size * sizeof(CeedInt))); 282 impl->d_ind_allocated = impl->d_ind; // We own the device memory 283 CeedCallCuda(ceed, cudaMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), cudaMemcpyDeviceToDevice)); 284 } 285 break; 286 case CEED_OWN_POINTER: 287 impl->d_ind = (CeedInt *)indices; 288 impl->d_ind_allocated = impl->d_ind; 289 break; 290 case CEED_USE_POINTER: 291 impl->d_ind = (CeedInt *)indices; 292 } 293 if (indices != NULL) { 294 CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated)); 295 CeedCallCuda(ceed, cudaMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), cudaMemcpyDeviceToHost)); 296 impl->h_ind = impl->h_ind_allocated; 297 CeedCallBackend(CeedElemRestrictionOffset_Cuda(r, indices)); 298 } 299 break; 300 } 301 // LCOV_EXCL_START 302 default: 303 return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported"); 304 // LCOV_EXCL_STOP 305 } 306 307 // Compile CUDA kernels 308 CeedInt num_nodes = impl->num_nodes; 309 char *restriction_kernel_path, *restriction_kernel_source; 310 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-restriction.h", &restriction_kernel_path)); 311 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n"); 312 CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source)); 313 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n"); 314 CeedCallBackend(CeedCompile_Cuda(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem, 315 "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES", 316 strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2])); 317 CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose)); 318 CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose)); 319 CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose)); 320 CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose)); 321 CeedCallBackend(CeedFree(&restriction_kernel_path)); 322 CeedCallBackend(CeedFree(&restriction_kernel_source)); 323 324 // Register backend functions 325 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Cuda)); 326 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Cuda)); 327 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Cuda)); 328 CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Cuda)); 329 return CEED_ERROR_SUCCESS; 330 } 331 332 //------------------------------------------------------------------------------ 333