1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other 2 // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE 3 // files for details. 4 // 5 // SPDX-License-Identifier: BSD-2-Clause 6 // 7 // This file is part of CEED: http://github.com/ceed 8 9 #include <ceed/backend.h> 10 #include <ceed/ceed.h> 11 12 #include <string> 13 #include <sycl/sycl.hpp> 14 #include <vector> 15 16 #include "../sycl/ceed-sycl-common.hpp" 17 #include "../sycl/ceed-sycl-compile.hpp" 18 #include "ceed-sycl-ref-qfunction-load.hpp" 19 #include "ceed-sycl-ref.hpp" 20 21 #define WG_SIZE_QF 384 22 23 //------------------------------------------------------------------------------ 24 // Apply QFunction 25 //------------------------------------------------------------------------------ 26 static int CeedQFunctionApply_Sycl(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) { 27 Ceed ceed; 28 Ceed_Sycl *ceed_Sycl; 29 void *context_data; 30 CeedInt num_input_fields, num_output_fields; 31 CeedQFunction_Sycl *impl; 32 33 CeedCallBackend(CeedQFunctionGetData(qf, &impl)); 34 35 // Build and compile kernel, if not done 36 if (!impl->QFunction) CeedCallBackend(CeedQFunctionBuildKernel_Sycl(qf)); 37 38 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 39 CeedCallBackend(CeedGetData(ceed, &ceed_Sycl)); 40 CeedCallBackend(CeedDestroy(&ceed)); 41 42 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 43 44 // Read vectors 45 std::vector<const CeedScalar *> inputs(num_input_fields); 46 const CeedVector *U_i = U; 47 for (auto &input_i : inputs) { 48 CeedCallBackend(CeedVectorGetArrayRead(*U_i, CEED_MEM_DEVICE, &input_i)); 49 ++U_i; 50 } 51 52 std::vector<CeedScalar *> outputs(num_output_fields); 53 CeedVector *V_i = V; 54 for (auto &output_i : outputs) { 55 CeedCallBackend(CeedVectorGetArrayWrite(*V_i, CEED_MEM_DEVICE, &output_i)); 56 ++V_i; 57 } 58 59 // Get context data 60 CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &context_data)); 61 62 std::vector<sycl::event> e; 63 64 if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()}; 65 66 // Launch as a basic parallel_for over Q quadrature points 67 ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 68 cgh.depends_on(e); 69 70 int iarg{}; 71 cgh.set_arg(iarg, context_data); 72 ++iarg; 73 cgh.set_arg(iarg, Q); 74 ++iarg; 75 for (auto &input_i : inputs) { 76 cgh.set_arg(iarg, input_i); 77 ++iarg; 78 } 79 for (auto &output_i : outputs) { 80 cgh.set_arg(iarg, output_i); 81 ++iarg; 82 } 83 // Hard-coding the work-group size for now 84 // We could use the Level Zero API to query and set an appropriate size in future 85 // Equivalent of CUDA Occupancy Calculator 86 int wg_size = WG_SIZE_QF; 87 sycl::range<1> rounded_Q = ((Q + (wg_size - 1)) / wg_size) * wg_size; 88 sycl::nd_range<1> kernel_range(rounded_Q, wg_size); 89 cgh.parallel_for(kernel_range, *(impl->QFunction)); 90 }); 91 92 // Restore vectors 93 U_i = U; 94 for (auto &input_i : inputs) { 95 CeedCallBackend(CeedVectorRestoreArrayRead(*U_i, &input_i)); 96 ++U_i; 97 } 98 99 V_i = V; 100 for (auto &output_i : outputs) { 101 CeedCallBackend(CeedVectorRestoreArray(*V_i, &output_i)); 102 ++V_i; 103 } 104 105 // Restore context 106 CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &context_data)); 107 return CEED_ERROR_SUCCESS; 108 } 109 110 //------------------------------------------------------------------------------ 111 // Destroy QFunction 112 //------------------------------------------------------------------------------ 113 static int CeedQFunctionDestroy_Sycl(CeedQFunction qf) { 114 Ceed ceed; 115 CeedQFunction_Sycl *impl; 116 117 CeedCallBackend(CeedQFunctionGetData(qf, &impl)); 118 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 119 delete impl->QFunction; 120 delete impl->sycl_module; 121 CeedCallBackend(CeedFree(&impl)); 122 CeedCallBackend(CeedDestroy(&ceed)); 123 return CEED_ERROR_SUCCESS; 124 } 125 126 //------------------------------------------------------------------------------ 127 // Create QFunction 128 //------------------------------------------------------------------------------ 129 int CeedQFunctionCreate_Sycl(CeedQFunction qf) { 130 Ceed ceed; 131 CeedQFunction_Sycl *impl; 132 133 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 134 CeedCallBackend(CeedCalloc(1, &impl)); 135 CeedCallBackend(CeedQFunctionSetData(qf, impl)); 136 // Register backend functions 137 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Sycl)); 138 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Sycl)); 139 CeedCallBackend(CeedDestroy(&ceed)); 140 return CEED_ERROR_SUCCESS; 141 } 142 143 //------------------------------------------------------------------------------ 144