1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other
2 // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3 // files for details.
4 //
5 // SPDX-License-Identifier: BSD-2-Clause
6 //
7 // This file is part of CEED: http://github.com/ceed
8
9 #include <ceed/backend.h>
10 #include <ceed/ceed.h>
11
12 #include <string>
13 #include <sycl/sycl.hpp>
14 #include <vector>
15
16 #include "../sycl/ceed-sycl-common.hpp"
17 #include "../sycl/ceed-sycl-compile.hpp"
18 #include "ceed-sycl-ref-qfunction-load.hpp"
19 #include "ceed-sycl-ref.hpp"
20
21 #define WG_SIZE_QF 384
22
23 //------------------------------------------------------------------------------
24 // Apply QFunction
25 //------------------------------------------------------------------------------
CeedQFunctionApply_Sycl(CeedQFunction qf,CeedInt Q,CeedVector * U,CeedVector * V)26 static int CeedQFunctionApply_Sycl(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) {
27 Ceed ceed;
28 Ceed_Sycl *ceed_Sycl;
29 void *context_data;
30 CeedInt num_input_fields, num_output_fields;
31 CeedQFunction_Sycl *impl;
32
33 CeedCallBackend(CeedQFunctionGetData(qf, &impl));
34
35 // Build and compile kernel, if not done
36 if (!impl->QFunction) CeedCallBackend(CeedQFunctionBuildKernel_Sycl(qf));
37
38 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
39 CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
40 CeedCallBackend(CeedDestroy(&ceed));
41
42 CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
43
44 // Read vectors
45 std::vector<const CeedScalar *> inputs(num_input_fields);
46 const CeedVector *U_i = U;
47 for (auto &input_i : inputs) {
48 CeedCallBackend(CeedVectorGetArrayRead(*U_i, CEED_MEM_DEVICE, &input_i));
49 ++U_i;
50 }
51
52 std::vector<CeedScalar *> outputs(num_output_fields);
53 CeedVector *V_i = V;
54 for (auto &output_i : outputs) {
55 CeedCallBackend(CeedVectorGetArrayWrite(*V_i, CEED_MEM_DEVICE, &output_i));
56 ++V_i;
57 }
58
59 // Get context data
60 CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &context_data));
61
62 std::vector<sycl::event> e;
63
64 if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()};
65
66 // Launch as a basic parallel_for over Q quadrature points
67 ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
68 cgh.depends_on(e);
69
70 int iarg{};
71 cgh.set_arg(iarg, context_data);
72 ++iarg;
73 cgh.set_arg(iarg, Q);
74 ++iarg;
75 for (auto &input_i : inputs) {
76 cgh.set_arg(iarg, input_i);
77 ++iarg;
78 }
79 for (auto &output_i : outputs) {
80 cgh.set_arg(iarg, output_i);
81 ++iarg;
82 }
83 // Hard-coding the work-group size for now
84 // We could use the Level Zero API to query and set an appropriate size in future
85 // Equivalent of CUDA Occupancy Calculator
86 int wg_size = WG_SIZE_QF;
87 sycl::range<1> rounded_Q = ((Q + (wg_size - 1)) / wg_size) * wg_size;
88 sycl::nd_range<1> kernel_range(rounded_Q, wg_size);
89 cgh.parallel_for(kernel_range, *(impl->QFunction));
90 });
91
92 // Restore vectors
93 U_i = U;
94 for (auto &input_i : inputs) {
95 CeedCallBackend(CeedVectorRestoreArrayRead(*U_i, &input_i));
96 ++U_i;
97 }
98
99 V_i = V;
100 for (auto &output_i : outputs) {
101 CeedCallBackend(CeedVectorRestoreArray(*V_i, &output_i));
102 ++V_i;
103 }
104
105 // Restore context
106 CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &context_data));
107 return CEED_ERROR_SUCCESS;
108 }
109
110 //------------------------------------------------------------------------------
111 // Destroy QFunction
112 //------------------------------------------------------------------------------
CeedQFunctionDestroy_Sycl(CeedQFunction qf)113 static int CeedQFunctionDestroy_Sycl(CeedQFunction qf) {
114 Ceed ceed;
115 CeedQFunction_Sycl *impl;
116
117 CeedCallBackend(CeedQFunctionGetData(qf, &impl));
118 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
119 delete impl->QFunction;
120 delete impl->sycl_module;
121 CeedCallBackend(CeedFree(&impl));
122 CeedCallBackend(CeedDestroy(&ceed));
123 return CEED_ERROR_SUCCESS;
124 }
125
126 //------------------------------------------------------------------------------
127 // Create QFunction
128 //------------------------------------------------------------------------------
CeedQFunctionCreate_Sycl(CeedQFunction qf)129 int CeedQFunctionCreate_Sycl(CeedQFunction qf) {
130 Ceed ceed;
131 CeedQFunction_Sycl *impl;
132
133 CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
134 CeedCallBackend(CeedCalloc(1, &impl));
135 CeedCallBackend(CeedQFunctionSetData(qf, impl));
136 // Register backend functions
137 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Sycl));
138 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Sycl));
139 CeedCallBackend(CeedDestroy(&ceed));
140 return CEED_ERROR_SUCCESS;
141 }
142
143 //------------------------------------------------------------------------------
144