1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/ceed.h> 9 #include <ceed/backend.h> 10 #include <cuda_runtime.h> 11 #include <string.h> 12 #include "ceed-cuda-ref.h" 13 14 //------------------------------------------------------------------------------ 15 // * Bytes used 16 //------------------------------------------------------------------------------ 17 static inline size_t bytes(const CeedQFunctionContext ctx) { 18 int ierr; 19 size_t ctxsize; 20 ierr = CeedQFunctionContextGetContextSize(ctx, &ctxsize); CeedChkBackend(ierr); 21 return ctxsize; 22 } 23 24 //------------------------------------------------------------------------------ 25 // Sync host to device 26 //------------------------------------------------------------------------------ 27 static inline int CeedQFunctionContextSyncH2D_Cuda( 28 const CeedQFunctionContext ctx) { 29 int ierr; 30 Ceed ceed; 31 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 32 CeedQFunctionContext_Cuda *impl; 33 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 34 35 if (!impl->h_data) 36 // LCOV_EXCL_START 37 return CeedError(ceed, CEED_ERROR_BACKEND, 38 "No valid host data to sync to device"); 39 // LCOV_EXCL_STOP 40 41 if (impl->d_data_borrowed) { 42 impl->d_data = impl->d_data_borrowed; 43 } else if (impl->d_data_owned) { 44 impl->d_data = impl->d_data_owned; 45 } else { 46 ierr = cudaMalloc((void **)&impl->d_data_owned, bytes(ctx)); 47 CeedChk_Cu(ceed, ierr); 48 impl->d_data = impl->d_data_owned; 49 } 50 51 ierr = cudaMemcpy(impl->d_data, impl->h_data, bytes(ctx), 52 cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 53 54 return CEED_ERROR_SUCCESS; 55 } 56 57 //------------------------------------------------------------------------------ 58 // Sync device to host 59 //------------------------------------------------------------------------------ 60 static inline int CeedQFunctionContextSyncD2H_Cuda( 61 const CeedQFunctionContext ctx) { 62 int ierr; 63 Ceed ceed; 64 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 65 CeedQFunctionContext_Cuda *impl; 66 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 67 68 if (!impl->d_data) 69 // LCOV_EXCL_START 70 return CeedError(ceed, CEED_ERROR_BACKEND, 71 "No valid device data to sync to host"); 72 // LCOV_EXCL_STOP 73 74 if (impl->h_data_borrowed) { 75 impl->h_data = impl->h_data_borrowed; 76 } else if (impl->h_data_owned) { 77 impl->h_data = impl->h_data_owned; 78 } else { 79 ierr = CeedMalloc(bytes(ctx), &impl->h_data_owned); 80 CeedChkBackend(ierr); 81 impl->h_data = impl->h_data_owned; 82 } 83 84 ierr = cudaMemcpy(impl->h_data, impl->d_data, bytes(ctx), 85 cudaMemcpyDeviceToHost); CeedChk_Cu(ceed, ierr); 86 87 return CEED_ERROR_SUCCESS; 88 } 89 90 //------------------------------------------------------------------------------ 91 // Sync data of type 92 //------------------------------------------------------------------------------ 93 static inline int CeedQFunctionContextSync_Cuda( 94 const CeedQFunctionContext ctx, CeedMemType mem_type) { 95 switch (mem_type) { 96 case CEED_MEM_HOST: return CeedQFunctionContextSyncD2H_Cuda(ctx); 97 case CEED_MEM_DEVICE: return CeedQFunctionContextSyncH2D_Cuda(ctx); 98 } 99 return CEED_ERROR_UNSUPPORTED; 100 } 101 102 //------------------------------------------------------------------------------ 103 // Set all pointers as invalid 104 //------------------------------------------------------------------------------ 105 static inline int CeedQFunctionContextSetAllInvalid_Cuda( 106 const CeedQFunctionContext ctx) { 107 int ierr; 108 CeedQFunctionContext_Cuda *impl; 109 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 110 111 impl->h_data = NULL; 112 impl->d_data = NULL; 113 114 return CEED_ERROR_SUCCESS; 115 } 116 117 //------------------------------------------------------------------------------ 118 // Check if ctx has valid data 119 //------------------------------------------------------------------------------ 120 static inline int CeedQFunctionContextHasValidData_Cuda( 121 const CeedQFunctionContext ctx, bool *has_valid_data) { 122 int ierr; 123 CeedQFunctionContext_Cuda *impl; 124 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 125 126 *has_valid_data = !!impl->h_data || !!impl->d_data; 127 128 return CEED_ERROR_SUCCESS; 129 } 130 131 //------------------------------------------------------------------------------ 132 // Check if ctx has borrowed data 133 //------------------------------------------------------------------------------ 134 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda( 135 const CeedQFunctionContext ctx, CeedMemType mem_type, 136 bool *has_borrowed_data_of_type) { 137 int ierr; 138 CeedQFunctionContext_Cuda *impl; 139 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 140 141 switch (mem_type) { 142 case CEED_MEM_HOST: 143 *has_borrowed_data_of_type = !!impl->h_data_borrowed; 144 break; 145 case CEED_MEM_DEVICE: 146 *has_borrowed_data_of_type = !!impl->d_data_borrowed; 147 break; 148 } 149 150 return CEED_ERROR_SUCCESS; 151 } 152 153 //------------------------------------------------------------------------------ 154 // Check if data of given type needs sync 155 //------------------------------------------------------------------------------ 156 static inline int CeedQFunctionContextNeedSync_Cuda( 157 const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 158 int ierr; 159 CeedQFunctionContext_Cuda *impl; 160 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 161 162 bool has_valid_data = true; 163 ierr = CeedQFunctionContextHasValidData(ctx, &has_valid_data); 164 CeedChkBackend(ierr); 165 switch (mem_type) { 166 case CEED_MEM_HOST: 167 *need_sync = has_valid_data && !impl->h_data; 168 break; 169 case CEED_MEM_DEVICE: 170 *need_sync = has_valid_data && !impl->d_data; 171 break; 172 } 173 174 return CEED_ERROR_SUCCESS; 175 } 176 177 //------------------------------------------------------------------------------ 178 // Set data from host 179 //------------------------------------------------------------------------------ 180 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, 181 const CeedCopyMode copy_mode, void *data) { 182 int ierr; 183 CeedQFunctionContext_Cuda *impl; 184 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 185 186 ierr = CeedFree(&impl->h_data_owned); CeedChkBackend(ierr); 187 switch (copy_mode) { 188 case CEED_COPY_VALUES: { 189 ierr = CeedMalloc(bytes(ctx), &impl->h_data_owned); CeedChkBackend(ierr); 190 impl->h_data_borrowed = NULL; 191 impl->h_data = impl->h_data_owned; 192 memcpy(impl->h_data, data, bytes(ctx)); 193 } break; 194 case CEED_OWN_POINTER: 195 impl->h_data_owned = data; 196 impl->h_data_borrowed = NULL; 197 impl->h_data = data; 198 break; 199 case CEED_USE_POINTER: 200 impl->h_data_borrowed = data; 201 impl->h_data = data; 202 break; 203 } 204 205 return CEED_ERROR_SUCCESS; 206 } 207 208 //------------------------------------------------------------------------------ 209 // Set data from device 210 //------------------------------------------------------------------------------ 211 static int CeedQFunctionContextSetDataDevice_Cuda( 212 const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 213 int ierr; 214 Ceed ceed; 215 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 216 CeedQFunctionContext_Cuda *impl; 217 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 218 219 ierr = cudaFree(impl->d_data_owned); CeedChk_Cu(ceed, ierr); 220 impl->d_data_owned = NULL; 221 switch (copy_mode) { 222 case CEED_COPY_VALUES: 223 ierr = cudaMalloc((void **)&impl->d_data_owned, bytes(ctx)); 224 CeedChk_Cu(ceed, ierr); 225 impl->d_data_borrowed = NULL; 226 impl->d_data = impl->d_data_owned; 227 ierr = cudaMemcpy(impl->d_data, data, bytes(ctx), 228 cudaMemcpyDeviceToDevice); CeedChk_Cu(ceed, ierr); 229 break; 230 case CEED_OWN_POINTER: 231 impl->d_data_owned = data; 232 impl->d_data_borrowed = NULL; 233 impl->d_data = data; 234 break; 235 case CEED_USE_POINTER: 236 impl->d_data_owned = NULL; 237 impl->d_data_borrowed = data; 238 impl->d_data = data; 239 break; 240 } 241 242 return CEED_ERROR_SUCCESS; 243 } 244 245 //------------------------------------------------------------------------------ 246 // Set the data used by a user context, 247 // freeing any previously allocated data if applicable 248 //------------------------------------------------------------------------------ 249 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, 250 const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 251 int ierr; 252 Ceed ceed; 253 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 254 255 ierr = CeedQFunctionContextSetAllInvalid_Cuda(ctx); CeedChkBackend(ierr); 256 switch (mem_type) { 257 case CEED_MEM_HOST: 258 return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data); 259 case CEED_MEM_DEVICE: 260 return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data); 261 } 262 263 return CEED_ERROR_UNSUPPORTED; 264 } 265 266 //------------------------------------------------------------------------------ 267 // Take data 268 //------------------------------------------------------------------------------ 269 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, 270 const CeedMemType mem_type, void *data) { 271 int ierr; 272 Ceed ceed; 273 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 274 CeedQFunctionContext_Cuda *impl; 275 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 276 277 // Sync data to requested mem_type 278 bool need_sync = false; 279 ierr = CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync); 280 CeedChkBackend(ierr); 281 if (need_sync) { 282 ierr = CeedQFunctionContextSync_Cuda(ctx, mem_type); CeedChkBackend(ierr); 283 } 284 285 // Update pointer 286 switch (mem_type) { 287 case CEED_MEM_HOST: 288 *(void **)data = impl->h_data_borrowed; 289 impl->h_data_borrowed = NULL; 290 impl->h_data = NULL; 291 break; 292 case CEED_MEM_DEVICE: 293 *(void **)data = impl->d_data_borrowed; 294 impl->d_data_borrowed = NULL; 295 impl->d_data = NULL; 296 break; 297 } 298 299 return CEED_ERROR_SUCCESS; 300 } 301 302 //------------------------------------------------------------------------------ 303 // Core logic for GetData. 304 // If a different memory type is most up to date, this will perform a copy 305 //------------------------------------------------------------------------------ 306 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, 307 const CeedMemType mem_type, void *data) { 308 int ierr; 309 Ceed ceed; 310 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 311 CeedQFunctionContext_Cuda *impl; 312 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 313 314 // Sync data to requested mem_type 315 bool need_sync = false; 316 ierr = CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync); 317 CeedChkBackend(ierr); 318 if (need_sync) { 319 ierr = CeedQFunctionContextSync_Cuda(ctx, mem_type); CeedChkBackend(ierr); 320 } 321 322 // Update pointer 323 switch (mem_type) { 324 case CEED_MEM_HOST: 325 *(void **)data = impl->h_data; 326 break; 327 case CEED_MEM_DEVICE: 328 *(void **)data = impl->d_data; 329 break; 330 } 331 332 return CEED_ERROR_SUCCESS; 333 } 334 335 //------------------------------------------------------------------------------ 336 // Get read-only access to the data 337 //------------------------------------------------------------------------------ 338 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, 339 const CeedMemType mem_type, void *data) { 340 return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data); 341 } 342 343 //------------------------------------------------------------------------------ 344 // Get read/write access to the data 345 //------------------------------------------------------------------------------ 346 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, 347 const CeedMemType mem_type, void *data) { 348 int ierr; 349 CeedQFunctionContext_Cuda *impl; 350 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 351 352 ierr = CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data); 353 CeedChkBackend(ierr); 354 355 // Mark only pointer for requested memory as valid 356 ierr = CeedQFunctionContextSetAllInvalid_Cuda(ctx); CeedChkBackend(ierr); 357 switch (mem_type) { 358 case CEED_MEM_HOST: 359 impl->h_data = *(void **)data; 360 break; 361 case CEED_MEM_DEVICE: 362 impl->d_data = *(void **)data; 363 break; 364 } 365 366 return CEED_ERROR_SUCCESS; 367 } 368 369 //------------------------------------------------------------------------------ 370 // Destroy the user context 371 //------------------------------------------------------------------------------ 372 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) { 373 int ierr; 374 Ceed ceed; 375 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 376 CeedQFunctionContext_Cuda *impl; 377 ierr = CeedQFunctionContextGetBackendData(ctx, &impl); CeedChkBackend(ierr); 378 379 ierr = cudaFree(impl->d_data_owned); CeedChk_Cu(ceed, ierr); 380 ierr = CeedFree(&impl->h_data_owned); CeedChkBackend(ierr); 381 ierr = CeedFree(&impl); CeedChkBackend(ierr); 382 383 return CEED_ERROR_SUCCESS; 384 } 385 386 //------------------------------------------------------------------------------ 387 // QFunctionContext Create 388 //------------------------------------------------------------------------------ 389 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) { 390 int ierr; 391 CeedQFunctionContext_Cuda *impl; 392 Ceed ceed; 393 ierr = CeedQFunctionContextGetCeed(ctx, &ceed); CeedChkBackend(ierr); 394 395 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", 396 CeedQFunctionContextHasValidData_Cuda); 397 CeedChkBackend(ierr); 398 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, 399 "HasBorrowedDataOfType", 400 CeedQFunctionContextHasBorrowedDataOfType_Cuda); 401 CeedChkBackend(ierr); 402 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", 403 CeedQFunctionContextSetData_Cuda); CeedChkBackend(ierr); 404 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", 405 CeedQFunctionContextTakeData_Cuda); CeedChkBackend(ierr); 406 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", 407 CeedQFunctionContextGetData_Cuda); CeedChkBackend(ierr); 408 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", 409 CeedQFunctionContextGetDataRead_Cuda); CeedChkBackend(ierr); 410 ierr = CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", 411 CeedQFunctionContextDestroy_Cuda); CeedChkBackend(ierr); 412 413 ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr); 414 ierr = CeedQFunctionContextSetBackendData(ctx, impl); CeedChkBackend(ierr); 415 416 return CEED_ERROR_SUCCESS; 417 } 418 //------------------------------------------------------------------------------ 419