xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunction.sycl.cpp (revision 82138112808ac45c6722ef2bfe52ea5cd96df80f)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other
2 // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3 // files for details.
4 //
5 // SPDX-License-Identifier: BSD-2-Clause
6 //
7 // This file is part of CEED:  http://github.com/ceed
8 
9 #include <ceed/backend.h>
10 #include <ceed/ceed.h>
11 
12 #include <string>
13 #include <sycl/sycl.hpp>
14 #include <vector>
15 
16 #include "../sycl/ceed-sycl-common.hpp"
17 #include "../sycl/ceed-sycl-compile.hpp"
18 #include "ceed-sycl-ref-qfunction-load.hpp"
19 #include "ceed-sycl-ref.hpp"
20 
21 #define WG_SIZE_QF 384
22 
23 //------------------------------------------------------------------------------
24 // Apply QFunction
25 //------------------------------------------------------------------------------
26 static int CeedQFunctionApply_Sycl(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) {
27   Ceed                ceed;
28   Ceed_Sycl          *ceed_Sycl;
29   void               *context_data;
30   CeedInt             num_input_fields, num_output_fields;
31   CeedQFunction_Sycl *impl;
32 
33   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
34 
35   // Build and compile kernel, if not done
36   if (!impl->QFunction) CeedCallBackend(CeedQFunctionBuildKernel_Sycl(qf));
37 
38   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
39   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
40   CeedCallBackend(CeedDestroy(&ceed));
41 
42   CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
43 
44   // Read vectors
45   std::vector<const CeedScalar *> inputs(num_input_fields);
46   const CeedVector               *U_i = U;
47   for (auto &input_i : inputs) {
48     CeedCallBackend(CeedVectorGetArrayRead(*U_i, CEED_MEM_DEVICE, &input_i));
49     ++U_i;
50   }
51 
52   std::vector<CeedScalar *> outputs(num_output_fields);
53   CeedVector               *V_i = V;
54   for (auto &output_i : outputs) {
55     CeedCallBackend(CeedVectorGetArrayWrite(*V_i, CEED_MEM_DEVICE, &output_i));
56     ++V_i;
57   }
58 
59   // Get context data
60   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &context_data));
61 
62   std::vector<sycl::event> e;
63 
64   if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()};
65 
66   // Launch as a basic parallel_for over Q quadrature points
67   ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
68     cgh.depends_on(e);
69 
70     int iarg{};
71     cgh.set_arg(iarg, context_data);
72     ++iarg;
73     cgh.set_arg(iarg, Q);
74     ++iarg;
75     for (auto &input_i : inputs) {
76       cgh.set_arg(iarg, input_i);
77       ++iarg;
78     }
79     for (auto &output_i : outputs) {
80       cgh.set_arg(iarg, output_i);
81       ++iarg;
82     }
83     // Hard-coding the work-group size for now
84     // We could use the Level Zero API to query and set an appropriate size in future
85     // Equivalent of CUDA Occupancy Calculator
86     int               wg_size   = WG_SIZE_QF;
87     sycl::range<1>    rounded_Q = ((Q + (wg_size - 1)) / wg_size) * wg_size;
88     sycl::nd_range<1> kernel_range(rounded_Q, wg_size);
89     cgh.parallel_for(kernel_range, *(impl->QFunction));
90   });
91 
92   // Restore vectors
93   U_i = U;
94   for (auto &input_i : inputs) {
95     CeedCallBackend(CeedVectorRestoreArrayRead(*U_i, &input_i));
96     ++U_i;
97   }
98 
99   V_i = V;
100   for (auto &output_i : outputs) {
101     CeedCallBackend(CeedVectorRestoreArray(*V_i, &output_i));
102     ++V_i;
103   }
104 
105   // Restore context
106   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &context_data));
107   return CEED_ERROR_SUCCESS;
108 }
109 
110 //------------------------------------------------------------------------------
111 // Destroy QFunction
112 //------------------------------------------------------------------------------
113 static int CeedQFunctionDestroy_Sycl(CeedQFunction qf) {
114   Ceed                ceed;
115   CeedQFunction_Sycl *impl;
116 
117   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
118   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
119   delete impl->QFunction;
120   delete impl->sycl_module;
121   CeedCallBackend(CeedFree(&impl));
122   CeedCallBackend(CeedDestroy(&ceed));
123   return CEED_ERROR_SUCCESS;
124 }
125 
126 //------------------------------------------------------------------------------
127 // Create QFunction
128 //------------------------------------------------------------------------------
129 int CeedQFunctionCreate_Sycl(CeedQFunction qf) {
130   Ceed                ceed;
131   CeedQFunction_Sycl *impl;
132 
133   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
134   CeedCallBackend(CeedCalloc(1, &impl));
135   CeedCallBackend(CeedQFunctionSetData(qf, impl));
136   // Register backend functions
137   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Sycl));
138   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Sycl));
139   CeedCallBackend(CeedDestroy(&ceed));
140   return CEED_ERROR_SUCCESS;
141 }
142 
143 //------------------------------------------------------------------------------
144