1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 11 #include <string> 12 #include <sycl/sycl.hpp> 13 14 #include "ceed-sycl-ref.hpp" 15 16 //------------------------------------------------------------------------------ 17 // Sync host to device 18 //------------------------------------------------------------------------------ 19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) { 20 CeedQFunctionContext_Sycl *impl; 21 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 22 Ceed ceed; 23 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 24 Ceed_Sycl *sycl_data; 25 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 26 27 if (!impl->h_data) { 28 // LCOV_EXCL_START 29 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 30 // LCOV_EXCL_STOP 31 } 32 33 size_t ctxsize; 34 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 35 36 if (impl->d_data_borrowed) { 37 impl->d_data = impl->d_data_borrowed; 38 } else if (impl->d_data_owned) { 39 impl->d_data = impl->d_data_owned; 40 } else { 41 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctxsize, sycl_data->sycl_device, sycl_data->sycl_context)); 42 impl->d_data = impl->d_data_owned; 43 } 44 // Order queue 45 sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier(); 46 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctxsize, {e}); 47 CeedCallSycl(ceed, copy_event.wait_and_throw()); 48 49 return CEED_ERROR_SUCCESS; 50 } 51 52 //------------------------------------------------------------------------------ 53 // Sync device to host 54 //------------------------------------------------------------------------------ 55 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) { 56 CeedQFunctionContext_Sycl *impl; 57 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 58 Ceed ceed; 59 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 60 Ceed_Sycl *sycl_data; 61 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 62 63 if (!impl->d_data) { 64 // LCOV_EXCL_START 65 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 66 // LCOV_EXCL_STOP 67 } 68 69 size_t ctxsize; 70 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 71 72 if (impl->h_data_borrowed) { 73 impl->h_data = impl->h_data_borrowed; 74 } else if (impl->h_data_owned) { 75 impl->h_data = impl->h_data_owned; 76 } else { 77 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 78 impl->h_data = impl->h_data_owned; 79 } 80 81 // Order queue 82 sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier(); 83 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctxsize, {e}); 84 CeedCallSycl(ceed, copy_event.wait_and_throw()); 85 86 return CEED_ERROR_SUCCESS; 87 } 88 89 //------------------------------------------------------------------------------ 90 // Sync data of type 91 //------------------------------------------------------------------------------ 92 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) { 93 switch (mem_type) { 94 case CEED_MEM_HOST: 95 return CeedQFunctionContextSyncD2H_Sycl(ctx); 96 case CEED_MEM_DEVICE: 97 return CeedQFunctionContextSyncH2D_Sycl(ctx); 98 } 99 return CEED_ERROR_UNSUPPORTED; 100 } 101 102 //------------------------------------------------------------------------------ 103 // Set all pointers as invalid 104 //------------------------------------------------------------------------------ 105 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) { 106 CeedQFunctionContext_Sycl *impl; 107 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 108 109 impl->h_data = NULL; 110 impl->d_data = NULL; 111 112 return CEED_ERROR_SUCCESS; 113 } 114 115 //------------------------------------------------------------------------------ 116 // Check if ctx has valid data 117 //------------------------------------------------------------------------------ 118 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) { 119 CeedQFunctionContext_Sycl *impl; 120 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 121 122 *has_valid_data = impl && (!!impl->h_data || !!impl->d_data); 123 124 return CEED_ERROR_SUCCESS; 125 } 126 127 //------------------------------------------------------------------------------ 128 // Check if ctx has borrowed data 129 //------------------------------------------------------------------------------ 130 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, 131 bool *has_borrowed_data_of_type) { 132 CeedQFunctionContext_Sycl *impl; 133 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 134 135 switch (mem_type) { 136 case CEED_MEM_HOST: 137 *has_borrowed_data_of_type = !!impl->h_data_borrowed; 138 break; 139 case CEED_MEM_DEVICE: 140 *has_borrowed_data_of_type = !!impl->d_data_borrowed; 141 break; 142 } 143 144 return CEED_ERROR_SUCCESS; 145 } 146 147 //------------------------------------------------------------------------------ 148 // Check if data of given type needs sync 149 //------------------------------------------------------------------------------ 150 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 151 CeedQFunctionContext_Sycl *impl; 152 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 153 154 bool has_valid_data = true; 155 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data)); 156 switch (mem_type) { 157 case CEED_MEM_HOST: 158 *need_sync = has_valid_data && !impl->h_data; 159 break; 160 case CEED_MEM_DEVICE: 161 *need_sync = has_valid_data && !impl->d_data; 162 break; 163 } 164 165 return CEED_ERROR_SUCCESS; 166 } 167 168 //------------------------------------------------------------------------------ 169 // Set data from host 170 //------------------------------------------------------------------------------ 171 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 172 CeedQFunctionContext_Sycl *impl; 173 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 174 175 CeedCallBackend(CeedFree(&impl->h_data_owned)); 176 switch (copy_mode) { 177 case CEED_COPY_VALUES: 178 size_t ctxsize; 179 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 180 CeedCallBackend(CeedMallocArray(1, ctxsize, &impl->h_data_owned)); 181 impl->h_data_borrowed = NULL; 182 impl->h_data = impl->h_data_owned; 183 memcpy(impl->h_data, data, ctxsize); 184 break; 185 case CEED_OWN_POINTER: 186 impl->h_data_owned = data; 187 impl->h_data_borrowed = NULL; 188 impl->h_data = data; 189 break; 190 case CEED_USE_POINTER: 191 impl->h_data_borrowed = data; 192 impl->h_data = data; 193 break; 194 } 195 196 return CEED_ERROR_SUCCESS; 197 } 198 199 //------------------------------------------------------------------------------ 200 // Set data from device 201 //------------------------------------------------------------------------------ 202 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 203 CeedQFunctionContext_Sycl *impl; 204 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 205 Ceed ceed; 206 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 207 Ceed_Sycl *sycl_data; 208 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 209 210 // Order queue 211 sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier(); 212 213 // Wait for all work to finish before freeing memory 214 if (impl->d_data_owned) { 215 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw()); 216 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context)); 217 impl->d_data_owned = NULL; 218 } 219 220 switch (copy_mode) { 221 case CEED_COPY_VALUES: { 222 size_t ctxsize; 223 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctxsize)); 224 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctxsize, sycl_data->sycl_device, sycl_data->sycl_context)); 225 impl->d_data_borrowed = NULL; 226 impl->d_data = impl->d_data_owned; 227 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctxsize, {e}); 228 CeedCallSycl(ceed, copy_event.wait_and_throw()); 229 } break; 230 case CEED_OWN_POINTER: { 231 impl->d_data_owned = data; 232 impl->d_data_borrowed = NULL; 233 impl->d_data = data; 234 } break; 235 case CEED_USE_POINTER: { 236 impl->d_data_owned = NULL; 237 impl->d_data_borrowed = data; 238 impl->d_data = data; 239 } break; 240 } 241 242 return CEED_ERROR_SUCCESS; 243 } 244 245 //------------------------------------------------------------------------------ 246 // Set the data used by a user context, 247 // freeing any previously allocated data if applicable 248 //------------------------------------------------------------------------------ 249 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 250 Ceed ceed; 251 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 252 253 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx)); 254 switch (mem_type) { 255 case CEED_MEM_HOST: 256 return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data); 257 case CEED_MEM_DEVICE: 258 return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data); 259 } 260 261 return CEED_ERROR_UNSUPPORTED; 262 } 263 264 //------------------------------------------------------------------------------ 265 // Take data 266 //------------------------------------------------------------------------------ 267 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 268 Ceed ceed; 269 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 270 CeedQFunctionContext_Sycl *impl; 271 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 272 273 Ceed_Sycl *ceedSycl; 274 CeedCallBackend(CeedGetData(ceed, &ceedSycl)); 275 276 // Order queue 277 ceedSycl->sycl_queue.ext_oneapi_submit_barrier(); 278 279 // Sync data to requested mem_type 280 bool need_sync = false; 281 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync)); 282 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type)); 283 284 // Update pointer 285 switch (mem_type) { 286 case CEED_MEM_HOST: 287 *(void **)data = impl->h_data_borrowed; 288 impl->h_data_borrowed = NULL; 289 impl->h_data = NULL; 290 break; 291 case CEED_MEM_DEVICE: 292 *(void **)data = impl->d_data_borrowed; 293 impl->d_data_borrowed = NULL; 294 impl->d_data = NULL; 295 break; 296 } 297 298 return CEED_ERROR_SUCCESS; 299 } 300 301 //------------------------------------------------------------------------------ 302 // Core logic for GetData. 303 // If a different memory type is most up to date, this will perform a copy 304 //------------------------------------------------------------------------------ 305 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 306 Ceed ceed; 307 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 308 CeedQFunctionContext_Sycl *impl; 309 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 310 311 // Sync data to requested mem_type 312 bool need_sync = false; 313 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync)); 314 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type)); 315 316 // Update pointer 317 switch (mem_type) { 318 case CEED_MEM_HOST: 319 *(void **)data = impl->h_data; 320 break; 321 case CEED_MEM_DEVICE: 322 *(void **)data = impl->d_data; 323 break; 324 } 325 326 return CEED_ERROR_SUCCESS; 327 } 328 329 //------------------------------------------------------------------------------ 330 // Get read-only access to the data 331 //------------------------------------------------------------------------------ 332 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 333 return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data); 334 } 335 336 //------------------------------------------------------------------------------ 337 // Get read/write access to the data 338 //------------------------------------------------------------------------------ 339 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 340 CeedQFunctionContext_Sycl *impl; 341 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 342 Ceed ceed; 343 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 344 345 CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data)); 346 347 // Mark only pointer for requested memory as valid 348 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx)); 349 switch (mem_type) { 350 case CEED_MEM_HOST: 351 impl->h_data = *(void **)data; 352 break; 353 case CEED_MEM_DEVICE: 354 impl->d_data = *(void **)data; 355 break; 356 } 357 358 return CEED_ERROR_SUCCESS; 359 } 360 361 //------------------------------------------------------------------------------ 362 // Destroy the user context 363 //------------------------------------------------------------------------------ 364 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) { 365 Ceed ceed; 366 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 367 CeedQFunctionContext_Sycl *impl; 368 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 369 Ceed_Sycl *sycl_data; 370 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 371 372 // Wait for all work to finish before freeing memory 373 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw()); 374 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context)); 375 CeedCallBackend(CeedFree(&impl->h_data_owned)); 376 CeedCallBackend(CeedFree(&impl)); 377 378 return CEED_ERROR_SUCCESS; 379 } 380 381 //------------------------------------------------------------------------------ 382 // QFunctionContext Create 383 //------------------------------------------------------------------------------ 384 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) { 385 CeedQFunctionContext_Sycl *impl; 386 Ceed ceed; 387 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 388 389 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl)); 390 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl)); 391 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl)); 392 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl)); 393 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl)); 394 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl)); 395 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl)); 396 397 CeedCallBackend(CeedCalloc(1, &impl)); 398 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 399 400 return CEED_ERROR_SUCCESS; 401 } 402 //------------------------------------------------------------------------------ 403