xref: /libCEED/backends/sycl-shared/ceed-sycl-shared-basis.sycl.cpp (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3bd882c8aSJames Wright //
4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5bd882c8aSJames Wright //
6bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7bd882c8aSJames Wright 
8bd882c8aSJames Wright #include <ceed/backend.h>
9bd882c8aSJames Wright #include <ceed/ceed.h>
10bd882c8aSJames Wright #include <ceed/jit-tools.h>
11bd882c8aSJames Wright 
12bd882c8aSJames Wright #include <map>
13bd882c8aSJames Wright #include <string_view>
14bd882c8aSJames Wright #include <sycl/sycl.hpp>
15bd882c8aSJames Wright 
16bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
17bd882c8aSJames Wright #include "ceed-sycl-shared.hpp"
18bd882c8aSJames Wright 
19bd882c8aSJames Wright //------------------------------------------------------------------------------
20bd882c8aSJames Wright // Compute the local range of for basis kernels
21bd882c8aSJames Wright //------------------------------------------------------------------------------
ComputeLocalRange(Ceed ceed,CeedInt dim,CeedInt thread_1d,CeedInt * local_range,CeedInt max_group_size=256)22bd882c8aSJames Wright static int ComputeLocalRange(Ceed ceed, CeedInt dim, CeedInt thread_1d, CeedInt *local_range, CeedInt max_group_size = 256) {
23bd882c8aSJames Wright   local_range[0]               = thread_1d;
24bd882c8aSJames Wright   local_range[1]               = (dim > 1) ? thread_1d : 1;
25bd882c8aSJames Wright   const CeedInt min_group_size = local_range[0] * local_range[1];
26dd64fc84SJeremy L Thompson 
27bd882c8aSJames Wright   CeedCheck(min_group_size <= max_group_size, ceed, CEED_ERROR_BACKEND, "Requested group size is smaller than the required minimum.");
28bd882c8aSJames Wright 
29bd882c8aSJames Wright   local_range[2] = max_group_size / min_group_size;  // elements per group
30bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
31bd882c8aSJames Wright }
32bd882c8aSJames Wright 
33bd882c8aSJames Wright //------------------------------------------------------------------------------
34bd882c8aSJames Wright // Apply basis
35bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisApplyTensor_Sycl_shared(CeedBasis basis,const CeedInt num_elem,CeedTransposeMode t_mode,CeedEvalMode eval_mode,CeedVector u,CeedVector v)36bd882c8aSJames Wright int CeedBasisApplyTensor_Sycl_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
37bd882c8aSJames Wright                                      CeedVector v) {
38bd882c8aSJames Wright   Ceed                   ceed;
39bd882c8aSJames Wright   Ceed_Sycl             *ceed_Sycl;
40dd64fc84SJeremy L Thompson   const CeedScalar      *d_u;
41dd64fc84SJeremy L Thompson   CeedScalar            *d_v;
42bd882c8aSJames Wright   CeedBasis_Sycl_shared *impl;
43dd64fc84SJeremy L Thompson 
44dd64fc84SJeremy L Thompson   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
45dd64fc84SJeremy L Thompson   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
46bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetData(basis, &impl));
47bd882c8aSJames Wright 
480ae60fd3SJeremy L Thompson   // Get read/write access to u, v
490ae60fd3SJeremy L Thompson   if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
500ae60fd3SJeremy L Thompson   else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode");
51bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
52bd882c8aSJames Wright 
53bd882c8aSJames Wright   // Apply basis operation
54bd882c8aSJames Wright   switch (eval_mode) {
55bd882c8aSJames Wright     case CEED_EVAL_INTERP: {
56bd882c8aSJames Wright       CeedInt       *lrange         = impl->interp_local_range;
57bd882c8aSJames Wright       const CeedInt &elem_per_group = lrange[2];
58bd882c8aSJames Wright       const CeedInt  group_count    = (num_elem / elem_per_group) + !!(num_elem % elem_per_group);
59bd882c8aSJames Wright       //-----------
60bd882c8aSJames Wright       sycl::range<3>    local_range(lrange[2], lrange[1], lrange[0]);
61bd882c8aSJames Wright       sycl::range<3>    global_range(group_count * lrange[2], lrange[1], lrange[0]);
62bd882c8aSJames Wright       sycl::nd_range<3> kernel_range(global_range, local_range);
63bd882c8aSJames Wright       //-----------
64bd882c8aSJames Wright       sycl::kernel *interp_kernel = (t_mode == CEED_TRANSPOSE) ? impl->interp_transpose_kernel : impl->interp_kernel;
65bd882c8aSJames Wright 
661f4b1b45SUmesh Unnikrishnan       std::vector<sycl::event> e;
67bd882c8aSJames Wright 
681f4b1b45SUmesh Unnikrishnan       if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()};
69bd882c8aSJames Wright       ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
70bd882c8aSJames Wright         cgh.depends_on(e);
71bd882c8aSJames Wright         cgh.set_args(num_elem, impl->d_interp_1d, d_u, d_v);
72bd882c8aSJames Wright         cgh.parallel_for(kernel_range, *interp_kernel);
73bd882c8aSJames Wright       });
74bd882c8aSJames Wright 
75bd882c8aSJames Wright     } break;
76bd882c8aSJames Wright     case CEED_EVAL_GRAD: {
77bd882c8aSJames Wright       CeedInt       *lrange         = impl->grad_local_range;
78bd882c8aSJames Wright       const CeedInt &elem_per_group = lrange[2];
79bd882c8aSJames Wright       const CeedInt  group_count    = (num_elem / elem_per_group) + !!(num_elem % elem_per_group);
80bd882c8aSJames Wright       //-----------
81bd882c8aSJames Wright       sycl::range<3>    local_range(lrange[2], lrange[1], lrange[0]);
82bd882c8aSJames Wright       sycl::range<3>    global_range(group_count * lrange[2], lrange[1], lrange[0]);
83bd882c8aSJames Wright       sycl::nd_range<3> kernel_range(global_range, local_range);
84bd882c8aSJames Wright       //-----------
85bd882c8aSJames Wright       sycl::kernel     *grad_kernel = (t_mode == CEED_TRANSPOSE) ? impl->grad_transpose_kernel : impl->grad_kernel;
86bd882c8aSJames Wright       const CeedScalar *d_grad_1d   = (impl->d_collo_grad_1d) ? impl->d_collo_grad_1d : impl->d_grad_1d;
871f4b1b45SUmesh Unnikrishnan 
881f4b1b45SUmesh Unnikrishnan       std::vector<sycl::event> e;
891f4b1b45SUmesh Unnikrishnan 
901f4b1b45SUmesh Unnikrishnan       if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()};
91bd882c8aSJames Wright 
92bd882c8aSJames Wright       ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
93bd882c8aSJames Wright         cgh.depends_on(e);
94bd882c8aSJames Wright         cgh.set_args(num_elem, impl->d_interp_1d, d_grad_1d, d_u, d_v);
95bd882c8aSJames Wright         cgh.parallel_for(kernel_range, *grad_kernel);
96bd882c8aSJames Wright       });
97bd882c8aSJames Wright     } break;
98bd882c8aSJames Wright     case CEED_EVAL_WEIGHT: {
99bd882c8aSJames Wright       CeedInt       *lrange         = impl->weight_local_range;
100bd882c8aSJames Wright       const CeedInt &elem_per_group = lrange[2];
101bd882c8aSJames Wright       const CeedInt  group_count    = (num_elem / elem_per_group) + !!(num_elem % elem_per_group);
102bd882c8aSJames Wright       //-----------
103bd882c8aSJames Wright       sycl::range<3>    local_range(lrange[2], lrange[1], lrange[0]);
104bd882c8aSJames Wright       sycl::range<3>    global_range(group_count * lrange[2], lrange[1], lrange[0]);
105bd882c8aSJames Wright       sycl::nd_range<3> kernel_range(global_range, local_range);
106bd882c8aSJames Wright       //-----------
1071f4b1b45SUmesh Unnikrishnan       std::vector<sycl::event> e;
1081f4b1b45SUmesh Unnikrishnan 
109097cc795SJames Wright       CeedCheck(impl->d_q_weight_1d, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weight_1d not set", CeedEvalModes[eval_mode]);
1101f4b1b45SUmesh Unnikrishnan       if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()};
111bd882c8aSJames Wright 
112bd882c8aSJames Wright       ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
113bd882c8aSJames Wright         cgh.depends_on(e);
114bd882c8aSJames Wright         cgh.set_args(num_elem, impl->d_q_weight_1d, d_v);
115bd882c8aSJames Wright         cgh.parallel_for(kernel_range, *(impl->weight_kernel));
116bd882c8aSJames Wright       });
117bd882c8aSJames Wright     } break;
1180ae60fd3SJeremy L Thompson     case CEED_EVAL_NONE: /* handled separately below */
1190ae60fd3SJeremy L Thompson       break;
120bd882c8aSJames Wright     // LCOV_EXCL_START
121bd882c8aSJames Wright     case CEED_EVAL_DIV:
122bd882c8aSJames Wright     case CEED_EVAL_CURL:
1234e3038a5SJeremy L Thompson       return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]);
124bd882c8aSJames Wright       // LCOV_EXCL_STOP
125bd882c8aSJames Wright   }
126bd882c8aSJames Wright 
1270ae60fd3SJeremy L Thompson   // Restore vectors, cover CEED_EVAL_NONE
128bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
1290ae60fd3SJeremy L Thompson   if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u));
1300ae60fd3SJeremy L Thompson   if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
1319bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
132bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
133bd882c8aSJames Wright }
134bd882c8aSJames Wright 
135bd882c8aSJames Wright //------------------------------------------------------------------------------
136bd882c8aSJames Wright // Destroy basis
137bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisDestroy_Sycl_shared(CeedBasis basis)138bd882c8aSJames Wright static int CeedBasisDestroy_Sycl_shared(CeedBasis basis) {
139bd882c8aSJames Wright   Ceed                   ceed;
140bd882c8aSJames Wright   Ceed_Sycl             *data;
141dd64fc84SJeremy L Thompson   CeedBasis_Sycl_shared *impl;
142bd882c8aSJames Wright 
143dd64fc84SJeremy L Thompson   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
144dd64fc84SJeremy L Thompson   CeedCallBackend(CeedBasisGetData(basis, &impl));
145dd64fc84SJeremy L Thompson   CeedCallBackend(CeedGetData(ceed, &data));
146bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
147097cc795SJames Wright   if (impl->d_q_weight_1d) CeedCallSycl(ceed, sycl::free(impl->d_q_weight_1d, data->sycl_context));
148bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_interp_1d, data->sycl_context));
149bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_grad_1d, data->sycl_context));
150bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_collo_grad_1d, data->sycl_context));
151bd882c8aSJames Wright 
152bd882c8aSJames Wright   delete impl->interp_kernel;
153bd882c8aSJames Wright   delete impl->interp_transpose_kernel;
154bd882c8aSJames Wright   delete impl->grad_kernel;
155bd882c8aSJames Wright   delete impl->grad_transpose_kernel;
156bd882c8aSJames Wright   delete impl->weight_kernel;
157bd882c8aSJames Wright   delete impl->sycl_module;
158bd882c8aSJames Wright 
159bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
1609bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
161bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
162bd882c8aSJames Wright }
163bd882c8aSJames Wright 
164bd882c8aSJames Wright //------------------------------------------------------------------------------
165bd882c8aSJames Wright // Create tensor basis
166bd882c8aSJames Wright // TODO: Refactor
167bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedBasisCreateTensorH1_Sycl_shared(CeedInt dim,CeedInt P_1d,CeedInt Q_1d,const CeedScalar * interp_1d,const CeedScalar * grad_1d,const CeedScalar * q_ref_1d,const CeedScalar * q_weight_1d,CeedBasis basis)168bd882c8aSJames Wright int CeedBasisCreateTensorH1_Sycl_shared(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
169bd882c8aSJames Wright                                         const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
170bd882c8aSJames Wright   Ceed                   ceed;
171bd882c8aSJames Wright   Ceed_Sycl             *data;
17222070f95SJeremy L Thompson   char                  *basis_kernel_source;
17322070f95SJeremy L Thompson   const char            *basis_kernel_path;
174bd882c8aSJames Wright   CeedInt                num_comp;
175dd64fc84SJeremy L Thompson   CeedBasis_Sycl_shared *impl;
176dd64fc84SJeremy L Thompson 
177dd64fc84SJeremy L Thompson   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
178dd64fc84SJeremy L Thompson   CeedCallBackend(CeedCalloc(1, &impl));
179dd64fc84SJeremy L Thompson   CeedCallBackend(CeedGetData(ceed, &data));
180bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
181bd882c8aSJames Wright 
182bd882c8aSJames Wright   const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
183bd882c8aSJames Wright   const CeedInt num_nodes = CeedIntPow(P_1d, dim);
184bd882c8aSJames Wright   const CeedInt num_qpts  = CeedIntPow(Q_1d, dim);
185bd882c8aSJames Wright 
186bd882c8aSJames Wright   CeedInt *interp_lrange = impl->interp_local_range;
187dd64fc84SJeremy L Thompson 
188bd882c8aSJames Wright   CeedCallBackend(ComputeLocalRange(ceed, dim, thread_1d, interp_lrange));
189bd882c8aSJames Wright   const CeedInt interp_group_size = interp_lrange[0] * interp_lrange[1] * interp_lrange[2];
190bd882c8aSJames Wright 
191bd882c8aSJames Wright   CeedInt *grad_lrange = impl->grad_local_range;
192dd64fc84SJeremy L Thompson 
193bd882c8aSJames Wright   CeedCallBackend(ComputeLocalRange(ceed, dim, thread_1d, grad_lrange));
194bd882c8aSJames Wright   const CeedInt grad_group_size = grad_lrange[0] * grad_lrange[1] * grad_lrange[2];
195bd882c8aSJames Wright 
196bd882c8aSJames Wright   CeedCallBackend(ComputeLocalRange(ceed, dim, Q_1d, impl->weight_local_range));
197bd882c8aSJames Wright 
1981f4b1b45SUmesh Unnikrishnan   std::vector<sycl::event> e;
1991f4b1b45SUmesh Unnikrishnan 
2001f4b1b45SUmesh Unnikrishnan   if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()};
2011f4b1b45SUmesh Unnikrishnan 
202bd882c8aSJames Wright   // Copy basis data to GPU
203097cc795SJames Wright   std::vector<sycl::event> copy_events;
204097cc795SJames Wright   if (q_weight_1d) {
205bd882c8aSJames Wright     CeedCallSycl(ceed, impl->d_q_weight_1d = sycl::malloc_device<CeedScalar>(Q_1d, data->sycl_device, data->sycl_context));
2061f4b1b45SUmesh Unnikrishnan     sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight_1d, impl->d_q_weight_1d, Q_1d, e);
207097cc795SJames Wright     copy_events.push_back(copy_weight);
208097cc795SJames Wright   }
209bd882c8aSJames Wright 
210bd882c8aSJames Wright   const CeedInt interp_length = Q_1d * P_1d;
211bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_interp_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
2121f4b1b45SUmesh Unnikrishnan   sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp_1d, impl->d_interp_1d, interp_length, e);
213097cc795SJames Wright   copy_events.push_back(copy_interp);
214bd882c8aSJames Wright 
215bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_grad_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
2161f4b1b45SUmesh Unnikrishnan   sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad_1d, impl->d_grad_1d, interp_length, e);
217097cc795SJames Wright   copy_events.push_back(copy_grad);
218bd882c8aSJames Wright 
219097cc795SJames Wright   CeedCallSycl(ceed, sycl::event::wait_and_throw(copy_events));
220bd882c8aSJames Wright 
221bd882c8aSJames Wright   // Compute collocated gradient and copy to GPU
222bd882c8aSJames Wright   impl->d_collo_grad_1d          = NULL;
223bd882c8aSJames Wright   const bool has_collocated_grad = (dim == 3) && (Q_1d >= P_1d);
224dd64fc84SJeremy L Thompson 
225bd882c8aSJames Wright   if (has_collocated_grad) {
226bd882c8aSJames Wright     CeedScalar   *collo_grad_1d;
227dd64fc84SJeremy L Thompson     const CeedInt cgrad_length = Q_1d * Q_1d;
228dd64fc84SJeremy L Thompson 
229bd882c8aSJames Wright     CeedCallBackend(CeedMalloc(Q_1d * Q_1d, &collo_grad_1d));
230bd882c8aSJames Wright     CeedCallBackend(CeedBasisGetCollocatedGrad(basis, collo_grad_1d));
231bd882c8aSJames Wright     CeedCallSycl(ceed, impl->d_collo_grad_1d = sycl::malloc_device<CeedScalar>(cgrad_length, data->sycl_device, data->sycl_context));
2321f4b1b45SUmesh Unnikrishnan     CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(collo_grad_1d, impl->d_collo_grad_1d, cgrad_length, e).wait_and_throw());
233bd882c8aSJames Wright     CeedCallBackend(CeedFree(&collo_grad_1d));
234bd882c8aSJames Wright   }
235bd882c8aSJames Wright 
236bd882c8aSJames Wright   // ---[Refactor into separate function]------>
237bd882c8aSJames Wright   // Define compile-time constants
238bd882c8aSJames Wright   std::map<std::string, CeedInt> jit_constants;
239bd882c8aSJames Wright   jit_constants["BASIS_DIM"]                 = dim;
240bd882c8aSJames Wright   jit_constants["BASIS_Q_1D"]                = Q_1d;
241bd882c8aSJames Wright   jit_constants["BASIS_P_1D"]                = P_1d;
242bd882c8aSJames Wright   jit_constants["T_1D"]                      = thread_1d;
243bd882c8aSJames Wright   jit_constants["BASIS_NUM_COMP"]            = num_comp;
244bd882c8aSJames Wright   jit_constants["BASIS_NUM_NODES"]           = num_nodes;
245bd882c8aSJames Wright   jit_constants["BASIS_NUM_QPTS"]            = num_qpts;
246bd882c8aSJames Wright   jit_constants["BASIS_HAS_COLLOCATED_GRAD"] = has_collocated_grad;
247bd882c8aSJames Wright   jit_constants["BASIS_INTERP_SCRATCH_SIZE"] = interp_group_size;
248bd882c8aSJames Wright   jit_constants["BASIS_GRAD_SCRATCH_SIZE"]   = grad_group_size;
249bd882c8aSJames Wright 
250bd882c8aSJames Wright   // Load kernel source
251bd882c8aSJames Wright   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-shared-basis-tensor.h", &basis_kernel_path));
25223d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n");
25322070f95SJeremy L Thompson   {
25422070f95SJeremy L Thompson     char *source;
25522070f95SJeremy L Thompson 
25622070f95SJeremy L Thompson     CeedCallBackend(CeedLoadSourceToBuffer(ceed, basis_kernel_path, &source));
25722070f95SJeremy L Thompson     basis_kernel_source = source;
25822070f95SJeremy L Thompson   }
25923d4529eSJeremy L Thompson   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete -----\n");
260bd882c8aSJames Wright 
261bd882c8aSJames Wright   // Compile kernels into a kernel bundle
262eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedBuildModule_Sycl(ceed, basis_kernel_source, &impl->sycl_module, jit_constants));
263bd882c8aSJames Wright 
264bd882c8aSJames Wright   // Load kernel functions
265eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "Interp", &impl->interp_kernel));
266eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "InterpTranspose", &impl->interp_transpose_kernel));
267eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "Grad", &impl->grad_kernel));
268eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "GradTranspose", &impl->grad_transpose_kernel));
269eb7e6cafSJeremy L Thompson   CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "Weight", &impl->weight_kernel));
270bd882c8aSJames Wright 
271bd882c8aSJames Wright   // Clean-up
272bd882c8aSJames Wright   CeedCallBackend(CeedFree(&basis_kernel_path));
273bd882c8aSJames Wright   CeedCallBackend(CeedFree(&basis_kernel_source));
274bd882c8aSJames Wright   // <---[Refactor into separate function]------
275bd882c8aSJames Wright 
276bd882c8aSJames Wright   CeedCallBackend(CeedBasisSetData(basis, impl));
277bd882c8aSJames Wright 
278bd882c8aSJames Wright   // Register backend functions
279bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApplyTensor_Sycl_shared));
280bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Sycl_shared));
2819bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
282bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
283bd882c8aSJames Wright }
284ff1e7120SSebastian Grimberg 
285bd882c8aSJames Wright //------------------------------------------------------------------------------
286