xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-shared/ceed-sycl-shared-basis.sycl.cpp (revision bd882c8a454763a096666645dc9a6229d5263694)
1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*bd882c8aSJames Wright //
4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5*bd882c8aSJames Wright //
6*bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7*bd882c8aSJames Wright 
8*bd882c8aSJames Wright #include <ceed/backend.h>
9*bd882c8aSJames Wright #include <ceed/ceed.h>
10*bd882c8aSJames Wright #include <ceed/jit-tools.h>
11*bd882c8aSJames Wright 
12*bd882c8aSJames Wright #include <map>
13*bd882c8aSJames Wright #include <string_view>
14*bd882c8aSJames Wright #include <sycl/sycl.hpp>
15*bd882c8aSJames Wright 
16*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
17*bd882c8aSJames Wright #include "ceed-sycl-shared.hpp"
18*bd882c8aSJames Wright 
19*bd882c8aSJames Wright //------------------------------------------------------------------------------
20*bd882c8aSJames Wright // Compute the local range of for basis kernels
21*bd882c8aSJames Wright //------------------------------------------------------------------------------
22*bd882c8aSJames Wright static int ComputeLocalRange(Ceed ceed, CeedInt dim, CeedInt thread_1d, CeedInt *local_range, CeedInt max_group_size = 256) {
23*bd882c8aSJames Wright   local_range[0] = thread_1d;
24*bd882c8aSJames Wright   local_range[1] = (dim > 1) ? thread_1d : 1;
25*bd882c8aSJames Wright 
26*bd882c8aSJames Wright   const CeedInt min_group_size = local_range[0] * local_range[1];
27*bd882c8aSJames Wright   CeedCheck(min_group_size <= max_group_size, ceed, CEED_ERROR_BACKEND, "Requested group size is smaller than the required minimum.");
28*bd882c8aSJames Wright 
29*bd882c8aSJames Wright   local_range[2] = max_group_size / min_group_size;  // elements per group
30*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
31*bd882c8aSJames Wright }
32*bd882c8aSJames Wright 
33*bd882c8aSJames Wright //------------------------------------------------------------------------------
34*bd882c8aSJames Wright // Apply basis
35*bd882c8aSJames Wright //------------------------------------------------------------------------------
36*bd882c8aSJames Wright int CeedBasisApplyTensor_Sycl_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u,
37*bd882c8aSJames Wright                                      CeedVector v) {
38*bd882c8aSJames Wright   Ceed ceed;
39*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
40*bd882c8aSJames Wright   Ceed_Sycl *ceed_Sycl;
41*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
42*bd882c8aSJames Wright   CeedBasis_Sycl_shared *impl;
43*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetData(basis, &impl));
44*bd882c8aSJames Wright 
45*bd882c8aSJames Wright   // Read vectors
46*bd882c8aSJames Wright   const CeedScalar *d_u;
47*bd882c8aSJames Wright   CeedScalar       *d_v;
48*bd882c8aSJames Wright   if (eval_mode != CEED_EVAL_WEIGHT) {
49*bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
50*bd882c8aSJames Wright   }
51*bd882c8aSJames Wright   CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
52*bd882c8aSJames Wright 
53*bd882c8aSJames Wright   // Apply basis operation
54*bd882c8aSJames Wright   switch (eval_mode) {
55*bd882c8aSJames Wright     case CEED_EVAL_INTERP: {
56*bd882c8aSJames Wright       CeedInt       *lrange         = impl->interp_local_range;
57*bd882c8aSJames Wright       const CeedInt &elem_per_group = lrange[2];
58*bd882c8aSJames Wright       const CeedInt  group_count    = (num_elem / elem_per_group) + !!(num_elem % elem_per_group);
59*bd882c8aSJames Wright       //-----------
60*bd882c8aSJames Wright       sycl::range<3>    local_range(lrange[2], lrange[1], lrange[0]);
61*bd882c8aSJames Wright       sycl::range<3>    global_range(group_count * lrange[2], lrange[1], lrange[0]);
62*bd882c8aSJames Wright       sycl::nd_range<3> kernel_range(global_range, local_range);
63*bd882c8aSJames Wright       //-----------
64*bd882c8aSJames Wright       sycl::kernel *interp_kernel = (t_mode == CEED_TRANSPOSE) ? impl->interp_transpose_kernel : impl->interp_kernel;
65*bd882c8aSJames Wright 
66*bd882c8aSJames Wright       // Order queue
67*bd882c8aSJames Wright       sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
68*bd882c8aSJames Wright 
69*bd882c8aSJames Wright       ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
70*bd882c8aSJames Wright         cgh.depends_on(e);
71*bd882c8aSJames Wright         cgh.set_args(num_elem, impl->d_interp_1d, d_u, d_v);
72*bd882c8aSJames Wright         cgh.parallel_for(kernel_range, *interp_kernel);
73*bd882c8aSJames Wright       });
74*bd882c8aSJames Wright 
75*bd882c8aSJames Wright     } break;
76*bd882c8aSJames Wright     case CEED_EVAL_GRAD: {
77*bd882c8aSJames Wright       CeedInt       *lrange         = impl->grad_local_range;
78*bd882c8aSJames Wright       const CeedInt &elem_per_group = lrange[2];
79*bd882c8aSJames Wright       const CeedInt  group_count    = (num_elem / elem_per_group) + !!(num_elem % elem_per_group);
80*bd882c8aSJames Wright       //-----------
81*bd882c8aSJames Wright       sycl::range<3>    local_range(lrange[2], lrange[1], lrange[0]);
82*bd882c8aSJames Wright       sycl::range<3>    global_range(group_count * lrange[2], lrange[1], lrange[0]);
83*bd882c8aSJames Wright       sycl::nd_range<3> kernel_range(global_range, local_range);
84*bd882c8aSJames Wright       //-----------
85*bd882c8aSJames Wright       sycl::kernel     *grad_kernel = (t_mode == CEED_TRANSPOSE) ? impl->grad_transpose_kernel : impl->grad_kernel;
86*bd882c8aSJames Wright       const CeedScalar *d_grad_1d   = (impl->d_collo_grad_1d) ? impl->d_collo_grad_1d : impl->d_grad_1d;
87*bd882c8aSJames Wright       // Order queue
88*bd882c8aSJames Wright       sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
89*bd882c8aSJames Wright 
90*bd882c8aSJames Wright       ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
91*bd882c8aSJames Wright         cgh.depends_on(e);
92*bd882c8aSJames Wright         cgh.set_args(num_elem, impl->d_interp_1d, d_grad_1d, d_u, d_v);
93*bd882c8aSJames Wright         cgh.parallel_for(kernel_range, *grad_kernel);
94*bd882c8aSJames Wright       });
95*bd882c8aSJames Wright     } break;
96*bd882c8aSJames Wright     case CEED_EVAL_WEIGHT: {
97*bd882c8aSJames Wright       CeedInt       *lrange         = impl->weight_local_range;
98*bd882c8aSJames Wright       const CeedInt &elem_per_group = lrange[2];
99*bd882c8aSJames Wright       const CeedInt  group_count    = (num_elem / elem_per_group) + !!(num_elem % elem_per_group);
100*bd882c8aSJames Wright       //-----------
101*bd882c8aSJames Wright       sycl::range<3>    local_range(lrange[2], lrange[1], lrange[0]);
102*bd882c8aSJames Wright       sycl::range<3>    global_range(group_count * lrange[2], lrange[1], lrange[0]);
103*bd882c8aSJames Wright       sycl::nd_range<3> kernel_range(global_range, local_range);
104*bd882c8aSJames Wright       //-----------
105*bd882c8aSJames Wright       // Order queue
106*bd882c8aSJames Wright       sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
107*bd882c8aSJames Wright 
108*bd882c8aSJames Wright       ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
109*bd882c8aSJames Wright         cgh.depends_on(e);
110*bd882c8aSJames Wright         cgh.set_args(num_elem, impl->d_q_weight_1d, d_v);
111*bd882c8aSJames Wright         cgh.parallel_for(kernel_range, *(impl->weight_kernel));
112*bd882c8aSJames Wright       });
113*bd882c8aSJames Wright     } break;
114*bd882c8aSJames Wright     // LCOV_EXCL_START
115*bd882c8aSJames Wright     // Evaluate the divergence to/from the quadrature points
116*bd882c8aSJames Wright     case CEED_EVAL_DIV:
117*bd882c8aSJames Wright       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
118*bd882c8aSJames Wright     // Evaluate the curl to/from the quadrature points
119*bd882c8aSJames Wright     case CEED_EVAL_CURL:
120*bd882c8aSJames Wright       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
121*bd882c8aSJames Wright     // Take no action, BasisApply should not have been called
122*bd882c8aSJames Wright     case CEED_EVAL_NONE:
123*bd882c8aSJames Wright       return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context");
124*bd882c8aSJames Wright       // LCOV_EXCL_STOP
125*bd882c8aSJames Wright   }
126*bd882c8aSJames Wright 
127*bd882c8aSJames Wright   // Restore vectors
128*bd882c8aSJames Wright   if (eval_mode != CEED_EVAL_WEIGHT) {
129*bd882c8aSJames Wright     CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
130*bd882c8aSJames Wright   }
131*bd882c8aSJames Wright   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
132*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
133*bd882c8aSJames Wright }
134*bd882c8aSJames Wright 
135*bd882c8aSJames Wright //------------------------------------------------------------------------------
136*bd882c8aSJames Wright // Destroy basis
137*bd882c8aSJames Wright //------------------------------------------------------------------------------
138*bd882c8aSJames Wright static int CeedBasisDestroy_Sycl_shared(CeedBasis basis) {
139*bd882c8aSJames Wright   Ceed ceed;
140*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
141*bd882c8aSJames Wright   CeedBasis_Sycl_shared *impl;
142*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetData(basis, &impl));
143*bd882c8aSJames Wright   Ceed_Sycl *data;
144*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
145*bd882c8aSJames Wright 
146*bd882c8aSJames Wright   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
147*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_q_weight_1d, data->sycl_context));
148*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_interp_1d, data->sycl_context));
149*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_grad_1d, data->sycl_context));
150*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::free(impl->d_collo_grad_1d, data->sycl_context));
151*bd882c8aSJames Wright 
152*bd882c8aSJames Wright   delete impl->interp_kernel;
153*bd882c8aSJames Wright   delete impl->interp_transpose_kernel;
154*bd882c8aSJames Wright   delete impl->grad_kernel;
155*bd882c8aSJames Wright   delete impl->grad_transpose_kernel;
156*bd882c8aSJames Wright   delete impl->weight_kernel;
157*bd882c8aSJames Wright   delete impl->sycl_module;
158*bd882c8aSJames Wright 
159*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
160*bd882c8aSJames Wright 
161*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
162*bd882c8aSJames Wright }
163*bd882c8aSJames Wright 
164*bd882c8aSJames Wright //------------------------------------------------------------------------------
165*bd882c8aSJames Wright // Create tensor basis
166*bd882c8aSJames Wright // TODO: Refactor
167*bd882c8aSJames Wright //------------------------------------------------------------------------------
168*bd882c8aSJames Wright int CeedBasisCreateTensorH1_Sycl_shared(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d,
169*bd882c8aSJames Wright                                         const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) {
170*bd882c8aSJames Wright   Ceed ceed;
171*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetCeed(basis, &ceed));
172*bd882c8aSJames Wright   CeedBasis_Sycl_shared *impl;
173*bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(1, &impl));
174*bd882c8aSJames Wright   Ceed_Sycl *data;
175*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &data));
176*bd882c8aSJames Wright 
177*bd882c8aSJames Wright   CeedInt num_comp;
178*bd882c8aSJames Wright   CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp));
179*bd882c8aSJames Wright 
180*bd882c8aSJames Wright   const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d);
181*bd882c8aSJames Wright   const CeedInt num_nodes = CeedIntPow(P_1d, dim);
182*bd882c8aSJames Wright   const CeedInt num_qpts  = CeedIntPow(Q_1d, dim);
183*bd882c8aSJames Wright 
184*bd882c8aSJames Wright   CeedInt *interp_lrange = impl->interp_local_range;
185*bd882c8aSJames Wright   CeedCallBackend(ComputeLocalRange(ceed, dim, thread_1d, interp_lrange));
186*bd882c8aSJames Wright   const CeedInt interp_group_size = interp_lrange[0] * interp_lrange[1] * interp_lrange[2];
187*bd882c8aSJames Wright 
188*bd882c8aSJames Wright   CeedInt *grad_lrange = impl->grad_local_range;
189*bd882c8aSJames Wright   CeedCallBackend(ComputeLocalRange(ceed, dim, thread_1d, grad_lrange));
190*bd882c8aSJames Wright   const CeedInt grad_group_size = grad_lrange[0] * grad_lrange[1] * grad_lrange[2];
191*bd882c8aSJames Wright 
192*bd882c8aSJames Wright   CeedCallBackend(ComputeLocalRange(ceed, dim, Q_1d, impl->weight_local_range));
193*bd882c8aSJames Wright 
194*bd882c8aSJames Wright   // Copy basis data to GPU
195*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_q_weight_1d = sycl::malloc_device<CeedScalar>(Q_1d, data->sycl_device, data->sycl_context));
196*bd882c8aSJames Wright   sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight_1d, impl->d_q_weight_1d, Q_1d);
197*bd882c8aSJames Wright 
198*bd882c8aSJames Wright   const CeedInt interp_length = Q_1d * P_1d;
199*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_interp_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
200*bd882c8aSJames Wright   sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp_1d, impl->d_interp_1d, interp_length);
201*bd882c8aSJames Wright 
202*bd882c8aSJames Wright   CeedCallSycl(ceed, impl->d_grad_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context));
203*bd882c8aSJames Wright   sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad_1d, impl->d_grad_1d, interp_length);
204*bd882c8aSJames Wright 
205*bd882c8aSJames Wright   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_weight, copy_interp, copy_grad}));
206*bd882c8aSJames Wright 
207*bd882c8aSJames Wright   // Compute collocated gradient and copy to GPU
208*bd882c8aSJames Wright   impl->d_collo_grad_1d          = NULL;
209*bd882c8aSJames Wright   const bool has_collocated_grad = (dim == 3) && (Q_1d >= P_1d);
210*bd882c8aSJames Wright   if (has_collocated_grad) {
211*bd882c8aSJames Wright     CeedScalar *collo_grad_1d;
212*bd882c8aSJames Wright     CeedCallBackend(CeedMalloc(Q_1d * Q_1d, &collo_grad_1d));
213*bd882c8aSJames Wright     CeedCallBackend(CeedBasisGetCollocatedGrad(basis, collo_grad_1d));
214*bd882c8aSJames Wright     const CeedInt cgrad_length = Q_1d * Q_1d;
215*bd882c8aSJames Wright     CeedCallSycl(ceed, impl->d_collo_grad_1d = sycl::malloc_device<CeedScalar>(cgrad_length, data->sycl_device, data->sycl_context));
216*bd882c8aSJames Wright     CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(collo_grad_1d, impl->d_collo_grad_1d, cgrad_length).wait_and_throw());
217*bd882c8aSJames Wright     CeedCallBackend(CeedFree(&collo_grad_1d));
218*bd882c8aSJames Wright   }
219*bd882c8aSJames Wright 
220*bd882c8aSJames Wright   // ---[Refactor into separate function]------>
221*bd882c8aSJames Wright   // Define compile-time constants
222*bd882c8aSJames Wright   std::map<std::string, CeedInt> jit_constants;
223*bd882c8aSJames Wright   jit_constants["BASIS_DIM"]                 = dim;
224*bd882c8aSJames Wright   jit_constants["BASIS_Q_1D"]                = Q_1d;
225*bd882c8aSJames Wright   jit_constants["BASIS_P_1D"]                = P_1d;
226*bd882c8aSJames Wright   jit_constants["T_1D"]                      = thread_1d;
227*bd882c8aSJames Wright   jit_constants["BASIS_NUM_COMP"]            = num_comp;
228*bd882c8aSJames Wright   jit_constants["BASIS_NUM_NODES"]           = num_nodes;
229*bd882c8aSJames Wright   jit_constants["BASIS_NUM_QPTS"]            = num_qpts;
230*bd882c8aSJames Wright   jit_constants["BASIS_HAS_COLLOCATED_GRAD"] = has_collocated_grad;
231*bd882c8aSJames Wright   jit_constants["BASIS_INTERP_SCRATCH_SIZE"] = interp_group_size;
232*bd882c8aSJames Wright   jit_constants["BASIS_GRAD_SCRATCH_SIZE"]   = grad_group_size;
233*bd882c8aSJames Wright 
234*bd882c8aSJames Wright   // Load kernel source
235*bd882c8aSJames Wright   char *basis_kernel_path, *basis_kernel_source;
236*bd882c8aSJames Wright   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-shared-basis-tensor.h", &basis_kernel_path));
237*bd882c8aSJames Wright   CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source -----\n");
238*bd882c8aSJames Wright   CeedCallBackend(CeedLoadSourceToBuffer(ceed, basis_kernel_path, &basis_kernel_source));
239*bd882c8aSJames Wright   CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source Complete -----\n");
240*bd882c8aSJames Wright 
241*bd882c8aSJames Wright   // Compile kernels into a kernel bundle
242*bd882c8aSJames Wright   CeedCallBackend(CeedJitBuildModule_Sycl(ceed, basis_kernel_source, &impl->sycl_module, jit_constants));
243*bd882c8aSJames Wright 
244*bd882c8aSJames Wright   // Load kernel functions
245*bd882c8aSJames Wright   CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, "Interp", &impl->interp_kernel));
246*bd882c8aSJames Wright   CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, "InterpTranspose", &impl->interp_transpose_kernel));
247*bd882c8aSJames Wright   CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, "Grad", &impl->grad_kernel));
248*bd882c8aSJames Wright   CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, "GradTranspose", &impl->grad_transpose_kernel));
249*bd882c8aSJames Wright   CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, "Weight", &impl->weight_kernel));
250*bd882c8aSJames Wright 
251*bd882c8aSJames Wright   // Clean-up
252*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&basis_kernel_path));
253*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&basis_kernel_source));
254*bd882c8aSJames Wright   // <---[Refactor into separate function]------
255*bd882c8aSJames Wright 
256*bd882c8aSJames Wright   CeedCallBackend(CeedBasisSetData(basis, impl));
257*bd882c8aSJames Wright 
258*bd882c8aSJames Wright   // Register backend functions
259*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApplyTensor_Sycl_shared));
260*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Sycl_shared));
261*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
262*bd882c8aSJames Wright }
263*bd882c8aSJames Wright //------------------------------------------------------------------------------
264