1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 11 #include <string> 12 #include <sycl/sycl.hpp> 13 14 #include "ceed-sycl-ref.hpp" 15 16 //------------------------------------------------------------------------------ 17 // Sync host to device 18 //------------------------------------------------------------------------------ 19 static inline int CeedQFunctionContextSyncH2D_Sycl(const CeedQFunctionContext ctx) { 20 Ceed ceed; 21 Ceed_Sycl *sycl_data; 22 size_t ctx_size; 23 CeedQFunctionContext_Sycl *impl; 24 25 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 26 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 27 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 28 CeedCheck(impl->h_data, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device"); 29 30 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 31 32 if (impl->d_data_borrowed) { 33 impl->d_data = impl->d_data_borrowed; 34 } else if (impl->d_data_owned) { 35 impl->d_data = impl->d_data_owned; 36 } else { 37 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context)); 38 impl->d_data = impl->d_data_owned; 39 } 40 // Order queue 41 sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier(); 42 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, impl->h_data, ctx_size, {e}); 43 CeedCallSycl(ceed, copy_event.wait_and_throw()); 44 return CEED_ERROR_SUCCESS; 45 } 46 47 //------------------------------------------------------------------------------ 48 // Sync device to host 49 //------------------------------------------------------------------------------ 50 static inline int CeedQFunctionContextSyncD2H_Sycl(const CeedQFunctionContext ctx) { 51 Ceed ceed; 52 Ceed_Sycl *sycl_data; 53 size_t ctx_size; 54 CeedQFunctionContext_Sycl *impl; 55 56 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 57 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 58 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 59 CeedCheck(impl->d_data, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host"); 60 61 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 62 63 if (impl->h_data_borrowed) { 64 impl->h_data = impl->h_data_borrowed; 65 } else if (impl->h_data_owned) { 66 impl->h_data = impl->h_data_owned; 67 } else { 68 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 69 impl->h_data = impl->h_data_owned; 70 } 71 72 // Order queue 73 sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier(); 74 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->h_data, impl->d_data, ctx_size, {e}); 75 CeedCallSycl(ceed, copy_event.wait_and_throw()); 76 return CEED_ERROR_SUCCESS; 77 } 78 79 //------------------------------------------------------------------------------ 80 // Sync data of type 81 //------------------------------------------------------------------------------ 82 static inline int CeedQFunctionContextSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type) { 83 switch (mem_type) { 84 case CEED_MEM_HOST: 85 return CeedQFunctionContextSyncD2H_Sycl(ctx); 86 case CEED_MEM_DEVICE: 87 return CeedQFunctionContextSyncH2D_Sycl(ctx); 88 } 89 return CEED_ERROR_UNSUPPORTED; 90 } 91 92 //------------------------------------------------------------------------------ 93 // Set all pointers as invalid 94 //------------------------------------------------------------------------------ 95 static inline int CeedQFunctionContextSetAllInvalid_Sycl(const CeedQFunctionContext ctx) { 96 CeedQFunctionContext_Sycl *impl; 97 98 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 99 impl->h_data = NULL; 100 impl->d_data = NULL; 101 return CEED_ERROR_SUCCESS; 102 } 103 104 //------------------------------------------------------------------------------ 105 // Check if ctx has valid data 106 //------------------------------------------------------------------------------ 107 static inline int CeedQFunctionContextHasValidData_Sycl(const CeedQFunctionContext ctx, bool *has_valid_data) { 108 CeedQFunctionContext_Sycl *impl; 109 110 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 111 *has_valid_data = impl && (impl->h_data || impl->d_data); 112 return CEED_ERROR_SUCCESS; 113 } 114 115 //------------------------------------------------------------------------------ 116 // Check if ctx has borrowed data 117 //------------------------------------------------------------------------------ 118 static inline int CeedQFunctionContextHasBorrowedDataOfType_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, 119 bool *has_borrowed_data_of_type) { 120 CeedQFunctionContext_Sycl *impl; 121 122 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 123 switch (mem_type) { 124 case CEED_MEM_HOST: 125 *has_borrowed_data_of_type = impl->h_data_borrowed; 126 break; 127 case CEED_MEM_DEVICE: 128 *has_borrowed_data_of_type = impl->d_data_borrowed; 129 break; 130 } 131 return CEED_ERROR_SUCCESS; 132 } 133 134 //------------------------------------------------------------------------------ 135 // Check if data of given type needs sync 136 //------------------------------------------------------------------------------ 137 static inline int CeedQFunctionContextNeedSync_Sycl(const CeedQFunctionContext ctx, CeedMemType mem_type, bool *need_sync) { 138 bool has_valid_data = true; 139 CeedQFunctionContext_Sycl *impl; 140 141 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 142 CeedCallBackend(CeedQFunctionContextHasValidData(ctx, &has_valid_data)); 143 switch (mem_type) { 144 case CEED_MEM_HOST: 145 *need_sync = has_valid_data && !impl->h_data; 146 break; 147 case CEED_MEM_DEVICE: 148 *need_sync = has_valid_data && !impl->d_data; 149 break; 150 } 151 return CEED_ERROR_SUCCESS; 152 } 153 154 //------------------------------------------------------------------------------ 155 // Set data from host 156 //------------------------------------------------------------------------------ 157 static int CeedQFunctionContextSetDataHost_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 158 CeedQFunctionContext_Sycl *impl; 159 160 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 161 CeedCallBackend(CeedFree(&impl->h_data_owned)); 162 switch (copy_mode) { 163 case CEED_COPY_VALUES: 164 size_t ctx_size; 165 166 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 167 CeedCallBackend(CeedMallocArray(1, ctx_size, &impl->h_data_owned)); 168 impl->h_data_borrowed = NULL; 169 impl->h_data = impl->h_data_owned; 170 memcpy(impl->h_data, data, ctx_size); 171 break; 172 case CEED_OWN_POINTER: 173 impl->h_data_owned = data; 174 impl->h_data_borrowed = NULL; 175 impl->h_data = data; 176 break; 177 case CEED_USE_POINTER: 178 impl->h_data_borrowed = data; 179 impl->h_data = data; 180 break; 181 } 182 return CEED_ERROR_SUCCESS; 183 } 184 185 //------------------------------------------------------------------------------ 186 // Set data from device 187 //------------------------------------------------------------------------------ 188 static int CeedQFunctionContextSetDataDevice_Sycl(const CeedQFunctionContext ctx, const CeedCopyMode copy_mode, void *data) { 189 Ceed ceed; 190 Ceed_Sycl *sycl_data; 191 CeedQFunctionContext_Sycl *impl; 192 193 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 194 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 195 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 196 197 // Order queue 198 sycl::event e = sycl_data->sycl_queue.ext_oneapi_submit_barrier(); 199 200 // Wait for all work to finish before freeing memory 201 if (impl->d_data_owned) { 202 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw()); 203 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context)); 204 impl->d_data_owned = NULL; 205 } 206 207 switch (copy_mode) { 208 case CEED_COPY_VALUES: { 209 size_t ctx_size; 210 211 CeedCallBackend(CeedQFunctionContextGetContextSize(ctx, &ctx_size)); 212 CeedCallSycl(ceed, impl->d_data_owned = sycl::malloc_device(ctx_size, sycl_data->sycl_device, sycl_data->sycl_context)); 213 impl->d_data_borrowed = NULL; 214 impl->d_data = impl->d_data_owned; 215 sycl::event copy_event = sycl_data->sycl_queue.memcpy(impl->d_data, data, ctx_size, {e}); 216 CeedCallSycl(ceed, copy_event.wait_and_throw()); 217 } break; 218 case CEED_OWN_POINTER: { 219 impl->d_data_owned = data; 220 impl->d_data_borrowed = NULL; 221 impl->d_data = data; 222 } break; 223 case CEED_USE_POINTER: { 224 impl->d_data_owned = NULL; 225 impl->d_data_borrowed = data; 226 impl->d_data = data; 227 } break; 228 } 229 return CEED_ERROR_SUCCESS; 230 } 231 232 //------------------------------------------------------------------------------ 233 // Set the data used by a user context, 234 // freeing any previously allocated data if applicable 235 //------------------------------------------------------------------------------ 236 static int CeedQFunctionContextSetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, const CeedCopyMode copy_mode, void *data) { 237 Ceed ceed; 238 239 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 240 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx)); 241 switch (mem_type) { 242 case CEED_MEM_HOST: 243 return CeedQFunctionContextSetDataHost_Sycl(ctx, copy_mode, data); 244 case CEED_MEM_DEVICE: 245 return CeedQFunctionContextSetDataDevice_Sycl(ctx, copy_mode, data); 246 } 247 return CEED_ERROR_UNSUPPORTED; 248 } 249 250 //------------------------------------------------------------------------------ 251 // Take data 252 //------------------------------------------------------------------------------ 253 static int CeedQFunctionContextTakeData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 254 Ceed ceed; 255 Ceed_Sycl *ceedSycl; 256 bool need_sync = false; 257 CeedQFunctionContext_Sycl *impl; 258 259 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 260 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 261 CeedCallBackend(CeedGetData(ceed, &ceedSycl)); 262 263 // Order queue 264 ceedSycl->sycl_queue.ext_oneapi_submit_barrier(); 265 266 // Sync data to requested mem_type 267 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync)); 268 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type)); 269 270 // Update pointer 271 switch (mem_type) { 272 case CEED_MEM_HOST: 273 *(void **)data = impl->h_data_borrowed; 274 impl->h_data_borrowed = NULL; 275 impl->h_data = NULL; 276 break; 277 case CEED_MEM_DEVICE: 278 *(void **)data = impl->d_data_borrowed; 279 impl->d_data_borrowed = NULL; 280 impl->d_data = NULL; 281 break; 282 } 283 return CEED_ERROR_SUCCESS; 284 } 285 286 //------------------------------------------------------------------------------ 287 // Core logic for GetData. 288 // If a different memory type is most up to date, this will perform a copy 289 //------------------------------------------------------------------------------ 290 static int CeedQFunctionContextGetDataCore_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 291 Ceed ceed; 292 bool need_sync = false; 293 CeedQFunctionContext_Sycl *impl; 294 295 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 296 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 297 298 // Sync data to requested mem_type 299 CeedCallBackend(CeedQFunctionContextNeedSync_Sycl(ctx, mem_type, &need_sync)); 300 if (need_sync) CeedCallBackend(CeedQFunctionContextSync_Sycl(ctx, mem_type)); 301 302 // Update pointer 303 switch (mem_type) { 304 case CEED_MEM_HOST: 305 *(void **)data = impl->h_data; 306 break; 307 case CEED_MEM_DEVICE: 308 *(void **)data = impl->d_data; 309 break; 310 } 311 return CEED_ERROR_SUCCESS; 312 } 313 314 //------------------------------------------------------------------------------ 315 // Get read-only access to the data 316 //------------------------------------------------------------------------------ 317 static int CeedQFunctionContextGetDataRead_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 318 return CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data); 319 } 320 321 //------------------------------------------------------------------------------ 322 // Get read/write access to the data 323 //------------------------------------------------------------------------------ 324 static int CeedQFunctionContextGetData_Sycl(const CeedQFunctionContext ctx, const CeedMemType mem_type, void *data) { 325 Ceed ceed; 326 CeedQFunctionContext_Sycl *impl; 327 328 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 329 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 330 CeedCallBackend(CeedQFunctionContextGetDataCore_Sycl(ctx, mem_type, data)); 331 332 // Mark only pointer for requested memory as valid 333 CeedCallBackend(CeedQFunctionContextSetAllInvalid_Sycl(ctx)); 334 switch (mem_type) { 335 case CEED_MEM_HOST: 336 impl->h_data = *(void **)data; 337 break; 338 case CEED_MEM_DEVICE: 339 impl->d_data = *(void **)data; 340 break; 341 } 342 return CEED_ERROR_SUCCESS; 343 } 344 345 //------------------------------------------------------------------------------ 346 // Destroy the user context 347 //------------------------------------------------------------------------------ 348 static int CeedQFunctionContextDestroy_Sycl(const CeedQFunctionContext ctx) { 349 Ceed ceed; 350 Ceed_Sycl *sycl_data; 351 CeedQFunctionContext_Sycl *impl; 352 353 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 354 CeedCallBackend(CeedQFunctionContextGetBackendData(ctx, &impl)); 355 CeedCallBackend(CeedGetData(ceed, &sycl_data)); 356 357 // Wait for all work to finish before freeing memory 358 CeedCallSycl(ceed, sycl_data->sycl_queue.wait_and_throw()); 359 CeedCallSycl(ceed, sycl::free(impl->d_data_owned, sycl_data->sycl_context)); 360 CeedCallBackend(CeedFree(&impl->h_data_owned)); 361 CeedCallBackend(CeedFree(&impl)); 362 return CEED_ERROR_SUCCESS; 363 } 364 365 //------------------------------------------------------------------------------ 366 // QFunctionContext Create 367 //------------------------------------------------------------------------------ 368 int CeedQFunctionContextCreate_Sycl(CeedQFunctionContext ctx) { 369 Ceed ceed; 370 CeedQFunctionContext_Sycl *impl; 371 372 CeedCallBackend(CeedQFunctionContextGetCeed(ctx, &ceed)); 373 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasValidData", CeedQFunctionContextHasValidData_Sycl)); 374 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "HasBorrowedDataOfType", CeedQFunctionContextHasBorrowedDataOfType_Sycl)); 375 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "SetData", CeedQFunctionContextSetData_Sycl)); 376 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "TakeData", CeedQFunctionContextTakeData_Sycl)); 377 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetData", CeedQFunctionContextGetData_Sycl)); 378 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "GetDataRead", CeedQFunctionContextGetDataRead_Sycl)); 379 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunctionContext", ctx, "Destroy", CeedQFunctionContextDestroy_Sycl)); 380 CeedCallBackend(CeedCalloc(1, &impl)); 381 CeedCallBackend(CeedQFunctionContextSetBackendData(ctx, impl)); 382 return CEED_ERROR_SUCCESS; 383 } 384 385 //------------------------------------------------------------------------------ 386