xref: /libCEED/backends/sycl-gen/ceed-sycl-gen-operator.sycl.cpp (revision bdee0278611904727ee35fcc2d0d7c3bf83db4c4)
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <stddef.h>
11 
12 #include "../sycl/ceed-sycl-compile.hpp"
13 #include "ceed-sycl-gen-operator-build.hpp"
14 #include "ceed-sycl-gen.hpp"
15 
16 //------------------------------------------------------------------------------
17 // Destroy operator
18 //------------------------------------------------------------------------------
19 static int CeedOperatorDestroy_Sycl_gen(CeedOperator op) {
20   CeedOperator_Sycl_gen *impl;
21 
22   CeedCallBackend(CeedOperatorGetData(op, &impl));
23   CeedCallBackend(CeedFree(&impl));
24   return CEED_ERROR_SUCCESS;
25 }
26 
27 //------------------------------------------------------------------------------
28 // Apply and add to output
29 //------------------------------------------------------------------------------
30 static int CeedOperatorApplyAdd_Sycl_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
31   Ceed                    ceed;
32   Ceed_Sycl              *ceed_Sycl;
33   CeedInt                 num_elem, num_input_fields, num_output_fields;
34   CeedEvalMode            eval_mode;
35   CeedVector              output_vecs[CEED_FIELD_MAX] = {};
36   CeedQFunctionField     *qf_input_fields, *qf_output_fields;
37   CeedQFunction_Sycl_gen *qf_impl;
38   CeedQFunction           qf;
39   CeedOperatorField      *op_input_fields, *op_output_fields;
40   CeedOperator_Sycl_gen  *impl;
41 
42   // Check for tensor-product bases
43   {
44     bool has_tensor_bases;
45 
46     CeedCallBackend(CeedOperatorHasTensorBases(op, &has_tensor_bases));
47     // -- Fallback to ref if not all bases are tensor-product
48     if (!has_tensor_bases) {
49       CeedOperator op_fallback;
50 
51       CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to sycl/ref CeedOperator due to non-tensor bases");
52       CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
53       CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
54       return CEED_ERROR_SUCCESS;
55     }
56   }
57 
58   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
59   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
60   CeedCallBackend(CeedOperatorGetData(op, &impl));
61   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
62   CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl));
63   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
64   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
65   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
66 
67   // Creation of the operator
68   CeedCallBackend(CeedOperatorBuildKernel_Sycl_gen(op));
69 
70   // Input vectors
71   for (CeedInt i = 0; i < num_input_fields; i++) {
72     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
73     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
74       impl->fields->inputs[i] = NULL;
75     } else {
76       bool       is_active;
77       CeedVector vec;
78 
79       // Get input vector
80       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
81       is_active = vec == CEED_VECTOR_ACTIVE;
82       if (is_active) vec = input_vec;
83       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &impl->fields->inputs[i]));
84       if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
85     }
86   }
87 
88   // Output vectors
89   for (CeedInt i = 0; i < num_output_fields; i++) {
90     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
91     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
92       impl->fields->outputs[i] = NULL;
93     } else {
94       bool       is_active;
95       CeedVector vec;
96 
97       // Get output vector
98       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
99       is_active = vec == CEED_VECTOR_ACTIVE;
100       if (is_active) vec = output_vec;
101       output_vecs[i] = vec;
102       // Check for multiple output modes
103       CeedInt index = -1;
104       for (CeedInt j = 0; j < i; j++) {
105         if (vec == output_vecs[j]) {
106           index = j;
107           break;
108         }
109       }
110       if (index == -1) {
111         CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &impl->fields->outputs[i]));
112       } else {
113         impl->fields->outputs[i] = impl->fields->outputs[index];
114       }
115       if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
116     }
117   }
118 
119   // Get context data
120   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_impl->d_c));
121 
122   // Apply operator
123   const CeedInt dim  = impl->dim;
124   const CeedInt Q_1d = impl->Q_1d;
125   const CeedInt P_1d = impl->max_P_1d;
126   CeedInt       block_sizes[3], grid = 0;
127 
128   CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes));
129   if (dim == 1) {
130     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
131     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
132   } else if (dim == 2) {
133     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
134     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
135   } else if (dim == 3) {
136     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
137     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
138   }
139 
140   sycl::range<3>    local_range(block_sizes[2], block_sizes[1], block_sizes[0]);
141   sycl::range<3>    global_range(grid * block_sizes[2], block_sizes[1], block_sizes[0]);
142   sycl::nd_range<3> kernel_range(global_range, local_range);
143 
144   //-----------
145   std::vector<sycl::event> e;
146 
147   if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()};
148 
149   CeedCallSycl(ceed, ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
150     cgh.depends_on(e);
151     cgh.set_args(num_elem, qf_impl->d_c, impl->indices, impl->fields, impl->B, impl->G, impl->W);
152     cgh.parallel_for(kernel_range, *(impl->op));
153   }));
154   CeedCallSycl(ceed, ceed_Sycl->sycl_queue.wait_and_throw());
155 
156   // Restore input arrays
157   for (CeedInt i = 0; i < num_input_fields; i++) {
158     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
159     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
160     } else {
161       bool       is_active;
162       CeedVector vec;
163 
164       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
165       is_active = vec == CEED_VECTOR_ACTIVE;
166       if (is_active) vec = input_vec;
167       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &impl->fields->inputs[i]));
168       if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
169     }
170   }
171 
172   // Restore output arrays
173   for (CeedInt i = 0; i < num_output_fields; i++) {
174     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
175     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
176     } else {
177       bool       is_active;
178       CeedVector vec;
179 
180       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
181       is_active = vec == CEED_VECTOR_ACTIVE;
182       if (is_active) vec = output_vec;
183       // Check for multiple output modes
184       CeedInt index = -1;
185 
186       for (CeedInt j = 0; j < i; j++) {
187         if (vec == output_vecs[j]) {
188           index = j;
189           break;
190         }
191       }
192       if (index == -1) {
193         CeedCallBackend(CeedVectorRestoreArray(vec, &impl->fields->outputs[i]));
194       }
195       if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
196     }
197   }
198 
199   // Restore context data
200   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_impl->d_c));
201   CeedCallBackend(CeedDestroy(&ceed));
202   CeedCallBackend(CeedQFunctionDestroy(&qf));
203   return CEED_ERROR_SUCCESS;
204 }
205 
206 //------------------------------------------------------------------------------
207 // Create operator
208 //------------------------------------------------------------------------------
209 int CeedOperatorCreate_Sycl_gen(CeedOperator op) {
210   Ceed                   ceed;
211   Ceed_Sycl             *sycl_data;
212   CeedOperator_Sycl_gen *impl;
213 
214   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
215   CeedCallBackend(CeedGetData(ceed, &sycl_data));
216 
217   CeedCallBackend(CeedCalloc(1, &impl));
218   CeedCallBackend(CeedOperatorSetData(op, impl));
219 
220   impl->indices = sycl::malloc_device<FieldsInt_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
221   impl->fields  = sycl::malloc_host<Fields_Sycl>(1, sycl_data->sycl_context);
222   impl->B       = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
223   impl->G       = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
224   impl->W       = sycl::malloc_device<CeedScalar>(1, sycl_data->sycl_device, sycl_data->sycl_context);
225 
226   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Sycl_gen));
227   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Sycl_gen));
228   CeedCallBackend(CeedDestroy(&ceed));
229   return CEED_ERROR_SUCCESS;
230 }
231 
232 //------------------------------------------------------------------------------
233