1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <cuda_runtime.h> 11 #include <stdbool.h> 12 #include <string.h> 13 14 #include "../cuda/ceed-cuda-common.h" 15 #include "ceed-cuda-ref.h" 16 17 //------------------------------------------------------------------------------ 18 // Sync host to device 19 //------------------------------------------------------------------------------ 20 static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx) { 21 Ceed ceed; 22 size_t ctx_size; 23 CeedQFunctionContext_Cuda *impl; 24 25 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 26 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 27 28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 29 30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 31 if (impl->d_data_borrowed) { 32 impl->d_data = impl->d_data_borrowed; 33 } else if (impl->d_data_owned) { 34 impl->d_data = impl->d_data_owned; 35 } else { 36 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size)); 37 impl->d_data = impl->d_data_owned; 38 } 39 CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctx_size, cudaMemcpyHostToDevice)); 40 CeedCallBackend(CeedDestroy(&ceed)); 41 return CEED_ERROR_SUCCESS; 42 } 43 44 //------------------------------------------------------------------------------ 45 // Sync device to host 46 //------------------------------------------------------------------------------ 47 static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx) { 48 Ceed ceed; 49 size_t ctx_size; 50 CeedQFunctionContext_Cuda *impl; 51 52 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 53 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 54 55 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 56 57 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 58 59 if (impl->h_data_borrowed) { 60 impl->h_data = impl->h_data_borrowed; 61 } else if (impl->h_data_owned) { 62 impl->h_data = impl->h_data_owned; 63 } else { 64 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 65 impl->h_data = impl->h_data_owned; 66 } 67 CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctx_size, cudaMemcpyDeviceToHost)); 68 CeedCallBackend(CeedDestroy(&ceed)); 69 return CEED_ERROR_SUCCESS; 70 } 71 72 //------------------------------------------------------------------------------ 73 // Sync data of type 74 //------------------------------------------------------------------------------ 75 static inline int CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type) { 76 switch (mem_type) { 77 case CEED_MEM_HOST: 78 return CeedQFunctionContextSyncD2H_Cuda(ctx); 79 case CEED_MEM_DEVICE: 80 return CeedQFunctionContextSyncH2D_Cuda(ctx); 81 } 82 return CEED_ERROR_UNSUPPORTED; 83 } 84 85 //------------------------------------------------------------------------------ 86 // Set all pointers as invalid 87 //------------------------------------------------------------------------------ 88 static inline int CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx) { 89 CeedQFunctionContext_Cuda *impl; 90 91 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 92 impl->h_data = NULL; 93 impl->d_data = NULL; 94 return CEED_ERROR_SUCCESS; 95 } 96 97 //------------------------------------------------------------------------------ 98 // Check if ctx has valid data 99 //------------------------------------------------------------------------------ 100 static inline int CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx, bool *has_valid_data) { 101 CeedQFunctionContext_Cuda *impl; 102 103 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 104 *has_valid_data = impl && (impl->h_data || impl->d_data); 105 return CEED_ERROR_SUCCESS; 106 } 107 108 //------------------------------------------------------------------------------ 109 // Check if ctx has borrowed data 110 //------------------------------------------------------------------------------ 111 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, 112 bool *has_borrowed_data_of_type) { 113 CeedQFunctionContext_Cuda *impl; 114 115 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 116 switch (mem_type) { 117 case CEED_MEM_HOST: 118 *has_borrowed_data_of_type = impl->h_data_borrowed; 119 break; 120 case CEED_MEM_DEVICE: 121 *has_borrowed_data_of_type = impl->d_data_borrowed; 122 break; 123 } 124 return CEED_ERROR_SUCCESS; 125 } 126 127 //------------------------------------------------------------------------------ 128 // Check if data of given type needs sync 129 //------------------------------------------------------------------------------ 130 static inline int CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 131 bool has_valid_data = true; 132 CeedQFunctionContext_Cuda *impl; 133 134 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 135 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data)); 136 switch (mem_type) { 137 case CEED_MEM_HOST: 138 *need_sync = has_valid_data && !impl->h_data; 139 break; 140 case CEED_MEM_DEVICE: 141 *need_sync = has_valid_data && !impl->d_data; 142 break; 143 } 144 return CEED_ERROR_SUCCESS; 145 } 146 147 //------------------------------------------------------------------------------ 148 // Set data from host 149 //------------------------------------------------------------------------------ 150 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 151 CeedQFunctionContext_Cuda *impl; 152 153 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 154 155 CeedCallBackend(CeedFree(&impl->h_data_owned)); 156 switch (copy_mode) { 157 case CEED_COPY_VALUES: { 158 size_t ctx_size; 159 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 160 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 161 impl->h_data_borrowed = NULL; 162 impl->h_data = impl->h_data_owned; 163 memcpy(impl->h_data, data, ctx_size); 164 } break; 165 case CEED_OWN_POINTER: 166 impl->h_data_owned = data; 167 impl->h_data_borrowed = NULL; 168 impl->h_data = data; 169 break; 170 case CEED_USE_POINTER: 171 impl->h_data_borrowed = data; 172 impl->h_data = data; 173 break; 174 } 175 return CEED_ERROR_SUCCESS; 176 } 177 178 //------------------------------------------------------------------------------ 179 // Set data from device 180 //------------------------------------------------------------------------------ 181 static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 182 Ceed ceed; 183 CeedQFunctionContext_Cuda *impl; 184 185 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 186 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 187 188 CeedCallCuda(ceed, cudaFree(impl->d_data_owned)); 189 impl->d_data_owned = NULL; 190 switch (copy_mode) { 191 case CEED_COPY_VALUES: { 192 size_t ctx_size; 193 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 194 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size)); 195 impl->d_data_borrowed = NULL; 196 impl->d_data = impl->d_data_owned; 197 CeedCallCuda(ceed, cudaMemcpy(impl->d_data, data, ctx_size, cudaMemcpyDeviceToDevice)); 198 } break; 199 case CEED_OWN_POINTER: 200 impl->d_data_owned = data; 201 impl->d_data_borrowed = NULL; 202 impl->d_data = data; 203 break; 204 case CEED_USE_POINTER: 205 impl->d_data_owned = NULL; 206 impl->d_data_borrowed = data; 207 impl->d_data = data; 208 break; 209 } 210 CeedCallBackend(CeedDestroy(&ceed)); 211 return CEED_ERROR_SUCCESS; 212 } 213 214 //------------------------------------------------------------------------------ 215 // Set the data used by a user context, 216 // freeing any previously allocated data if applicable 217 //------------------------------------------------------------------------------ 218 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 219 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx)); 220 switch (mem_type) { 221 case CEED_MEM_HOST: 222 return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data); 223 case CEED_MEM_DEVICE: 224 return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data); 225 } 226 return CEED_ERROR_UNSUPPORTED; 227 } 228 229 //------------------------------------------------------------------------------ 230 // Take data 231 //------------------------------------------------------------------------------ 232 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 233 CeedQFunctionContext_Cuda *impl; 234 235 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 236 237 // Sync data to requested mem_type 238 bool need_sync = false; 239 CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync)); 240 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type)); 241 242 // Update pointer 243 switch (mem_type) { 244 case CEED_MEM_HOST: 245 *(void **)data = impl->h_data_borrowed; 246 impl->h_data_borrowed = NULL; 247 impl->h_data = NULL; 248 break; 249 case CEED_MEM_DEVICE: 250 *(void **)data = impl->d_data_borrowed; 251 impl->d_data_borrowed = NULL; 252 impl->d_data = NULL; 253 break; 254 } 255 return CEED_ERROR_SUCCESS; 256 } 257 258 //------------------------------------------------------------------------------ 259 // Core logic for GetData. 260 // If a different memory type is most up to date, this will perform a copy 261 //------------------------------------------------------------------------------ 262 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 263 bool need_sync = false; 264 CeedQFunctionContext_Cuda *impl; 265 266 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 267 268 // Sync data to requested mem_type 269 CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync)); 270 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type)); 271 272 // Update pointer 273 switch (mem_type) { 274 case CEED_MEM_HOST: 275 *(void **)data = impl->h_data; 276 break; 277 case CEED_MEM_DEVICE: 278 *(void **)data = impl->d_data; 279 break; 280 } 281 return CEED_ERROR_SUCCESS; 282 } 283 284 //------------------------------------------------------------------------------ 285 // Get read-only access to the data 286 //------------------------------------------------------------------------------ 287 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 288 return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data); 289 } 290 291 //------------------------------------------------------------------------------ 292 // Get read/write access to the data 293 //------------------------------------------------------------------------------ 294 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 295 CeedQFunctionContext_Cuda *impl; 296 297 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 298 CeedCallBackend(CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data)); 299 300 // Mark only pointer for requested memory as valid 301 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx)); 302 switch (mem_type) { 303 case CEED_MEM_HOST: 304 impl->h_data = *(void **)data; 305 break; 306 case CEED_MEM_DEVICE: 307 impl->d_data = *(void **)data; 308 break; 309 } 310 return CEED_ERROR_SUCCESS; 311 } 312 313 //------------------------------------------------------------------------------ 314 // Destroy the user context 315 //------------------------------------------------------------------------------ 316 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) { 317 CeedQFunctionContext_Cuda *impl; 318 319 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 320 CeedCallCuda(CeedQFunctionContextReturnCeed(ctx), cudaFree(impl->d_data_owned)); 321 CeedCallBackend(CeedFree(&impl->h_data_owned)); 322 CeedCallBackend(CeedFree(&impl)); 323 return CEED_ERROR_SUCCESS; 324 } 325 326 //------------------------------------------------------------------------------ 327 // QFunctionContext Create 328 //------------------------------------------------------------------------------ 329 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) { 330 CeedQFunctionContext_Cuda *impl; 331 Ceed ceed; 332 333 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 334 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Cuda)); 335 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Cuda)); 336 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Cuda)); 337 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Cuda)); 338 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda)); 339 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda)); 340 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda)); 341 CeedCallBackend(CeedDestroy(&ceed)); 342 CeedCallBackend(CeedCalloc(1, &impl)); 343 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 344 return CEED_ERROR_SUCCESS; 345 } 346 347 //------------------------------------------------------------------------------ 348