1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 11 #include <string> 12 #include <sycl/sycl.hpp> 13 14 #include "ceed-sycl-ref.hpp" 15 16 //------------------------------------------------------------------------------ 17 // Sync host to device 18 //------------------------------------------------------------------------------ 19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) { 20 Ceed ceed; 21 Ceed_Sycl *sycl_data; 22 size_t ctx_size; 23 CeedQFunctionContext_Sycl *impl; 24 25 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 26 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 27 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 29 30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 31 32 if (impl->d_data_borrowed) { 33 impl->d_data = impl->d_data_borrowed; 34 } else if (impl->d_data_owned) { 35 impl->d_data = impl->d_data_owned; 36 } else { 37 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context)); 38 impl->d_data = impl->d_data_owned; 39 } 40 std::vector<sycl::event> e; 41 42 if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()}; 43 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctx_size, e); 44 CeedCallSycl(ceed, copy_event.wait_and_throw()); 45 CeedCallBackend(CeedDestroy(&ceed)); 46 return CEED_ERROR_SUCCESS; 47 } 48 49 //------------------------------------------------------------------------------ 50 // Sync device to host 51 //------------------------------------------------------------------------------ 52 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) { 53 Ceed ceed; 54 Ceed_Sycl *sycl_data; 55 size_t ctx_size; 56 CeedQFunctionContext_Sycl *impl; 57 58 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 59 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 60 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 61 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 62 63 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 64 65 if (impl->h_data_borrowed) { 66 impl->h_data = impl->h_data_borrowed; 67 } else if (impl->h_data_owned) { 68 impl->h_data = impl->h_data_owned; 69 } else { 70 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 71 impl->h_data = impl->h_data_owned; 72 } 73 74 std::vector<sycl::event> e; 75 76 if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()}; 77 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctx_size, e); 78 CeedCallSycl(ceed, copy_event.wait_and_throw()); 79 CeedCallBackend(CeedDestroy(&ceed)); 80 return CEED_ERROR_SUCCESS; 81 } 82 83 //------------------------------------------------------------------------------ 84 // Sync data of type 85 //------------------------------------------------------------------------------ 86 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) { 87 switch (mem_type) { 88 case CEED_MEM_HOST: 89 return CeedQFunctionContextSyncD2H_Sycl(ctx); 90 case CEED_MEM_DEVICE: 91 return CeedQFunctionContextSyncH2D_Sycl(ctx); 92 } 93 return CEED_ERROR_UNSUPPORTED; 94 } 95 96 //------------------------------------------------------------------------------ 97 // Set all pointers as invalid 98 //------------------------------------------------------------------------------ 99 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) { 100 CeedQFunctionContext_Sycl *impl; 101 102 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 103 impl->h_data = NULL; 104 impl->d_data = NULL; 105 return CEED_ERROR_SUCCESS; 106 } 107 108 //------------------------------------------------------------------------------ 109 // Check if ctx has valid data 110 //------------------------------------------------------------------------------ 111 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) { 112 CeedQFunctionContext_Sycl *impl; 113 114 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 115 *has_valid_data = impl && (impl->h_data || impl->d_data); 116 return CEED_ERROR_SUCCESS; 117 } 118 119 //------------------------------------------------------------------------------ 120 // Check if ctx has borrowed data 121 //------------------------------------------------------------------------------ 122 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, 123 bool *has_borrowed_data_of_type) { 124 CeedQFunctionContext_Sycl *impl; 125 126 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 127 switch (mem_type) { 128 case CEED_MEM_HOST: 129 *has_borrowed_data_of_type = impl->h_data_borrowed; 130 break; 131 case CEED_MEM_DEVICE: 132 *has_borrowed_data_of_type = impl->d_data_borrowed; 133 break; 134 } 135 return CEED_ERROR_SUCCESS; 136 } 137 138 //------------------------------------------------------------------------------ 139 // Check if data of given type needs sync 140 //------------------------------------------------------------------------------ 141 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 142 bool has_valid_data = true; 143 CeedQFunctionContext_Sycl *impl; 144 145 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 146 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data)); 147 switch (mem_type) { 148 case CEED_MEM_HOST: 149 *need_sync = has_valid_data && !impl->h_data; 150 break; 151 case CEED_MEM_DEVICE: 152 *need_sync = has_valid_data && !impl->d_data; 153 break; 154 } 155 return CEED_ERROR_SUCCESS; 156 } 157 158 //------------------------------------------------------------------------------ 159 // Set data from host 160 //------------------------------------------------------------------------------ 161 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 162 CeedQFunctionContext_Sycl *impl; 163 164 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 165 CeedCallBackend(CeedFree(&impl->h_data_owned)); 166 switch (copy_mode) { 167 case CEED_COPY_VALUES: 168 size_t ctx_size; 169 170 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 171 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 172 impl->h_data_borrowed = NULL; 173 impl->h_data = impl->h_data_owned; 174 memcpy(impl->h_data, data, ctx_size); 175 break; 176 case CEED_OWN_POINTER: 177 impl->h_data_owned = data; 178 impl->h_data_borrowed = NULL; 179 impl->h_data = data; 180 break; 181 case CEED_USE_POINTER: 182 impl->h_data_borrowed = data; 183 impl->h_data = data; 184 break; 185 } 186 return CEED_ERROR_SUCCESS; 187 } 188 189 //------------------------------------------------------------------------------ 190 // Set data from device 191 //------------------------------------------------------------------------------ 192 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 193 Ceed ceed; 194 Ceed_Sycl *sycl_data; 195 CeedQFunctionContext_Sycl *impl; 196 197 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 198 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 199 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 200 201 std::vector<sycl::event> e; 202 203 if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()}; 204 205 // Wait for all work to finish before freeing memory 206 if (impl->d_data_owned) { 207 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw()); 208 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context)); 209 impl->d_data_owned = NULL; 210 } 211 212 switch (copy_mode) { 213 case CEED_COPY_VALUES: { 214 size_t ctx_size; 215 216 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 217 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context)); 218 impl->d_data_borrowed = NULL; 219 impl->d_data = impl->d_data_owned; 220 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctx_size, e); 221 CeedCallSycl(ceed, copy_event.wait_and_throw()); 222 } break; 223 case CEED_OWN_POINTER: { 224 impl->d_data_owned = data; 225 impl->d_data_borrowed = NULL; 226 impl->d_data = data; 227 } break; 228 case CEED_USE_POINTER: { 229 impl->d_data_owned = NULL; 230 impl->d_data_borrowed = data; 231 impl->d_data = data; 232 } break; 233 } 234 CeedCallBackend(CeedDestroy(&ceed)); 235 return CEED_ERROR_SUCCESS; 236 } 237 238 //------------------------------------------------------------------------------ 239 // Set the data used by a user context, 240 // freeing any previously allocated data if applicable 241 //------------------------------------------------------------------------------ 242 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 243 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx)); 244 switch (mem_type) { 245 case CEED_MEM_HOST: 246 return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data); 247 case CEED_MEM_DEVICE: 248 return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data); 249 } 250 return CEED_ERROR_UNSUPPORTED; 251 } 252 253 //------------------------------------------------------------------------------ 254 // Take data 255 //------------------------------------------------------------------------------ 256 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 257 Ceed ceed; 258 Ceed_Sycl *ceedSycl; 259 bool need_sync = false; 260 CeedQFunctionContext_Sycl *impl; 261 262 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 263 CeedCallBackend(CeedGetData(ceed, &ceedSycl)); 264 CeedCallBackend(CeedDestroy(&ceed)); 265 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 266 267 // Order queue if needed 268 if (!ceedSycl->sycl_queue.is_in_order()) ceedSycl->sycl_queue.ext_oneapi_submit_barrier(); 269 270 // Sync data to requested mem_type 271 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync)); 272 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type)); 273 274 // Update pointer 275 switch (mem_type) { 276 case CEED_MEM_HOST: 277 *(void **)data = impl->h_data_borrowed; 278 impl->h_data_borrowed = NULL; 279 impl->h_data = NULL; 280 break; 281 case CEED_MEM_DEVICE: 282 *(void **)data = impl->d_data_borrowed; 283 impl->d_data_borrowed = NULL; 284 impl->d_data = NULL; 285 break; 286 } 287 return CEED_ERROR_SUCCESS; 288 } 289 290 //------------------------------------------------------------------------------ 291 // Core logic for GetData. 292 // If a different memory type is most up to date, this will perform a copy 293 //------------------------------------------------------------------------------ 294 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 295 bool need_sync = false; 296 CeedQFunctionContext_Sycl *impl; 297 298 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 299 300 // Sync data to requested mem_type 301 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync)); 302 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type)); 303 304 // Update pointer 305 switch (mem_type) { 306 case CEED_MEM_HOST: 307 *(void **)data = impl->h_data; 308 break; 309 case CEED_MEM_DEVICE: 310 *(void **)data = impl->d_data; 311 break; 312 } 313 return CEED_ERROR_SUCCESS; 314 } 315 316 //------------------------------------------------------------------------------ 317 // Get read-only access to the data 318 //------------------------------------------------------------------------------ 319 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 320 return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data); 321 } 322 323 //------------------------------------------------------------------------------ 324 // Get read/write access to the data 325 //------------------------------------------------------------------------------ 326 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 327 CeedQFunctionContext_Sycl *impl; 328 329 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 330 CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data)); 331 332 // Mark only pointer for requested memory as valid 333 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx)); 334 switch (mem_type) { 335 case CEED_MEM_HOST: 336 impl->h_data = *(void **)data; 337 break; 338 case CEED_MEM_DEVICE: 339 impl->d_data = *(void **)data; 340 break; 341 } 342 return CEED_ERROR_SUCCESS; 343 } 344 345 //------------------------------------------------------------------------------ 346 // Destroy the user context 347 //------------------------------------------------------------------------------ 348 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) { 349 Ceed ceed; 350 Ceed_Sycl *sycl_data; 351 CeedQFunctionContext_Sycl *impl; 352 353 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 354 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 355 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 356 357 // Wait for all work to finish before freeing memory 358 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw()); 359 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context)); 360 CeedCallBackend(CeedDestroy(&ceed)); 361 CeedCallBackend(CeedFree(&impl->h_data_owned)); 362 CeedCallBackend(CeedFree(&impl)); 363 return CEED_ERROR_SUCCESS; 364 } 365 366 //------------------------------------------------------------------------------ 367 // QFunctionContext Create 368 //------------------------------------------------------------------------------ 369 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) { 370 Ceed ceed; 371 CeedQFunctionContext_Sycl *impl; 372 373 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 374 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl)); 375 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl)); 376 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl)); 377 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl)); 378 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl)); 379 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl)); 380 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl)); 381 CeedCallBackend(CeedDestroy(&ceed)); 382 CeedCallBackend(CeedCalloc(1, &impl)); 383 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 384 return CEED_ERROR_SUCCESS; 385 } 386 387 //------------------------------------------------------------------------------ 388