1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 11 #include <string> 12 #include <sycl/sycl.hpp> 13 14 #include "ceed-sycl-ref.hpp" 15 16 //------------------------------------------------------------------------------ 17 // Sync host to device 18 //------------------------------------------------------------------------------ 19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) { 20 Ceed ceed; 21 Ceed_Sycl *sycl_data; 22 size_t ctx_size; 23 CeedQFunctionContext_Sycl *impl; 24 25 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 26 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 27 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 29 30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 31 32 if (impl->d_data_borrowed) { 33 impl->d_data = impl->d_data_borrowed; 34 } else if (impl->d_data_owned) { 35 impl->d_data = impl->d_data_owned; 36 } else { 37 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context)); 38 impl->d_data = impl->d_data_owned; 39 } 40 std::vector<sycl::event> e; 41 42 if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()}; 43 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctx_size, e); 44 CeedCallSycl(ceed, copy_event.wait_and_throw()); 45 return CEED_ERROR_SUCCESS; 46 } 47 48 //------------------------------------------------------------------------------ 49 // Sync device to host 50 //------------------------------------------------------------------------------ 51 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) { 52 Ceed ceed; 53 Ceed_Sycl *sycl_data; 54 size_t ctx_size; 55 CeedQFunctionContext_Sycl *impl; 56 57 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 58 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 59 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 60 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 61 62 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 63 64 if (impl->h_data_borrowed) { 65 impl->h_data = impl->h_data_borrowed; 66 } else if (impl->h_data_owned) { 67 impl->h_data = impl->h_data_owned; 68 } else { 69 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 70 impl->h_data = impl->h_data_owned; 71 } 72 73 std::vector<sycl::event> e; 74 75 if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()}; 76 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctx_size, e); 77 CeedCallSycl(ceed, copy_event.wait_and_throw()); 78 return CEED_ERROR_SUCCESS; 79 } 80 81 //------------------------------------------------------------------------------ 82 // Sync data of type 83 //------------------------------------------------------------------------------ 84 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) { 85 switch (mem_type) { 86 case CEED_MEM_HOST: 87 return CeedQFunctionContextSyncD2H_Sycl(ctx); 88 case CEED_MEM_DEVICE: 89 return CeedQFunctionContextSyncH2D_Sycl(ctx); 90 } 91 return CEED_ERROR_UNSUPPORTED; 92 } 93 94 //------------------------------------------------------------------------------ 95 // Set all pointers as invalid 96 //------------------------------------------------------------------------------ 97 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) { 98 CeedQFunctionContext_Sycl *impl; 99 100 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 101 impl->h_data = NULL; 102 impl->d_data = NULL; 103 return CEED_ERROR_SUCCESS; 104 } 105 106 //------------------------------------------------------------------------------ 107 // Check if ctx has valid data 108 //------------------------------------------------------------------------------ 109 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) { 110 CeedQFunctionContext_Sycl *impl; 111 112 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 113 *has_valid_data = impl && (impl->h_data || impl->d_data); 114 return CEED_ERROR_SUCCESS; 115 } 116 117 //------------------------------------------------------------------------------ 118 // Check if ctx has borrowed data 119 //------------------------------------------------------------------------------ 120 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, 121 bool *has_borrowed_data_of_type) { 122 CeedQFunctionContext_Sycl *impl; 123 124 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 125 switch (mem_type) { 126 case CEED_MEM_HOST: 127 *has_borrowed_data_of_type = impl->h_data_borrowed; 128 break; 129 case CEED_MEM_DEVICE: 130 *has_borrowed_data_of_type = impl->d_data_borrowed; 131 break; 132 } 133 return CEED_ERROR_SUCCESS; 134 } 135 136 //------------------------------------------------------------------------------ 137 // Check if data of given type needs sync 138 //------------------------------------------------------------------------------ 139 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 140 bool has_valid_data = true; 141 CeedQFunctionContext_Sycl *impl; 142 143 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 144 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data)); 145 switch (mem_type) { 146 case CEED_MEM_HOST: 147 *need_sync = has_valid_data && !impl->h_data; 148 break; 149 case CEED_MEM_DEVICE: 150 *need_sync = has_valid_data && !impl->d_data; 151 break; 152 } 153 return CEED_ERROR_SUCCESS; 154 } 155 156 //------------------------------------------------------------------------------ 157 // Set data from host 158 //------------------------------------------------------------------------------ 159 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 160 CeedQFunctionContext_Sycl *impl; 161 162 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 163 CeedCallBackend(CeedFree(&impl->h_data_owned)); 164 switch (copy_mode) { 165 case CEED_COPY_VALUES: 166 size_t ctx_size; 167 168 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 169 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 170 impl->h_data_borrowed = NULL; 171 impl->h_data = impl->h_data_owned; 172 memcpy(impl->h_data, data, ctx_size); 173 break; 174 case CEED_OWN_POINTER: 175 impl->h_data_owned = data; 176 impl->h_data_borrowed = NULL; 177 impl->h_data = data; 178 break; 179 case CEED_USE_POINTER: 180 impl->h_data_borrowed = data; 181 impl->h_data = data; 182 break; 183 } 184 return CEED_ERROR_SUCCESS; 185 } 186 187 //------------------------------------------------------------------------------ 188 // Set data from device 189 //------------------------------------------------------------------------------ 190 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 191 Ceed ceed; 192 Ceed_Sycl *sycl_data; 193 CeedQFunctionContext_Sycl *impl; 194 195 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 196 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 197 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 198 199 std::vector<sycl::event> e; 200 201 if (!sycl_data->sycl_queue.is_in_order()) e = {sycl_data->sycl_queue.ext_oneapi_submit_barrier()}; 202 203 // Wait for all work to finish before freeing memory 204 if (impl->d_data_owned) { 205 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw()); 206 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context)); 207 impl->d_data_owned = NULL; 208 } 209 210 switch (copy_mode) { 211 case CEED_COPY_VALUES: { 212 size_t ctx_size; 213 214 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 215 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context)); 216 impl->d_data_borrowed = NULL; 217 impl->d_data = impl->d_data_owned; 218 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctx_size, e); 219 CeedCallSycl(ceed, copy_event.wait_and_throw()); 220 } break; 221 case CEED_OWN_POINTER: { 222 impl->d_data_owned = data; 223 impl->d_data_borrowed = NULL; 224 impl->d_data = data; 225 } break; 226 case CEED_USE_POINTER: { 227 impl->d_data_owned = NULL; 228 impl->d_data_borrowed = data; 229 impl->d_data = data; 230 } break; 231 } 232 return CEED_ERROR_SUCCESS; 233 } 234 235 //------------------------------------------------------------------------------ 236 // Set the data used by a user context, 237 // freeing any previously allocated data if applicable 238 //------------------------------------------------------------------------------ 239 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 240 Ceed ceed; 241 242 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 243 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx)); 244 switch (mem_type) { 245 case CEED_MEM_HOST: 246 return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data); 247 case CEED_MEM_DEVICE: 248 return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data); 249 } 250 return CEED_ERROR_UNSUPPORTED; 251 } 252 253 //------------------------------------------------------------------------------ 254 // Take data 255 //------------------------------------------------------------------------------ 256 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 257 Ceed ceed; 258 Ceed_Sycl *ceedSycl; 259 bool need_sync = false; 260 CeedQFunctionContext_Sycl *impl; 261 262 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 263 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 264 CeedCallBackend(CeedGetData(ceed, &ceedSycl)); 265 266 // Order queue if needed 267 if (!ceedSycl->sycl_queue.is_in_order()) ceedSycl->sycl_queue.ext_oneapi_submit_barrier(); 268 269 // Sync data to requested mem_type 270 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync)); 271 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type)); 272 273 // Update pointer 274 switch (mem_type) { 275 case CEED_MEM_HOST: 276 *(void **)data = impl->h_data_borrowed; 277 impl->h_data_borrowed = NULL; 278 impl->h_data = NULL; 279 break; 280 case CEED_MEM_DEVICE: 281 *(void **)data = impl->d_data_borrowed; 282 impl->d_data_borrowed = NULL; 283 impl->d_data = NULL; 284 break; 285 } 286 return CEED_ERROR_SUCCESS; 287 } 288 289 //------------------------------------------------------------------------------ 290 // Core logic for GetData. 291 // If a different memory type is most up to date, this will perform a copy 292 //------------------------------------------------------------------------------ 293 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 294 Ceed ceed; 295 bool need_sync = false; 296 CeedQFunctionContext_Sycl *impl; 297 298 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 299 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 300 301 // Sync data to requested mem_type 302 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync)); 303 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type)); 304 305 // Update pointer 306 switch (mem_type) { 307 case CEED_MEM_HOST: 308 *(void **)data = impl->h_data; 309 break; 310 case CEED_MEM_DEVICE: 311 *(void **)data = impl->d_data; 312 break; 313 } 314 return CEED_ERROR_SUCCESS; 315 } 316 317 //------------------------------------------------------------------------------ 318 // Get read-only access to the data 319 //------------------------------------------------------------------------------ 320 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 321 return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data); 322 } 323 324 //------------------------------------------------------------------------------ 325 // Get read/write access to the data 326 //------------------------------------------------------------------------------ 327 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 328 Ceed ceed; 329 CeedQFunctionContext_Sycl *impl; 330 331 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 332 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 333 CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data)); 334 335 // Mark only pointer for requested memory as valid 336 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx)); 337 switch (mem_type) { 338 case CEED_MEM_HOST: 339 impl->h_data = *(void **)data; 340 break; 341 case CEED_MEM_DEVICE: 342 impl->d_data = *(void **)data; 343 break; 344 } 345 return CEED_ERROR_SUCCESS; 346 } 347 348 //------------------------------------------------------------------------------ 349 // Destroy the user context 350 //------------------------------------------------------------------------------ 351 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) { 352 Ceed ceed; 353 Ceed_Sycl *sycl_data; 354 CeedQFunctionContext_Sycl *impl; 355 356 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 357 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 358 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 359 360 // Wait for all work to finish before freeing memory 361 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw()); 362 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context)); 363 CeedCallBackend(CeedFree(&impl->h_data_owned)); 364 CeedCallBackend(CeedFree(&impl)); 365 return CEED_ERROR_SUCCESS; 366 } 367 368 //------------------------------------------------------------------------------ 369 // QFunctionContext Create 370 //------------------------------------------------------------------------------ 371 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) { 372 Ceed ceed; 373 CeedQFunctionContext_Sycl *impl; 374 375 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 376 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl)); 377 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl)); 378 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl)); 379 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl)); 380 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl)); 381 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl)); 382 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl)); 383 CeedCallBackend(CeedCalloc(1, &impl)); 384 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 385 return CEED_ERROR_SUCCESS; 386 } 387 388 //------------------------------------------------------------------------------ 389