1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <cuda_runtime.h> 11 #include <stdbool.h> 12 #include <string.h> 13 14 #include "../cuda/ceed-cuda-common.h" 15 #include "ceed-cuda-ref.h" 16 17 //------------------------------------------------------------------------------ 18 // Sync host to device 19 //------------------------------------------------------------------------------ 20 static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx) { 21 Ceed ceed; 22 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 23 CeedQFunctionContext_Cuda *impl; 24 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 25 26 if (!impl->h_data) { 27 // LCOV_EXCL_START 28 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 29 // LCOV_EXCL_STOP 30 } 31 32 size_t ctxsize; 33 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 34 35 if (impl->d_data_borrowed) { 36 impl->d_data = impl->d_data_borrowed; 37 } else if (impl->d_data_owned) { 38 impl->d_data = impl->d_data_owned; 39 } else { 40 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctxsize)); 41 impl->d_data = impl->d_data_owned; 42 } 43 44 CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctxsize, cudaMemcpyHostToDevice)); 45 46 return CEED_ERROR_SUCCESS; 47 } 48 49 //------------------------------------------------------------------------------ 50 // Sync device to host 51 //------------------------------------------------------------------------------ 52 static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx) { 53 Ceed ceed; 54 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 55 CeedQFunctionContext_Cuda *impl; 56 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 57 58 if (!impl->d_data) { 59 // LCOV_EXCL_START 60 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 61 // LCOV_EXCL_STOP 62 } 63 64 size_t ctxsize; 65 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 66 67 if (impl->h_data_borrowed) { 68 impl->h_data = impl->h_data_borrowed; 69 } else if (impl->h_data_owned) { 70 impl->h_data = impl->h_data_owned; 71 } else { 72 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 73 impl->h_data = impl->h_data_owned; 74 } 75 76 CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctxsize, cudaMemcpyDeviceToHost)); 77 78 return CEED_ERROR_SUCCESS; 79 } 80 81 //------------------------------------------------------------------------------ 82 // Sync data of type 83 //------------------------------------------------------------------------------ 84 static inline int CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type) { 85 switch (mem_type) { 86 case CEED_MEM_HOST: 87 return CeedQFunctionContextSyncD2H_Cuda(ctx); 88 case CEED_MEM_DEVICE: 89 return CeedQFunctionContextSyncH2D_Cuda(ctx); 90 } 91 return CEED_ERROR_UNSUPPORTED; 92 } 93 94 //------------------------------------------------------------------------------ 95 // Set all pointers as invalid 96 //------------------------------------------------------------------------------ 97 static inline int CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx) { 98 CeedQFunctionContext_Cuda *impl; 99 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 100 101 impl->h_data = NULL; 102 impl->d_data = NULL; 103 104 return CEED_ERROR_SUCCESS; 105 } 106 107 //------------------------------------------------------------------------------ 108 // Check if ctx has valid data 109 //------------------------------------------------------------------------------ 110 static inline int CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx, bool *has_valid_data) { 111 CeedQFunctionContext_Cuda *impl; 112 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 113 114 *has_valid_data = impl && (!!impl->h_data || !!impl->d_data); 115 116 return CEED_ERROR_SUCCESS; 117 } 118 119 //------------------------------------------------------------------------------ 120 // Check if ctx has borrowed data 121 //------------------------------------------------------------------------------ 122 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, 123 bool *has_borrowed_data_of_type) { 124 CeedQFunctionContext_Cuda *impl; 125 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 126 127 switch (mem_type) { 128 case CEED_MEM_HOST: 129 *has_borrowed_data_of_type = !!impl->h_data_borrowed; 130 break; 131 case CEED_MEM_DEVICE: 132 *has_borrowed_data_of_type = !!impl->d_data_borrowed; 133 break; 134 } 135 136 return CEED_ERROR_SUCCESS; 137 } 138 139 //------------------------------------------------------------------------------ 140 // Check if data of given type needs sync 141 //------------------------------------------------------------------------------ 142 static inline int CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 143 CeedQFunctionContext_Cuda *impl; 144 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 145 146 bool has_valid_data = true; 147 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data)); 148 switch (mem_type) { 149 case CEED_MEM_HOST: 150 *need_sync = has_valid_data && !impl->h_data; 151 break; 152 case CEED_MEM_DEVICE: 153 *need_sync = has_valid_data && !impl->d_data; 154 break; 155 } 156 157 return CEED_ERROR_SUCCESS; 158 } 159 160 //------------------------------------------------------------------------------ 161 // Set data from host 162 //------------------------------------------------------------------------------ 163 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 164 CeedQFunctionContext_Cuda *impl; 165 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 166 167 CeedCallBackend(CeedFree(&impl->h_data_owned)); 168 switch (copy_mode) { 169 case CEED_COPY_VALUES: { 170 size_t ctxsize; 171 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 172 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 173 impl->h_data_borrowed = NULL; 174 impl->h_data = impl->h_data_owned; 175 memcpy(impl->h_data, data, ctxsize); 176 } break; 177 case CEED_OWN_POINTER: 178 impl->h_data_owned = data; 179 impl->h_data_borrowed = NULL; 180 impl->h_data = data; 181 break; 182 case CEED_USE_POINTER: 183 impl->h_data_borrowed = data; 184 impl->h_data = data; 185 break; 186 } 187 188 return CEED_ERROR_SUCCESS; 189 } 190 191 //------------------------------------------------------------------------------ 192 // Set data from device 193 //------------------------------------------------------------------------------ 194 static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 195 Ceed ceed; 196 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 197 CeedQFunctionContext_Cuda *impl; 198 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 199 200 CeedCallCuda(ceed, cudaFree(impl->d_data_owned)); 201 impl->d_data_owned = NULL; 202 switch (copy_mode) { 203 case CEED_COPY_VALUES: { 204 size_t ctxsize; 205 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 206 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctxsize)); 207 impl->d_data_borrowed = NULL; 208 impl->d_data = impl->d_data_owned; 209 CeedCallCuda(ceed, cudaMemcpy(impl->d_data, data, ctxsize, cudaMemcpyDeviceToDevice)); 210 } break; 211 case CEED_OWN_POINTER: 212 impl->d_data_owned = data; 213 impl->d_data_borrowed = NULL; 214 impl->d_data = data; 215 break; 216 case CEED_USE_POINTER: 217 impl->d_data_owned = NULL; 218 impl->d_data_borrowed = data; 219 impl->d_data = data; 220 break; 221 } 222 223 return CEED_ERROR_SUCCESS; 224 } 225 226 //------------------------------------------------------------------------------ 227 // Set the data used by a user context, 228 // freeing any previously allocated data if applicable 229 //------------------------------------------------------------------------------ 230 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 231 Ceed ceed; 232 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 233 234 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx)); 235 switch (mem_type) { 236 case CEED_MEM_HOST: 237 return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data); 238 case CEED_MEM_DEVICE: 239 return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data); 240 } 241 242 return CEED_ERROR_UNSUPPORTED; 243 } 244 245 //------------------------------------------------------------------------------ 246 // Take data 247 //------------------------------------------------------------------------------ 248 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 249 Ceed ceed; 250 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 251 CeedQFunctionContext_Cuda *impl; 252 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 253 254 // Sync data to requested mem_type 255 bool need_sync = false; 256 CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync)); 257 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type)); 258 259 // Update pointer 260 switch (mem_type) { 261 case CEED_MEM_HOST: 262 *(void **)data = impl->h_data_borrowed; 263 impl->h_data_borrowed = NULL; 264 impl->h_data = NULL; 265 break; 266 case CEED_MEM_DEVICE: 267 *(void **)data = impl->d_data_borrowed; 268 impl->d_data_borrowed = NULL; 269 impl->d_data = NULL; 270 break; 271 } 272 273 return CEED_ERROR_SUCCESS; 274 } 275 276 //------------------------------------------------------------------------------ 277 // Core logic for GetData. 278 // If a different memory type is most up to date, this will perform a copy 279 //------------------------------------------------------------------------------ 280 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 281 Ceed ceed; 282 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 283 CeedQFunctionContext_Cuda *impl; 284 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 285 286 // Sync data to requested mem_type 287 bool need_sync = false; 288 CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync)); 289 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type)); 290 291 // Update pointer 292 switch (mem_type) { 293 case CEED_MEM_HOST: 294 *(void **)data = impl->h_data; 295 break; 296 case CEED_MEM_DEVICE: 297 *(void **)data = impl->d_data; 298 break; 299 } 300 301 return CEED_ERROR_SUCCESS; 302 } 303 304 //------------------------------------------------------------------------------ 305 // Get read-only access to the data 306 //------------------------------------------------------------------------------ 307 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 308 return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data); 309 } 310 311 //------------------------------------------------------------------------------ 312 // Get read/write access to the data 313 //------------------------------------------------------------------------------ 314 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 315 CeedQFunctionContext_Cuda *impl; 316 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 317 318 CeedCallBackend(CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data)); 319 320 // Mark only pointer for requested memory as valid 321 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx)); 322 switch (mem_type) { 323 case CEED_MEM_HOST: 324 impl->h_data = *(void **)data; 325 break; 326 case CEED_MEM_DEVICE: 327 impl->d_data = *(void **)data; 328 break; 329 } 330 331 return CEED_ERROR_SUCCESS; 332 } 333 334 //------------------------------------------------------------------------------ 335 // Destroy the user context 336 //------------------------------------------------------------------------------ 337 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) { 338 Ceed ceed; 339 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 340 CeedQFunctionContext_Cuda *impl; 341 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 342 343 CeedCallCuda(ceed, cudaFree(impl->d_data_owned)); 344 CeedCallBackend(CeedFree(&impl->h_data_owned)); 345 CeedCallBackend(CeedFree(&impl)); 346 347 return CEED_ERROR_SUCCESS; 348 } 349 350 //------------------------------------------------------------------------------ 351 // QFunctionContext Create 352 //------------------------------------------------------------------------------ 353 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) { 354 CeedQFunctionContext_Cuda *impl; 355 Ceed ceed; 356 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 357 358 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Cuda)); 359 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Cuda)); 360 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Cuda)); 361 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Cuda)); 362 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda)); 363 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda)); 364 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda)); 365 366 CeedCallBackend(CeedCalloc(1, &impl)); 367 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 368 369 return CEED_ERROR_SUCCESS; 370 } 371 //------------------------------------------------------------------------------ 372