1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include <cuda_runtime.h> 11 #include <string.h> 12 #include "ceed-cuda-ref.h" 13 14 //------------------------------------------------------------------------------ 15 // Sync host to device 16 //------------------------------------------------------------------------------ 17 static inline int CeedQFunctionContextSyncH2D_Cuda( 18 const CeedQFunctionContext ctx) { 19 int ierr; 20 Ceed ceed; 21 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 22 CeedQFunctionContext_Cuda *impl; 23 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 24 25 if (!impl->h_data) 26 // LCOV_EXCL_START 27 return CeedError(ceed, CEED_ERROR_BACKEND, 28 "No valid host data to sync to device"); 29 // LCOV_EXCL_STOP 30 31 size_t ctxsize; 32 ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr); 33 34 if (impl->d_data_borrowed) { 35 impl->d_data = impl->d_data_borrowed; 36 } else if (impl->d_data_owned) { 37 impl->d_data = impl->d_data_owned; 38 } else { 39 ierr = cudaMalloc((void **)&impl->d_data_owned, ctxsize); 40 CeedChk_Cu(ceed, ierr); 41 impl->d_data = impl->d_data_owned; 42 } 43 44 ierr = cudaMemcpy(impl->d_data, impl->h_data, ctxsize, 45 cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 46 47 return CEED_ERROR_SUCCESS; 48 } 49 50 //------------------------------------------------------------------------------ 51 // Sync device to host 52 //------------------------------------------------------------------------------ 53 static inline int CeedQFunctionContextSyncD2H_Cuda( 54 const CeedQFunctionContext ctx) { 55 int ierr; 56 Ceed ceed; 57 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 58 CeedQFunctionContext_Cuda *impl; 59 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 60 61 if (!impl->d_data) 62 // LCOV_EXCL_START 63 return CeedError(ceed, CEED_ERROR_BACKEND, 64 "No valid device data to sync to host"); 65 // LCOV_EXCL_STOP 66 67 size_t ctxsize; 68 ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr); 69 70 if (impl->h_data_borrowed) { 71 impl->h_data = impl->h_data_borrowed; 72 } else if (impl->h_data_owned) { 73 impl->h_data = impl->h_data_owned; 74 } else { 75 ierr = CeedMalloc(ctxsize, &impl->h_data_owned); 76 CeedChkBackend(ierr); 77 impl->h_data = impl->h_data_owned; 78 } 79 80 ierr = cudaMemcpy(impl->h_data, impl->d_data, ctxsize, 81 cudaMemcpyDeviceToHost); CeedChk_Cu(ceed, ierr); 82 83 return CEED_ERROR_SUCCESS; 84 } 85 86 //------------------------------------------------------------------------------ 87 // Sync data of type 88 //------------------------------------------------------------------------------ 89 static inline int CeedQFunctionContextSync_Cuda( 90 const CeedQFunctionContext ctx, CeedMemType mem_type) { 91 switch (mem_type) { 92 case CEED_MEM_HOST: return CeedQFunctionContextSyncD2H_Cuda(ctx); 93 case CEED_MEM_DEVICE: return CeedQFunctionContextSyncH2D_Cuda(ctx); 94 } 95 return CEED_ERROR_UNSUPPORTED; 96 } 97 98 //------------------------------------------------------------------------------ 99 // Set all pointers as invalid 100 //------------------------------------------------------------------------------ 101 static inline int CeedQFunctionContextSetAllInvalid_Cuda( 102 const CeedQFunctionContext ctx) { 103 int ierr; 104 CeedQFunctionContext_Cuda *impl; 105 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 106 107 impl->h_data = NULL; 108 impl->d_data = NULL; 109 110 return CEED_ERROR_SUCCESS; 111 } 112 113 //------------------------------------------------------------------------------ 114 // Check if ctx has valid data 115 //------------------------------------------------------------------------------ 116 static inline int CeedQFunctionContextHasValidData_Cuda( 117 const CeedQFunctionContext ctx, bool *has_valid_data) { 118 int ierr; 119 CeedQFunctionContext_Cuda *impl; 120 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 121 122 *has_valid_data = !!impl->h_data || !!impl->d_data; 123 124 return CEED_ERROR_SUCCESS; 125 } 126 127 //------------------------------------------------------------------------------ 128 // Check if ctx has borrowed data 129 //------------------------------------------------------------------------------ 130 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda( 131 const CeedQFunctionContext ctx, CeedMemType mem_type, 132 bool *has_borrowed_data_of_type) { 133 int ierr; 134 CeedQFunctionContext_Cuda *impl; 135 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 136 137 switch (mem_type) { 138 case CEED_MEM_HOST: 139 *has_borrowed_data_of_type = !!impl->h_data_borrowed; 140 break; 141 case CEED_MEM_DEVICE: 142 *has_borrowed_data_of_type = !!impl->d_data_borrowed; 143 break; 144 } 145 146 return CEED_ERROR_SUCCESS; 147 } 148 149 //------------------------------------------------------------------------------ 150 // Check if data of given type needs sync 151 //------------------------------------------------------------------------------ 152 static inline int CeedQFunctionContextNeedSync_Cuda( 153 const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 154 int ierr; 155 CeedQFunctionContext_Cuda *impl; 156 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 157 158 bool has_valid_data = true; 159 ierr = CeedQFunctionContextHasValidData(ctx, &has_valid_data); 160 CeedChkBackend(ierr); 161 switch (mem_type) { 162 case CEED_MEM_HOST: 163 *need_sync = has_valid_data && !impl->h_data; 164 break; 165 case CEED_MEM_DEVICE: 166 *need_sync = has_valid_data && !impl->d_data; 167 break; 168 } 169 170 return CEED_ERROR_SUCCESS; 171 } 172 173 //------------------------------------------------------------------------------ 174 // Set data from host 175 //------------------------------------------------------------------------------ 176 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, 177 const CeedCopyMode copy_mode, void *data) { 178 int ierr; 179 CeedQFunctionContext_Cuda *impl; 180 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 181 182 ierr = CeedFree(&impl->h_data_owned); CeedChkBackend(ierr); 183 switch (copy_mode) { 184 case CEED_COPY_VALUES: { 185 size_t ctxsize; 186 ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr); 187 ierr = CeedMalloc(ctxsize, &impl->h_data_owned); CeedChkBackend(ierr); 188 impl->h_data_borrowed = NULL; 189 impl->h_data = impl->h_data_owned; 190 memcpy(impl->h_data, data, ctxsize); 191 } break; 192 case CEED_OWN_POINTER: 193 impl->h_data_owned = data; 194 impl->h_data_borrowed = NULL; 195 impl->h_data = data; 196 break; 197 case CEED_USE_POINTER: 198 impl->h_data_borrowed = data; 199 impl->h_data = data; 200 break; 201 } 202 203 return CEED_ERROR_SUCCESS; 204 } 205 206 //------------------------------------------------------------------------------ 207 // Set data from device 208 //------------------------------------------------------------------------------ 209 static int CeedQFunctionContextSetDataDevice_Cuda( 210 const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 211 int ierr; 212 Ceed ceed; 213 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 214 CeedQFunctionContext_Cuda *impl; 215 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 216 217 ierr = cudaFree(impl->d_data_owned); CeedChk_Cu(ceed, ierr); 218 impl->d_data_owned = NULL; 219 switch (copy_mode) { 220 case CEED_COPY_VALUES: { 221 size_t ctxsize; 222 ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr); 223 ierr = cudaMalloc((void **)&impl->d_data_owned, ctxsize); 224 CeedChk_Cu(ceed, ierr); 225 impl->d_data_borrowed = NULL; 226 impl->d_data = impl->d_data_owned; 227 ierr = cudaMemcpy(impl->d_data, data, ctxsize, 228 cudaMemcpyDeviceToDevice); CeedChk_Cu(ceed, ierr); 229 } break; 230 case CEED_OWN_POINTER: 231 impl->d_data_owned = data; 232 impl->d_data_borrowed = NULL; 233 impl->d_data = data; 234 break; 235 case CEED_USE_POINTER: 236 impl->d_data_owned = NULL; 237 impl->d_data_borrowed = data; 238 impl->d_data = data; 239 break; 240 } 241 242 return CEED_ERROR_SUCCESS; 243 } 244 245 //------------------------------------------------------------------------------ 246 // Set the data used by a user context, 247 // freeing any previously allocated data if applicable 248 //------------------------------------------------------------------------------ 249 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, 250 const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 251 int ierr; 252 Ceed ceed; 253 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 254 255 ierr = CeedQFunctionContextSetAllInvalid_Cuda(ctx); CeedChkBackend(ierr); 256 switch (mem_type) { 257 case CEED_MEM_HOST: 258 return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data); 259 case CEED_MEM_DEVICE: 260 return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data); 261 } 262 263 return CEED_ERROR_UNSUPPORTED; 264 } 265 266 //------------------------------------------------------------------------------ 267 // Take data 268 //------------------------------------------------------------------------------ 269 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, 270 const CeedMemType mem_type, void *data) { 271 int ierr; 272 Ceed ceed; 273 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 274 CeedQFunctionContext_Cuda *impl; 275 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 276 277 // Sync data to requested mem_type 278 bool need_sync = false; 279 ierr = CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync); 280 CeedChkBackend(ierr); 281 if (need_sync) { 282 ierr = CeedQFunctionContextSync_Cuda(ctx, mem_type); CeedChkBackend(ierr); 283 } 284 285 // Update pointer 286 switch (mem_type) { 287 case CEED_MEM_HOST: 288 *(void **)data = impl->h_data_borrowed; 289 impl->h_data_borrowed = NULL; 290 impl->h_data = NULL; 291 break; 292 case CEED_MEM_DEVICE: 293 *(void **)data = impl->d_data_borrowed; 294 impl->d_data_borrowed = NULL; 295 impl->d_data = NULL; 296 break; 297 } 298 299 return CEED_ERROR_SUCCESS; 300 } 301 302 //------------------------------------------------------------------------------ 303 // Core logic for GetData. 304 // If a different memory type is most up to date, this will perform a copy 305 //------------------------------------------------------------------------------ 306 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, 307 const CeedMemType mem_type, void *data) { 308 int ierr; 309 Ceed ceed; 310 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 311 CeedQFunctionContext_Cuda *impl; 312 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 313 314 // Sync data to requested mem_type 315 bool need_sync = false; 316 ierr = CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync); 317 CeedChkBackend(ierr); 318 if (need_sync) { 319 ierr = CeedQFunctionContextSync_Cuda(ctx, mem_type); CeedChkBackend(ierr); 320 } 321 322 // Update pointer 323 switch (mem_type) { 324 case CEED_MEM_HOST: 325 *(void **)data = impl->h_data; 326 break; 327 case CEED_MEM_DEVICE: 328 *(void **)data = impl->d_data; 329 break; 330 } 331 332 return CEED_ERROR_SUCCESS; 333 } 334 335 //------------------------------------------------------------------------------ 336 // Get read-only access to the data 337 //------------------------------------------------------------------------------ 338 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, 339 const CeedMemType mem_type, void *data) { 340 return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data); 341 } 342 343 //------------------------------------------------------------------------------ 344 // Get read/write access to the data 345 //------------------------------------------------------------------------------ 346 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, 347 const CeedMemType mem_type, void *data) { 348 int ierr; 349 CeedQFunctionContext_Cuda *impl; 350 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 351 352 ierr = CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data); 353 CeedChkBackend(ierr); 354 355 // Mark only pointer for requested memory as valid 356 ierr = CeedQFunctionContextSetAllInvalid_Cuda(ctx); CeedChkBackend(ierr); 357 switch (mem_type) { 358 case CEED_MEM_HOST: 359 impl->h_data = *(void **)data; 360 break; 361 case CEED_MEM_DEVICE: 362 impl->d_data = *(void **)data; 363 break; 364 } 365 366 return CEED_ERROR_SUCCESS; 367 } 368 369 //------------------------------------------------------------------------------ 370 // Destroy the user context 371 //------------------------------------------------------------------------------ 372 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) { 373 int ierr; 374 Ceed ceed; 375 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 376 CeedQFunctionContext_Cuda *impl; 377 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 378 379 ierr = cudaFree(impl->d_data_owned); CeedChk_Cu(ceed, ierr); 380 ierr = CeedFree(&impl->h_data_owned); CeedChkBackend(ierr); 381 ierr = CeedFree(&impl); CeedChkBackend(ierr); 382 383 return CEED_ERROR_SUCCESS; 384 } 385 386 //------------------------------------------------------------------------------ 387 // QFunctionContext Create 388 //------------------------------------------------------------------------------ 389 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) { 390 int ierr; 391 CeedQFunctionContext_Cuda *impl; 392 Ceed ceed; 393 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 394 395 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", 396 CeedQFunctionContextHasValidData_Cuda); 397 CeedChkBackend(ierr); 398 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, 399 "HasBorrowedDataOfType", 400 CeedQFunctionContextHasBorrowedDataOfType_Cuda); 401 CeedChkBackend(ierr); 402 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", 403 CeedQFunctionContextSetData_Cuda); CeedChkBackend(ierr); 404 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", 405 CeedQFunctionContextTakeData_Cuda); CeedChkBackend(ierr); 406 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", 407 CeedQFunctionContextGetData_Cuda); CeedChkBackend(ierr); 408 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", 409 CeedQFunctionContextGetDataRead_Cuda); CeedChkBackend(ierr); 410 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", 411 CeedQFunctionContextDestroy_Cuda); CeedChkBackend(ierr); 412 413 ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr); 414 ierr = CeedQFunctionContextSetBackendData(ctx, impl); CeedChkBackend(ierr); 415 416 return CEED_ERROR_SUCCESS; 417 } 418 //------------------------------------------------------------------------------ 419