1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <stddef.h>
11
12 #include "../sycl/ceed-sycl-compile.hpp"
13 #include "ceed-sycl-gen-operator-build.hpp"
14 #include "ceed-sycl-gen.hpp"
15
16 //------------------------------------------------------------------------------
17 // Destroy operator
18 //------------------------------------------------------------------------------
CeedOperatorDestroy_Sycl_gen(CeedOperator op)19 static int CeedOperatorDestroy_Sycl_gen(CeedOperator op) {
20 CeedOperator_Sycl_gen *impl;
21
22 CeedCallBackend(CeedOperatorGetData(op, &impl));
23 CeedCallBackend(CeedFree(&impl));
24 return CEED_ERROR_SUCCESS;
25 }
26
27 //------------------------------------------------------------------------------
28 // Apply and add to output
29 //------------------------------------------------------------------------------
CeedOperatorApplyAdd_Sycl_gen(CeedOperator op,CeedVector input_vec,CeedVector output_vec,CeedRequest * request)30 static int CeedOperatorApplyAdd_Sycl_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
31 Ceed ceed;
32 Ceed_Sycl *ceed_Sycl;
33 CeedInt num_elem, num_input_fields, num_output_fields;
34 CeedEvalMode eval_mode;
35 CeedVector output_vecs[CEED_FIELD_MAX] = {};
36 CeedQFunctionField *qf_input_fields, *qf_output_fields;
37 CeedQFunction_Sycl_gen *qf_impl;
38 CeedQFunction qf;
39 CeedOperatorField *op_input_fields, *op_output_fields;
40 CeedOperator_Sycl_gen *impl;
41
42 // Check for tensor-product bases
43 {
44 bool has_tensor_bases;
45
46 CeedCallBackend(CeedOperatorHasTensorBases(op, &has_tensor_bases));
47 // -- Fallback to ref if not all bases are tensor-product
48 if (!has_tensor_bases) {
49 CeedOperator op_fallback;
50
51 CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to sycl/ref CeedOperator due to non-tensor bases");
52 CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
53 CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
54 return CEED_ERROR_SUCCESS;
55 }
56 }
57
58 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
59 CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
60 CeedCallBackend(CeedOperatorGetData(op, &impl));
61 CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
62 CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl));
63 CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
64 CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
65 CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
66
67 // Creation of the operator
68 CeedCallBackend(CeedOperatorBuildKernel_Sycl_gen(op));
69
70 // Input vectors
71 for (CeedInt i = 0; i < num_input_fields; i++) {
72 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
73 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
74 impl->fields->inputs[i] = NULL;
75 } else {
76 bool is_active;
77 CeedVector vec;
78
79 // Get input vector
80 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
81 is_active = vec == CEED_VECTOR_ACTIVE;
82 if (is_active) vec = input_vec;
83 CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &impl->fields->inputs[i]));
84 if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
85 }
86 }
87
88 // Output vectors
89 for (CeedInt i = 0; i < num_output_fields; i++) {
90 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
91 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
92 impl->fields->outputs[i] = NULL;
93 } else {
94 bool is_active;
95 CeedVector vec;
96
97 // Get output vector
98 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
99 is_active = vec == CEED_VECTOR_ACTIVE;
100 if (is_active) vec = output_vec;
101 output_vecs[i] = vec;
102 // Check for multiple output modes
103 CeedInt index = -1;
104 for (CeedInt j = 0; j < i; j++) {
105 if (vec == output_vecs[j]) {
106 index = j;
107 break;
108 }
109 }
110 if (index == -1) {
111 CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &impl->fields->outputs[i]));
112 } else {
113 impl->fields->outputs[i] = impl->fields->outputs[index];
114 }
115 if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
116 }
117 }
118
119 // Get context data
120 CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_impl->d_c));
121
122 // Apply operator
123 const CeedInt dim = impl->dim;
124 const CeedInt Q_1d = impl->Q_1d;
125 const CeedInt P_1d = impl->max_P_1d;
126 CeedInt block_sizes[3], grid = 0;
127
128 CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes));
129 if (dim == 1) {
130 grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
131 // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
132 } else if (dim == 2) {
133 grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
134 // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
135 } else if (dim == 3) {
136 grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
137 // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
138 }
139
140 sycl::range<3> local_range(block_sizes[2], block_sizes[1], block_sizes[0]);
141 sycl::range<3> global_range(grid * block_sizes[2], block_sizes[1], block_sizes[0]);
142 sycl::nd_range<3> kernel_range(global_range, local_range);
143
144 //-----------
145 std::vector<sycl::event> e;
146
147 if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()};
148
149 CeedCallSycl(ceed, ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
150 cgh.depends_on(e);
151 cgh.set_args(num_elem, qf_impl->d_c, impl->indices, impl->fields, impl->B, impl->G, impl->W);
152 cgh.parallel_for(kernel_range, *(impl->op));
153 }));
154 CeedCallSycl(ceed, ceed_Sycl->sycl_queue.wait_and_throw());
155
156 // Restore input arrays
157 for (CeedInt i = 0; i < num_input_fields; i++) {
158 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
159 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
160 } else {
161 bool is_active;
162 CeedVector vec;
163
164 CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
165 is_active = vec == CEED_VECTOR_ACTIVE;
166 if (is_active) vec = input_vec;
167 CeedCallBackend(CeedVectorRestoreArrayRead(vec, &impl->fields->inputs[i]));
168 if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
169 }
170 }
171
172 // Restore output arrays
173 for (CeedInt i = 0; i < num_output_fields; i++) {
174 CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
175 if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
176 } else {
177 bool is_active;
178 CeedVector vec;
179
180 CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
181 is_active = vec == CEED_VECTOR_ACTIVE;
182 if (is_active) vec = output_vec;
183 // Check for multiple output modes
184 CeedInt index = -1;
185
186 for (CeedInt j = 0; j < i; j++) {
187 if (vec == output_vecs[j]) {
188 index = j;
189 break;
190 }
191 }
192 if (index == -1) {
193 CeedCallBackend(CeedVectorRestoreArray(vec, &impl->fields->outputs[i]));
194 }
195 if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
196 }
197 }
198
199 // Restore context data
200 CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_impl->d_c));
201 CeedCallBackend(CeedDestroy(&ceed));
202 CeedCallBackend(CeedQFunctionDestroy(&qf));
203 return CEED_ERROR_SUCCESS;
204 }
205
206 //------------------------------------------------------------------------------
207 // Create operator
208 //------------------------------------------------------------------------------
CeedOperatorCreate_Sycl_gen(CeedOperator op)209 int CeedOperatorCreate_Sycl_gen(CeedOperator op) {
210 Ceed ceed;
211 Ceed_Sycl *sycl_data;
212 CeedOperator_Sycl_gen *impl;
213
214 CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
215 CeedCallBackend(CeedGetData(ceed, &sycl_data));
216
217 CeedCallBackend(CeedCalloc(1, &impl));
218 CeedCallBackend(CeedOperatorSetData(op, impl));
219
220 impl->indices = sycl::malloc_device<FieldsInt_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
221 impl->fields = sycl::malloc_host<Fields_Sycl>(1, sycl_data->sycl_context);
222 impl->B = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
223 impl->G = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
224 impl->W = sycl::malloc_device<CeedScalar>(1, sycl_data->sycl_device, sycl_data->sycl_context);
225
226 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Sycl_gen));
227 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Sycl_gen));
228 CeedCallBackend(CeedDestroy(&ceed));
229 return CEED_ERROR_SUCCESS;
230 }
231
232 //------------------------------------------------------------------------------
233