1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <cuda_runtime.h> 11 #include <stdbool.h> 12 #include <string.h> 13 14 #include "../cuda/ceed-cuda-common.h" 15 #include "ceed-cuda-ref.h" 16 17 //------------------------------------------------------------------------------ 18 // Sync host to device 19 //------------------------------------------------------------------------------ 20 static inline int CeedQFunctionContextSyncH2D_Cuda(const CeedQFunctionContext ctx) { 21 Ceed ceed; 22 size_t ctx_size; 23 CeedQFunctionContext_Cuda *impl; 24 25 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 26 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 27 28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 29 30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 31 if (impl->d_data_borrowed) { 32 impl->d_data = impl->d_data_borrowed; 33 } else if (impl->d_data_owned) { 34 impl->d_data = impl->d_data_owned; 35 } else { 36 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size)); 37 impl->d_data = impl->d_data_owned; 38 } 39 CeedCallCuda(ceed, cudaMemcpy(impl->d_data, impl->h_data, ctx_size, cudaMemcpyHostToDevice)); 40 return CEED_ERROR_SUCCESS; 41 } 42 43 //------------------------------------------------------------------------------ 44 // Sync device to host 45 //------------------------------------------------------------------------------ 46 static inline int CeedQFunctionContextSyncD2H_Cuda(const CeedQFunctionContext ctx) { 47 Ceed ceed; 48 size_t ctx_size; 49 CeedQFunctionContext_Cuda *impl; 50 51 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 52 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 53 54 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 55 56 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 57 58 if (impl->h_data_borrowed) { 59 impl->h_data = impl->h_data_borrowed; 60 } else if (impl->h_data_owned) { 61 impl->h_data = impl->h_data_owned; 62 } else { 63 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 64 impl->h_data = impl->h_data_owned; 65 } 66 CeedCallCuda(ceed, cudaMemcpy(impl->h_data, impl->d_data, ctx_size, cudaMemcpyDeviceToHost)); 67 return CEED_ERROR_SUCCESS; 68 } 69 70 //------------------------------------------------------------------------------ 71 // Sync data of type 72 //------------------------------------------------------------------------------ 73 static inline int CeedQFunctionContextSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type) { 74 switch (mem_type) { 75 case CEED_MEM_HOST: 76 return CeedQFunctionContextSyncD2H_Cuda(ctx); 77 case CEED_MEM_DEVICE: 78 return CeedQFunctionContextSyncH2D_Cuda(ctx); 79 } 80 return CEED_ERROR_UNSUPPORTED; 81 } 82 83 //------------------------------------------------------------------------------ 84 // Set all pointers as invalid 85 //------------------------------------------------------------------------------ 86 static inline int CeedQFunctionContextSetAllInvalid_Cuda(const CeedQFunctionContext ctx) { 87 CeedQFunctionContext_Cuda *impl; 88 89 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 90 impl->h_data = NULL; 91 impl->d_data = NULL; 92 return CEED_ERROR_SUCCESS; 93 } 94 95 //------------------------------------------------------------------------------ 96 // Check if ctx has valid data 97 //------------------------------------------------------------------------------ 98 static inline int CeedQFunctionContextHasValidData_Cuda(const CeedQFunctionContext ctx, bool *has_valid_data) { 99 CeedQFunctionContext_Cuda *impl; 100 101 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 102 *has_valid_data = impl && (impl->h_data || impl->d_data); 103 return CEED_ERROR_SUCCESS; 104 } 105 106 //------------------------------------------------------------------------------ 107 // Check if ctx has borrowed data 108 //------------------------------------------------------------------------------ 109 static inline int CeedQFunctionContextHasBorrowedDataOfType_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, 110 bool *has_borrowed_data_of_type) { 111 CeedQFunctionContext_Cuda *impl; 112 113 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 114 switch (mem_type) { 115 case CEED_MEM_HOST: 116 *has_borrowed_data_of_type = impl->h_data_borrowed; 117 break; 118 case CEED_MEM_DEVICE: 119 *has_borrowed_data_of_type = impl->d_data_borrowed; 120 break; 121 } 122 return CEED_ERROR_SUCCESS; 123 } 124 125 //------------------------------------------------------------------------------ 126 // Check if data of given type needs sync 127 //------------------------------------------------------------------------------ 128 static inline int CeedQFunctionContextNeedSync_Cuda(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 129 bool has_valid_data = true; 130 CeedQFunctionContext_Cuda *impl; 131 132 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 133 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data)); 134 switch (mem_type) { 135 case CEED_MEM_HOST: 136 *need_sync = has_valid_data && !impl->h_data; 137 break; 138 case CEED_MEM_DEVICE: 139 *need_sync = has_valid_data && !impl->d_data; 140 break; 141 } 142 return CEED_ERROR_SUCCESS; 143 } 144 145 //------------------------------------------------------------------------------ 146 // Set data from host 147 //------------------------------------------------------------------------------ 148 static int CeedQFunctionContextSetDataHost_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 149 CeedQFunctionContext_Cuda *impl; 150 151 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 152 153 CeedCallBackend(CeedFree(&impl->h_data_owned)); 154 switch (copy_mode) { 155 case CEED_COPY_VALUES: { 156 size_t ctx_size; 157 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 158 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 159 impl->h_data_borrowed = NULL; 160 impl->h_data = impl->h_data_owned; 161 memcpy(impl->h_data, data, ctx_size); 162 } break; 163 case CEED_OWN_POINTER: 164 impl->h_data_owned = data; 165 impl->h_data_borrowed = NULL; 166 impl->h_data = data; 167 break; 168 case CEED_USE_POINTER: 169 impl->h_data_borrowed = data; 170 impl->h_data = data; 171 break; 172 } 173 return CEED_ERROR_SUCCESS; 174 } 175 176 //------------------------------------------------------------------------------ 177 // Set data from device 178 //------------------------------------------------------------------------------ 179 static int CeedQFunctionContextSetDataDevice_Cuda(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 180 Ceed ceed; 181 CeedQFunctionContext_Cuda *impl; 182 183 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 184 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 185 186 CeedCallCuda(ceed, cudaFree(impl->d_data_owned)); 187 impl->d_data_owned = NULL; 188 switch (copy_mode) { 189 case CEED_COPY_VALUES: { 190 size_t ctx_size; 191 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 192 CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_data_owned, ctx_size)); 193 impl->d_data_borrowed = NULL; 194 impl->d_data = impl->d_data_owned; 195 CeedCallCuda(ceed, cudaMemcpy(impl->d_data, data, ctx_size, cudaMemcpyDeviceToDevice)); 196 } break; 197 case CEED_OWN_POINTER: 198 impl->d_data_owned = data; 199 impl->d_data_borrowed = NULL; 200 impl->d_data = data; 201 break; 202 case CEED_USE_POINTER: 203 impl->d_data_owned = NULL; 204 impl->d_data_borrowed = data; 205 impl->d_data = data; 206 break; 207 } 208 return CEED_ERROR_SUCCESS; 209 } 210 211 //------------------------------------------------------------------------------ 212 // Set the data used by a user context, 213 // freeing any previously allocated data if applicable 214 //------------------------------------------------------------------------------ 215 static int CeedQFunctionContextSetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 216 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx)); 217 switch (mem_type) { 218 case CEED_MEM_HOST: 219 return CeedQFunctionContextSetDataHost_Cuda(ctx, copy_mode, data); 220 case CEED_MEM_DEVICE: 221 return CeedQFunctionContextSetDataDevice_Cuda(ctx, copy_mode, data); 222 } 223 return CEED_ERROR_UNSUPPORTED; 224 } 225 226 //------------------------------------------------------------------------------ 227 // Take data 228 //------------------------------------------------------------------------------ 229 static int CeedQFunctionContextTakeData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 230 CeedQFunctionContext_Cuda *impl; 231 232 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 233 234 // Sync data to requested mem_type 235 bool need_sync = false; 236 CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync)); 237 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type)); 238 239 // Update pointer 240 switch (mem_type) { 241 case CEED_MEM_HOST: 242 *(void **)data = impl->h_data_borrowed; 243 impl->h_data_borrowed = NULL; 244 impl->h_data = NULL; 245 break; 246 case CEED_MEM_DEVICE: 247 *(void **)data = impl->d_data_borrowed; 248 impl->d_data_borrowed = NULL; 249 impl->d_data = NULL; 250 break; 251 } 252 return CEED_ERROR_SUCCESS; 253 } 254 255 //------------------------------------------------------------------------------ 256 // Core logic for GetData. 257 // If a different memory type is most up to date, this will perform a copy 258 //------------------------------------------------------------------------------ 259 static int CeedQFunctionContextGetDataCore_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 260 bool need_sync = false; 261 CeedQFunctionContext_Cuda *impl; 262 263 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 264 265 // Sync data to requested mem_type 266 CeedCallBackend(CeedQFunctionContextNeedSync_Cuda(ctx, mem_type, &need_sync)); 267 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Cuda(ctx, mem_type)); 268 269 // Update pointer 270 switch (mem_type) { 271 case CEED_MEM_HOST: 272 *(void **)data = impl->h_data; 273 break; 274 case CEED_MEM_DEVICE: 275 *(void **)data = impl->d_data; 276 break; 277 } 278 return CEED_ERROR_SUCCESS; 279 } 280 281 //------------------------------------------------------------------------------ 282 // Get read-only access to the data 283 //------------------------------------------------------------------------------ 284 static int CeedQFunctionContextGetDataRead_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 285 return CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data); 286 } 287 288 //------------------------------------------------------------------------------ 289 // Get read/write access to the data 290 //------------------------------------------------------------------------------ 291 static int CeedQFunctionContextGetData_Cuda(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 292 CeedQFunctionContext_Cuda *impl; 293 294 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 295 CeedCallBackend(CeedQFunctionContextGetDataCore_Cuda(ctx, mem_type, data)); 296 297 // Mark only pointer for requested memory as valid 298 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Cuda(ctx)); 299 switch (mem_type) { 300 case CEED_MEM_HOST: 301 impl->h_data = *(void **)data; 302 break; 303 case CEED_MEM_DEVICE: 304 impl->d_data = *(void **)data; 305 break; 306 } 307 return CEED_ERROR_SUCCESS; 308 } 309 310 //------------------------------------------------------------------------------ 311 // Destroy the user context 312 //------------------------------------------------------------------------------ 313 static int CeedQFunctionContextDestroy_Cuda(const CeedQFunctionContext ctx) { 314 CeedQFunctionContext_Cuda *impl; 315 316 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 317 CeedCallCuda(CeedQFunctionContextReturnCeed(ctx), cudaFree(impl->d_data_owned)); 318 CeedCallBackend(CeedFree(&impl->h_data_owned)); 319 CeedCallBackend(CeedFree(&impl)); 320 return CEED_ERROR_SUCCESS; 321 } 322 323 //------------------------------------------------------------------------------ 324 // QFunctionContext Create 325 //------------------------------------------------------------------------------ 326 int CeedQFunctionContextCreate_Cuda(CeedQFunctionContext ctx) { 327 CeedQFunctionContext_Cuda *impl; 328 Ceed ceed; 329 330 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 331 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Cuda)); 332 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Cuda)); 333 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Cuda)); 334 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Cuda)); 335 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Cuda)); 336 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Cuda)); 337 CeedCallBackend(CeedSetBackendFunction(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Cuda)); 338 CeedCallBackend(CeedCalloc(1, &impl)); 339 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 340 return CEED_ERROR_SUCCESS; 341 } 342 343 //------------------------------------------------------------------------------ 344