xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-ref/ceed-sycl-ref-qfunction.sycl.cpp (revision bd882c8a454763a096666645dc9a6229d5263694)
1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other
2*bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3*bd882c8aSJames Wright // files for details.
4*bd882c8aSJames Wright //
5*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
6*bd882c8aSJames Wright //
7*bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
8*bd882c8aSJames Wright 
9*bd882c8aSJames Wright #include <ceed/backend.h>
10*bd882c8aSJames Wright #include <ceed/ceed.h>
11*bd882c8aSJames Wright 
12*bd882c8aSJames Wright #include <string>
13*bd882c8aSJames Wright #include <sycl/sycl.hpp>
14*bd882c8aSJames Wright #include <vector>
15*bd882c8aSJames Wright 
16*bd882c8aSJames Wright #include "../sycl/ceed-sycl-common.hpp"
17*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
18*bd882c8aSJames Wright #include "ceed-sycl-ref-qfunction-load.hpp"
19*bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
20*bd882c8aSJames Wright 
21*bd882c8aSJames Wright #define WG_SIZE_QF 384
22*bd882c8aSJames Wright 
23*bd882c8aSJames Wright //------------------------------------------------------------------------------
24*bd882c8aSJames Wright // Apply QFunction
25*bd882c8aSJames Wright //------------------------------------------------------------------------------
26*bd882c8aSJames Wright static int CeedQFunctionApply_Sycl(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) {
27*bd882c8aSJames Wright   CeedQFunction_Sycl *impl;
28*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
29*bd882c8aSJames Wright 
30*bd882c8aSJames Wright   // Build and compile kernel, if not done
31*bd882c8aSJames Wright   if (!impl->QFunction) CeedCallBackend(CeedBuildQFunction_Sycl(qf));
32*bd882c8aSJames Wright 
33*bd882c8aSJames Wright   Ceed ceed;
34*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
35*bd882c8aSJames Wright   Ceed_Sycl *ceed_Sycl;
36*bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
37*bd882c8aSJames Wright 
38*bd882c8aSJames Wright   CeedInt num_input_fields, num_output_fields;
39*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
40*bd882c8aSJames Wright 
41*bd882c8aSJames Wright   // Read vectors
42*bd882c8aSJames Wright   std::vector<const CeedScalar *> inputs(num_input_fields);
43*bd882c8aSJames Wright   const CeedVector               *U_i = U;
44*bd882c8aSJames Wright   for (auto &input_i : inputs) {
45*bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArrayRead(*U_i, CEED_MEM_DEVICE, &input_i));
46*bd882c8aSJames Wright     ++U_i;
47*bd882c8aSJames Wright   }
48*bd882c8aSJames Wright 
49*bd882c8aSJames Wright   std::vector<CeedScalar *> outputs(num_output_fields);
50*bd882c8aSJames Wright   CeedVector               *V_i = V;
51*bd882c8aSJames Wright   for (auto &output_i : outputs) {
52*bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArrayWrite(*V_i, CEED_MEM_DEVICE, &output_i));
53*bd882c8aSJames Wright     ++V_i;
54*bd882c8aSJames Wright   }
55*bd882c8aSJames Wright 
56*bd882c8aSJames Wright   // Get context data
57*bd882c8aSJames Wright   void *context_data;
58*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &context_data));
59*bd882c8aSJames Wright 
60*bd882c8aSJames Wright   // Order queue
61*bd882c8aSJames Wright   sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
62*bd882c8aSJames Wright 
63*bd882c8aSJames Wright   // Launch as a basic parallel_for over Q quadrature points
64*bd882c8aSJames Wright   ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
65*bd882c8aSJames Wright     cgh.depends_on({e});
66*bd882c8aSJames Wright 
67*bd882c8aSJames Wright     int iarg{};
68*bd882c8aSJames Wright     cgh.set_arg(iarg, context_data);
69*bd882c8aSJames Wright     ++iarg;
70*bd882c8aSJames Wright     cgh.set_arg(iarg, Q);
71*bd882c8aSJames Wright     ++iarg;
72*bd882c8aSJames Wright     for (auto &input_i : inputs) {
73*bd882c8aSJames Wright       cgh.set_arg(iarg, input_i);
74*bd882c8aSJames Wright       ++iarg;
75*bd882c8aSJames Wright     }
76*bd882c8aSJames Wright     for (auto &output_i : outputs) {
77*bd882c8aSJames Wright       cgh.set_arg(iarg, output_i);
78*bd882c8aSJames Wright       ++iarg;
79*bd882c8aSJames Wright     }
80*bd882c8aSJames Wright     // Hard-coding the work-group size for now
81*bd882c8aSJames Wright     // We could use the Level Zero API to query and set an appropriate size in future
82*bd882c8aSJames Wright     // Equivalent of CUDA Occupancy Calculator
83*bd882c8aSJames Wright     int               wg_size   = WG_SIZE_QF;
84*bd882c8aSJames Wright     sycl::range<1>    rounded_Q = ((Q + (wg_size - 1)) / wg_size) * wg_size;
85*bd882c8aSJames Wright     sycl::nd_range<1> kernel_range(rounded_Q, wg_size);
86*bd882c8aSJames Wright     cgh.parallel_for(kernel_range, *(impl->QFunction));
87*bd882c8aSJames Wright   });
88*bd882c8aSJames Wright 
89*bd882c8aSJames Wright   // Restore vectors
90*bd882c8aSJames Wright   U_i = U;
91*bd882c8aSJames Wright   for (auto &input_i : inputs) {
92*bd882c8aSJames Wright     CeedCallBackend(CeedVectorRestoreArrayRead(*U_i, &input_i));
93*bd882c8aSJames Wright     ++U_i;
94*bd882c8aSJames Wright   }
95*bd882c8aSJames Wright 
96*bd882c8aSJames Wright   V_i = V;
97*bd882c8aSJames Wright   for (auto &output_i : outputs) {
98*bd882c8aSJames Wright     CeedCallBackend(CeedVectorRestoreArray(*V_i, &output_i));
99*bd882c8aSJames Wright     ++V_i;
100*bd882c8aSJames Wright   }
101*bd882c8aSJames Wright 
102*bd882c8aSJames Wright   // Restore context
103*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &context_data));
104*bd882c8aSJames Wright 
105*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
106*bd882c8aSJames Wright }
107*bd882c8aSJames Wright 
108*bd882c8aSJames Wright //------------------------------------------------------------------------------
109*bd882c8aSJames Wright // Destroy QFunction
110*bd882c8aSJames Wright //------------------------------------------------------------------------------
111*bd882c8aSJames Wright static int CeedQFunctionDestroy_Sycl(CeedQFunction qf) {
112*bd882c8aSJames Wright   CeedQFunction_Sycl *impl;
113*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
114*bd882c8aSJames Wright 
115*bd882c8aSJames Wright   Ceed ceed;
116*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
117*bd882c8aSJames Wright 
118*bd882c8aSJames Wright   delete impl->QFunction;
119*bd882c8aSJames Wright   delete impl->sycl_module;
120*bd882c8aSJames Wright 
121*bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
122*bd882c8aSJames Wright 
123*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
124*bd882c8aSJames Wright }
125*bd882c8aSJames Wright 
126*bd882c8aSJames Wright //------------------------------------------------------------------------------
127*bd882c8aSJames Wright // Create QFunction
128*bd882c8aSJames Wright //------------------------------------------------------------------------------
129*bd882c8aSJames Wright int CeedQFunctionCreate_Sycl(CeedQFunction qf) {
130*bd882c8aSJames Wright   Ceed ceed;
131*bd882c8aSJames Wright   CeedQFunctionGetCeed(qf, &ceed);
132*bd882c8aSJames Wright   CeedQFunction_Sycl *impl;
133*bd882c8aSJames Wright 
134*bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(1, &impl));
135*bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionSetData(qf, impl));
136*bd882c8aSJames Wright 
137*bd882c8aSJames Wright   // Register backend functions
138*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Sycl));
139*bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Sycl));
140*bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
141*bd882c8aSJames Wright }
142*bd882c8aSJames Wright //------------------------------------------------------------------------------
143