xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunction.sycl.cpp (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other
2bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3bd882c8aSJames Wright // files for details.
4bd882c8aSJames Wright //
5bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
6bd882c8aSJames Wright //
7bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
8bd882c8aSJames Wright 
9bd882c8aSJames Wright #include <ceed/backend.h>
10bd882c8aSJames Wright #include <ceed/ceed.h>
11bd882c8aSJames Wright 
12bd882c8aSJames Wright #include <string>
13bd882c8aSJames Wright #include <sycl/sycl.hpp>
14bd882c8aSJames Wright #include <vector>
15bd882c8aSJames Wright 
16bd882c8aSJames Wright #include "../sycl/ceed-sycl-common.hpp"
17bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp"
18bd882c8aSJames Wright #include "ceed-sycl-ref-qfunction-load.hpp"
19bd882c8aSJames Wright #include "ceed-sycl-ref.hpp"
20bd882c8aSJames Wright 
21bd882c8aSJames Wright #define WG_SIZE_QF 384
22bd882c8aSJames Wright 
23bd882c8aSJames Wright //------------------------------------------------------------------------------
24bd882c8aSJames Wright // Apply QFunction
25bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedQFunctionApply_Sycl(CeedQFunction qf,CeedInt Q,CeedVector * U,CeedVector * V)26bd882c8aSJames Wright static int CeedQFunctionApply_Sycl(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) {
27dd64fc84SJeremy L Thompson   Ceed                ceed;
28dd64fc84SJeremy L Thompson   Ceed_Sycl          *ceed_Sycl;
29dd64fc84SJeremy L Thompson   void               *context_data;
30dd64fc84SJeremy L Thompson   CeedInt             num_input_fields, num_output_fields;
31bd882c8aSJames Wright   CeedQFunction_Sycl *impl;
32dd64fc84SJeremy L Thompson 
33bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
34bd882c8aSJames Wright 
35bd882c8aSJames Wright   // Build and compile kernel, if not done
36eb7e6cafSJeremy L Thompson   if (!impl->QFunction) CeedCallBackend(CeedQFunctionBuildKernel_Sycl(qf));
37bd882c8aSJames Wright 
38bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
39bd882c8aSJames Wright   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
409bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
41bd882c8aSJames Wright 
42bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
43bd882c8aSJames Wright 
44bd882c8aSJames Wright   // Read vectors
45bd882c8aSJames Wright   std::vector<const CeedScalar *> inputs(num_input_fields);
46bd882c8aSJames Wright   const CeedVector               *U_i = U;
47bd882c8aSJames Wright   for (auto &input_i : inputs) {
48bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArrayRead(*U_i, CEED_MEM_DEVICE, &input_i));
49bd882c8aSJames Wright     ++U_i;
50bd882c8aSJames Wright   }
51bd882c8aSJames Wright 
52bd882c8aSJames Wright   std::vector<CeedScalar *> outputs(num_output_fields);
53bd882c8aSJames Wright   CeedVector               *V_i = V;
54bd882c8aSJames Wright   for (auto &output_i : outputs) {
55bd882c8aSJames Wright     CeedCallBackend(CeedVectorGetArrayWrite(*V_i, CEED_MEM_DEVICE, &output_i));
56bd882c8aSJames Wright     ++V_i;
57bd882c8aSJames Wright   }
58bd882c8aSJames Wright 
59bd882c8aSJames Wright   // Get context data
60bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &context_data));
61bd882c8aSJames Wright 
621f4b1b45SUmesh Unnikrishnan   std::vector<sycl::event> e;
631f4b1b45SUmesh Unnikrishnan 
641f4b1b45SUmesh Unnikrishnan   if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()};
65bd882c8aSJames Wright 
66bd882c8aSJames Wright   // Launch as a basic parallel_for over Q quadrature points
67bd882c8aSJames Wright   ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
681f4b1b45SUmesh Unnikrishnan     cgh.depends_on(e);
69bd882c8aSJames Wright 
70bd882c8aSJames Wright     int iarg{};
71bd882c8aSJames Wright     cgh.set_arg(iarg, context_data);
72bd882c8aSJames Wright     ++iarg;
73bd882c8aSJames Wright     cgh.set_arg(iarg, Q);
74bd882c8aSJames Wright     ++iarg;
75bd882c8aSJames Wright     for (auto &input_i : inputs) {
76bd882c8aSJames Wright       cgh.set_arg(iarg, input_i);
77bd882c8aSJames Wright       ++iarg;
78bd882c8aSJames Wright     }
79bd882c8aSJames Wright     for (auto &output_i : outputs) {
80bd882c8aSJames Wright       cgh.set_arg(iarg, output_i);
81bd882c8aSJames Wright       ++iarg;
82bd882c8aSJames Wright     }
83bd882c8aSJames Wright     // Hard-coding the work-group size for now
84bd882c8aSJames Wright     // We could use the Level Zero API to query and set an appropriate size in future
85bd882c8aSJames Wright     // Equivalent of CUDA Occupancy Calculator
86bd882c8aSJames Wright     int               wg_size   = WG_SIZE_QF;
87bd882c8aSJames Wright     sycl::range<1>    rounded_Q = ((Q + (wg_size - 1)) / wg_size) * wg_size;
88bd882c8aSJames Wright     sycl::nd_range<1> kernel_range(rounded_Q, wg_size);
89bd882c8aSJames Wright     cgh.parallel_for(kernel_range, *(impl->QFunction));
90bd882c8aSJames Wright   });
91bd882c8aSJames Wright 
92bd882c8aSJames Wright   // Restore vectors
93bd882c8aSJames Wright   U_i = U;
94bd882c8aSJames Wright   for (auto &input_i : inputs) {
95bd882c8aSJames Wright     CeedCallBackend(CeedVectorRestoreArrayRead(*U_i, &input_i));
96bd882c8aSJames Wright     ++U_i;
97bd882c8aSJames Wright   }
98bd882c8aSJames Wright 
99bd882c8aSJames Wright   V_i = V;
100bd882c8aSJames Wright   for (auto &output_i : outputs) {
101bd882c8aSJames Wright     CeedCallBackend(CeedVectorRestoreArray(*V_i, &output_i));
102bd882c8aSJames Wright     ++V_i;
103bd882c8aSJames Wright   }
104bd882c8aSJames Wright 
105bd882c8aSJames Wright   // Restore context
106bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &context_data));
107bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
108bd882c8aSJames Wright }
109bd882c8aSJames Wright 
110bd882c8aSJames Wright //------------------------------------------------------------------------------
111bd882c8aSJames Wright // Destroy QFunction
112bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedQFunctionDestroy_Sycl(CeedQFunction qf)113bd882c8aSJames Wright static int CeedQFunctionDestroy_Sycl(CeedQFunction qf) {
114bd882c8aSJames Wright   Ceed                ceed;
115dd64fc84SJeremy L Thompson   CeedQFunction_Sycl *impl;
116bd882c8aSJames Wright 
117dd64fc84SJeremy L Thompson   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
118dd64fc84SJeremy L Thompson   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
119bd882c8aSJames Wright   delete impl->QFunction;
120bd882c8aSJames Wright   delete impl->sycl_module;
121bd882c8aSJames Wright   CeedCallBackend(CeedFree(&impl));
1229bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
123bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
124bd882c8aSJames Wright }
125bd882c8aSJames Wright 
126bd882c8aSJames Wright //------------------------------------------------------------------------------
127bd882c8aSJames Wright // Create QFunction
128bd882c8aSJames Wright //------------------------------------------------------------------------------
CeedQFunctionCreate_Sycl(CeedQFunction qf)129bd882c8aSJames Wright int CeedQFunctionCreate_Sycl(CeedQFunction qf) {
130bd882c8aSJames Wright   Ceed                ceed;
131bd882c8aSJames Wright   CeedQFunction_Sycl *impl;
132bd882c8aSJames Wright 
1336e536b99SJeremy L Thompson   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
134bd882c8aSJames Wright   CeedCallBackend(CeedCalloc(1, &impl));
135bd882c8aSJames Wright   CeedCallBackend(CeedQFunctionSetData(qf, impl));
136bd882c8aSJames Wright   // Register backend functions
137bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Sycl));
138bd882c8aSJames Wright   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Sycl));
1399bc66399SJeremy L Thompson   CeedCallBackend(CeedDestroy(&ceed));
140bd882c8aSJames Wright   return CEED_ERROR_SUCCESS;
141bd882c8aSJames Wright }
142ff1e7120SSebastian Grimberg 
143bd882c8aSJames Wright //------------------------------------------------------------------------------
144