1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 #include <cuda_runtime.h> 11 #include <string.h> 12 13 #include "ceed-cuda-ref.h" 14 15 //------------------------------------------------------------------------------ 16 // Sync host to device 17 //------------------------------------------------------------------------------ 18 static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx) { 19 Ceed ceed; 20 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 21 CeedQFunctionContext_Cuda *impl; 22 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 23 24 if (!impl->h_data) { 25 // LCOV_EXCL_START 26 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 27 // LCOV_EXCL_STOP 28 } 29 30 size_t ctxsize; 31 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 32 33 if (impl->d_data_borrowed) { 34 impl->d_data = impl->d_data_borrowed; 35 } else if (impl->d_data_owned) { 36 impl->d_data = impl->d_data_owned; 37 } else { 38 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctxsize)); 39 impl->d_data = impl->d_data_owned; 40 } 41 42 CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctxsize, cudaMemcpyHostToDevice)); 43 44 return CEED_ERROR_SUCCESS; 45 } 46 47 //------------------------------------------------------------------------------ 48 // Sync device to host 49 //------------------------------------------------------------------------------ 50 static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx) { 51 Ceed ceed; 52 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 53 CeedQFunctionContext_Cuda *impl; 54 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 55 56 if (!impl->d_data) { 57 // LCOV_EXCL_START 58 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 59 // LCOV_EXCL_STOP 60 } 61 62 size_t ctxsize; 63 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 64 65 if (impl->h_data_borrowed) { 66 impl->h_data = impl->h_data_borrowed; 67 } else if (impl->h_data_owned) { 68 impl->h_data = impl->h_data_owned; 69 } else { 70 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 71 impl->h_data = impl->h_data_owned; 72 } 73 74 CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctxsize, cudaMemcpyDeviceToHost)); 75 76 return CEED_ERROR_SUCCESS; 77 } 78 79 //------------------------------------------------------------------------------ 80 // Sync data of type 81 //------------------------------------------------------------------------------ 82 static inline int CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type) { 83 switch (mem_type) { 84 case CEED_MEM_HOST: 85 return CeedQFunctionContextSyncD2H_Cuda(ctx); 86 case CEED_MEM_DEVICE: 87 return CeedQFunctionContextSyncH2D_Cuda(ctx); 88 } 89 return CEED_ERROR_UNSUPPORTED; 90 } 91 92 //------------------------------------------------------------------------------ 93 // Set all pointers as invalid 94 //------------------------------------------------------------------------------ 95 static inline int CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx) { 96 CeedQFunctionContext_Cuda *impl; 97 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 98 99 impl->h_data = NULL; 100 impl->d_data = NULL; 101 102 return CEED_ERROR_SUCCESS; 103 } 104 105 //------------------------------------------------------------------------------ 106 // Check if ctx has valid data 107 //------------------------------------------------------------------------------ 108 static inline int CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx, bool *has_valid_data) { 109 CeedQFunctionContext_Cuda *impl; 110 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 111 112 *has_valid_data = impl && (!!impl->h_data || !!impl->d_data); 113 114 return CEED_ERROR_SUCCESS; 115 } 116 117 //------------------------------------------------------------------------------ 118 // Check if ctx has borrowed data 119 //------------------------------------------------------------------------------ 120 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, 121 bool *has_borrowed_data_of_type) { 122 CeedQFunctionContext_Cuda *impl; 123 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 124 125 switch (mem_type) { 126 case CEED_MEM_HOST: 127 *has_borrowed_data_of_type = !!impl->h_data_borrowed; 128 break; 129 case CEED_MEM_DEVICE: 130 *has_borrowed_data_of_type = !!impl->d_data_borrowed; 131 break; 132 } 133 134 return CEED_ERROR_SUCCESS; 135 } 136 137 //------------------------------------------------------------------------------ 138 // Check if data of given type needs sync 139 //------------------------------------------------------------------------------ 140 static inline int CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 141 CeedQFunctionContext_Cuda *impl; 142 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 143 144 bool has_valid_data = true; 145 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data)); 146 switch (mem_type) { 147 case CEED_MEM_HOST: 148 *need_sync = has_valid_data && !impl->h_data; 149 break; 150 case CEED_MEM_DEVICE: 151 *need_sync = has_valid_data && !impl->d_data; 152 break; 153 } 154 155 return CEED_ERROR_SUCCESS; 156 } 157 158 //------------------------------------------------------------------------------ 159 // Set data from host 160 //------------------------------------------------------------------------------ 161 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 162 CeedQFunctionContext_Cuda *impl; 163 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 164 165 CeedCallBackend(CeedFree(&impl->h_data_owned)); 166 switch (copy_mode) { 167 case CEED_COPY_VALUES: { 168 size_t ctxsize; 169 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 170 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 171 impl->h_data_borrowed = NULL; 172 impl->h_data = impl->h_data_owned; 173 memcpy(impl->h_data, data, ctxsize); 174 } break; 175 case CEED_OWN_POINTER: 176 impl->h_data_owned = data; 177 impl->h_data_borrowed = NULL; 178 impl->h_data = data; 179 break; 180 case CEED_USE_POINTER: 181 impl->h_data_borrowed = data; 182 impl->h_data = data; 183 break; 184 } 185 186 return CEED_ERROR_SUCCESS; 187 } 188 189 //------------------------------------------------------------------------------ 190 // Set data from device 191 //------------------------------------------------------------------------------ 192 static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 193 Ceed ceed; 194 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 195 CeedQFunctionContext_Cuda *impl; 196 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 197 198 CeedCallCuda(ceed, cudaFree(impl->d_data_owned)); 199 impl->d_data_owned = NULL; 200 switch (copy_mode) { 201 case CEED_COPY_VALUES: { 202 size_t ctxsize; 203 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 204 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctxsize)); 205 impl->d_data_borrowed = NULL; 206 impl->d_data = impl->d_data_owned; 207 CeedCallCuda(ceed, cudaMemcpy(impl->d_data, data, ctxsize, cudaMemcpyDeviceToDevice)); 208 } break; 209 case CEED_OWN_POINTER: 210 impl->d_data_owned = data; 211 impl->d_data_borrowed = NULL; 212 impl->d_data = data; 213 break; 214 case CEED_USE_POINTER: 215 impl->d_data_owned = NULL; 216 impl->d_data_borrowed = data; 217 impl->d_data = data; 218 break; 219 } 220 221 return CEED_ERROR_SUCCESS; 222 } 223 224 //------------------------------------------------------------------------------ 225 // Set the data used by a user context, 226 // freeing any previously allocated data if applicable 227 //------------------------------------------------------------------------------ 228 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 229 Ceed ceed; 230 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 231 232 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx)); 233 switch (mem_type) { 234 case CEED_MEM_HOST: 235 return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data); 236 case CEED_MEM_DEVICE: 237 return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data); 238 } 239 240 return CEED_ERROR_UNSUPPORTED; 241 } 242 243 //------------------------------------------------------------------------------ 244 // Take data 245 //------------------------------------------------------------------------------ 246 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 247 Ceed ceed; 248 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 249 CeedQFunctionContext_Cuda *impl; 250 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 251 252 // Sync data to requested mem_type 253 bool need_sync = false; 254 CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync)); 255 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type)); 256 257 // Update pointer 258 switch (mem_type) { 259 case CEED_MEM_HOST: 260 *(void **)data = impl->h_data_borrowed; 261 impl->h_data_borrowed = NULL; 262 impl->h_data = NULL; 263 break; 264 case CEED_MEM_DEVICE: 265 *(void **)data = impl->d_data_borrowed; 266 impl->d_data_borrowed = NULL; 267 impl->d_data = NULL; 268 break; 269 } 270 271 return CEED_ERROR_SUCCESS; 272 } 273 274 //------------------------------------------------------------------------------ 275 // Core logic for GetData. 276 // If a different memory type is most up to date, this will perform a copy 277 //------------------------------------------------------------------------------ 278 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 279 Ceed ceed; 280 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 281 CeedQFunctionContext_Cuda *impl; 282 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 283 284 // Sync data to requested mem_type 285 bool need_sync = false; 286 CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync)); 287 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type)); 288 289 // Update pointer 290 switch (mem_type) { 291 case CEED_MEM_HOST: 292 *(void **)data = impl->h_data; 293 break; 294 case CEED_MEM_DEVICE: 295 *(void **)data = impl->d_data; 296 break; 297 } 298 299 return CEED_ERROR_SUCCESS; 300 } 301 302 //------------------------------------------------------------------------------ 303 // Get read-only access to the data 304 //------------------------------------------------------------------------------ 305 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 306 return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data); 307 } 308 309 //------------------------------------------------------------------------------ 310 // Get read/write access to the data 311 //------------------------------------------------------------------------------ 312 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 313 CeedQFunctionContext_Cuda *impl; 314 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 315 316 CeedCallBackend(CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data)); 317 318 // Mark only pointer for requested memory as valid 319 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx)); 320 switch (mem_type) { 321 case CEED_MEM_HOST: 322 impl->h_data = *(void **)data; 323 break; 324 case CEED_MEM_DEVICE: 325 impl->d_data = *(void **)data; 326 break; 327 } 328 329 return CEED_ERROR_SUCCESS; 330 } 331 332 //------------------------------------------------------------------------------ 333 // Destroy the user context 334 //------------------------------------------------------------------------------ 335 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) { 336 Ceed ceed; 337 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 338 CeedQFunctionContext_Cuda *impl; 339 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 340 341 CeedCallCuda(ceed, cudaFree(impl->d_data_owned)); 342 CeedCallBackend(CeedFree(&impl->h_data_owned)); 343 CeedCallBackend(CeedFree(&impl)); 344 345 return CEED_ERROR_SUCCESS; 346 } 347 348 //------------------------------------------------------------------------------ 349 // QFunctionContext Create 350 //------------------------------------------------------------------------------ 351 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) { 352 CeedQFunctionContext_Cuda *impl; 353 Ceed ceed; 354 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 355 356 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Cuda)); 357 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Cuda)); 358 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Cuda)); 359 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Cuda)); 360 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda)); 361 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda)); 362 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda)); 363 364 CeedCallBackend(CeedCalloc(1, &impl)); 365 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 366 367 return CEED_ERROR_SUCCESS; 368 } 369 //------------------------------------------------------------------------------ 370