1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 11 #include <string> 12 #include <sycl/sycl.hpp> 13 14 #include "ceed-sycl-ref.hpp" 15 16 //------------------------------------------------------------------------------ 17 // Sync host to device 18 //------------------------------------------------------------------------------ 19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) { 20 Ceed ceed; 21 Ceed_Sycl *sycl_data; 22 size_t ctx_size; 23 CeedQFunctionContext_Sycl *impl; 24 25 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 26 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 27 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 28 29 if (!impl->h_data) { 30 // LCOV_EXCL_START 31 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 32 // LCOV_EXCL_STOP 33 } 34 35 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 36 37 if (impl->d_data_borrowed) { 38 impl->d_data = impl->d_data_borrowed; 39 } else if (impl->d_data_owned) { 40 impl->d_data = impl->d_data_owned; 41 } else { 42 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context)); 43 impl->d_data = impl->d_data_owned; 44 } 45 // Order queue 46 sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier(); 47 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctx_size, {e}); 48 CeedCallSycl(ceed, copy_event.wait_and_throw()); 49 return CEED_ERROR_SUCCESS; 50 } 51 52 //------------------------------------------------------------------------------ 53 // Sync device to host 54 //------------------------------------------------------------------------------ 55 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) { 56 Ceed ceed; 57 Ceed_Sycl *sycl_data; 58 size_t ctx_size; 59 CeedQFunctionContext_Sycl *impl; 60 61 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 62 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 63 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 64 65 if (!impl->d_data) { 66 // LCOV_EXCL_START 67 return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 68 // LCOV_EXCL_STOP 69 } 70 71 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 72 73 if (impl->h_data_borrowed) { 74 impl->h_data = impl->h_data_borrowed; 75 } else if (impl->h_data_owned) { 76 impl->h_data = impl->h_data_owned; 77 } else { 78 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 79 impl->h_data = impl->h_data_owned; 80 } 81 82 // Order queue 83 sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier(); 84 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctx_size, {e}); 85 CeedCallSycl(ceed, copy_event.wait_and_throw()); 86 return CEED_ERROR_SUCCESS; 87 } 88 89 //------------------------------------------------------------------------------ 90 // Sync data of type 91 //------------------------------------------------------------------------------ 92 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) { 93 switch (mem_type) { 94 case CEED_MEM_HOST: 95 return CeedQFunctionContextSyncD2H_Sycl(ctx); 96 case CEED_MEM_DEVICE: 97 return CeedQFunctionContextSyncH2D_Sycl(ctx); 98 } 99 return CEED_ERROR_UNSUPPORTED; 100 } 101 102 //------------------------------------------------------------------------------ 103 // Set all pointers as invalid 104 //------------------------------------------------------------------------------ 105 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) { 106 CeedQFunctionContext_Sycl *impl; 107 108 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 109 impl->h_data = NULL; 110 impl->d_data = NULL; 111 return CEED_ERROR_SUCCESS; 112 } 113 114 //------------------------------------------------------------------------------ 115 // Check if ctx has valid data 116 //------------------------------------------------------------------------------ 117 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) { 118 CeedQFunctionContext_Sycl *impl; 119 120 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 121 *has_valid_data = impl && (impl->h_data || impl->d_data); 122 return CEED_ERROR_SUCCESS; 123 } 124 125 //------------------------------------------------------------------------------ 126 // Check if ctx has borrowed data 127 //------------------------------------------------------------------------------ 128 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, 129 bool *has_borrowed_data_of_type) { 130 CeedQFunctionContext_Sycl *impl; 131 132 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 133 switch (mem_type) { 134 case CEED_MEM_HOST: 135 *has_borrowed_data_of_type = impl->h_data_borrowed; 136 break; 137 case CEED_MEM_DEVICE: 138 *has_borrowed_data_of_type = impl->d_data_borrowed; 139 break; 140 } 141 return CEED_ERROR_SUCCESS; 142 } 143 144 //------------------------------------------------------------------------------ 145 // Check if data of given type needs sync 146 //------------------------------------------------------------------------------ 147 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 148 bool has_valid_data = true; 149 CeedQFunctionContext_Sycl *impl; 150 151 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 152 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data)); 153 switch (mem_type) { 154 case CEED_MEM_HOST: 155 *need_sync = has_valid_data && !impl->h_data; 156 break; 157 case CEED_MEM_DEVICE: 158 *need_sync = has_valid_data && !impl->d_data; 159 break; 160 } 161 return CEED_ERROR_SUCCESS; 162 } 163 164 //------------------------------------------------------------------------------ 165 // Set data from host 166 //------------------------------------------------------------------------------ 167 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 168 CeedQFunctionContext_Sycl *impl; 169 170 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 171 CeedCallBackend(CeedFree(&impl->h_data_owned)); 172 switch (copy_mode) { 173 case CEED_COPY_VALUES: 174 size_t ctx_size; 175 176 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 177 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 178 impl->h_data_borrowed = NULL; 179 impl->h_data = impl->h_data_owned; 180 memcpy(impl->h_data, data, ctx_size); 181 break; 182 case CEED_OWN_POINTER: 183 impl->h_data_owned = data; 184 impl->h_data_borrowed = NULL; 185 impl->h_data = data; 186 break; 187 case CEED_USE_POINTER: 188 impl->h_data_borrowed = data; 189 impl->h_data = data; 190 break; 191 } 192 return CEED_ERROR_SUCCESS; 193 } 194 195 //------------------------------------------------------------------------------ 196 // Set data from device 197 //------------------------------------------------------------------------------ 198 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 199 Ceed ceed; 200 Ceed_Sycl *sycl_data; 201 CeedQFunctionContext_Sycl *impl; 202 203 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 204 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 205 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 206 207 // Order queue 208 sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier(); 209 210 // Wait for all work to finish before freeing memory 211 if (impl->d_data_owned) { 212 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw()); 213 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context)); 214 impl->d_data_owned = NULL; 215 } 216 217 switch (copy_mode) { 218 case CEED_COPY_VALUES: { 219 size_t ctx_size; 220 221 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 222 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context)); 223 impl->d_data_borrowed = NULL; 224 impl->d_data = impl->d_data_owned; 225 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctx_size, {e}); 226 CeedCallSycl(ceed, copy_event.wait_and_throw()); 227 } break; 228 case CEED_OWN_POINTER: { 229 impl->d_data_owned = data; 230 impl->d_data_borrowed = NULL; 231 impl->d_data = data; 232 } break; 233 case CEED_USE_POINTER: { 234 impl->d_data_owned = NULL; 235 impl->d_data_borrowed = data; 236 impl->d_data = data; 237 } break; 238 } 239 return CEED_ERROR_SUCCESS; 240 } 241 242 //------------------------------------------------------------------------------ 243 // Set the data used by a user context, 244 // freeing any previously allocated data if applicable 245 //------------------------------------------------------------------------------ 246 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 247 Ceed ceed; 248 249 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 250 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx)); 251 switch (mem_type) { 252 case CEED_MEM_HOST: 253 return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data); 254 case CEED_MEM_DEVICE: 255 return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data); 256 } 257 return CEED_ERROR_UNSUPPORTED; 258 } 259 260 //------------------------------------------------------------------------------ 261 // Take data 262 //------------------------------------------------------------------------------ 263 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 264 Ceed ceed; 265 Ceed_Sycl *ceedSycl; 266 bool need_sync = false; 267 CeedQFunctionContext_Sycl *impl; 268 269 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 270 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 271 CeedCallBackend(CeedGetData(ceed, &ceedSycl)); 272 273 // Order queue 274 ceedSycl->sycl_queue.ext_oneapi_submit_barrier(); 275 276 // Sync data to requested mem_type 277 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync)); 278 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type)); 279 280 // Update pointer 281 switch (mem_type) { 282 case CEED_MEM_HOST: 283 *(void **)data = impl->h_data_borrowed; 284 impl->h_data_borrowed = NULL; 285 impl->h_data = NULL; 286 break; 287 case CEED_MEM_DEVICE: 288 *(void **)data = impl->d_data_borrowed; 289 impl->d_data_borrowed = NULL; 290 impl->d_data = NULL; 291 break; 292 } 293 return CEED_ERROR_SUCCESS; 294 } 295 296 //------------------------------------------------------------------------------ 297 // Core logic for GetData. 298 // If a different memory type is most up to date, this will perform a copy 299 //------------------------------------------------------------------------------ 300 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 301 Ceed ceed; 302 bool need_sync = false; 303 CeedQFunctionContext_Sycl *impl; 304 305 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 306 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 307 308 // Sync data to requested mem_type 309 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync)); 310 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type)); 311 312 // Update pointer 313 switch (mem_type) { 314 case CEED_MEM_HOST: 315 *(void **)data = impl->h_data; 316 break; 317 case CEED_MEM_DEVICE: 318 *(void **)data = impl->d_data; 319 break; 320 } 321 return CEED_ERROR_SUCCESS; 322 } 323 324 //------------------------------------------------------------------------------ 325 // Get read-only access to the data 326 //------------------------------------------------------------------------------ 327 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 328 return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data); 329 } 330 331 //------------------------------------------------------------------------------ 332 // Get read/write access to the data 333 //------------------------------------------------------------------------------ 334 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 335 Ceed ceed; 336 CeedQFunctionContext_Sycl *impl; 337 338 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 339 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 340 CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data)); 341 342 // Mark only pointer for requested memory as valid 343 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx)); 344 switch (mem_type) { 345 case CEED_MEM_HOST: 346 impl->h_data = *(void **)data; 347 break; 348 case CEED_MEM_DEVICE: 349 impl->d_data = *(void **)data; 350 break; 351 } 352 return CEED_ERROR_SUCCESS; 353 } 354 355 //------------------------------------------------------------------------------ 356 // Destroy the user context 357 //------------------------------------------------------------------------------ 358 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) { 359 Ceed ceed; 360 Ceed_Sycl *sycl_data; 361 CeedQFunctionContext_Sycl *impl; 362 363 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 364 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 365 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 366 367 // Wait for all work to finish before freeing memory 368 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw()); 369 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context)); 370 CeedCallBackend(CeedFree(&impl->h_data_owned)); 371 CeedCallBackend(CeedFree(&impl)); 372 return CEED_ERROR_SUCCESS; 373 } 374 375 //------------------------------------------------------------------------------ 376 // QFunctionContext Create 377 //------------------------------------------------------------------------------ 378 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) { 379 Ceed ceed; 380 CeedQFunctionContext_Sycl *impl; 381 382 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 383 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl)); 384 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl)); 385 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl)); 386 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl)); 387 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl)); 388 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl)); 389 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl)); 390 CeedCallBackend(CeedCalloc(1, &impl)); 391 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 392 return CEED_ERROR_SUCCESS; 393 } 394 395 //------------------------------------------------------------------------------ 396