xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunction.sycl.cpp (revision c7b67790f45ae71043f73ea7aa7684f20189bfc2)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other
2 // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3 // files for details.
4 //
5 // SPDX-License-Identifier: BSD-2-Clause
6 //
7 // This file is part of CEED:  http://github.com/ceed
8 
9 #include <ceed/backend.h>
10 #include <ceed/ceed.h>
11 
12 #include <string>
13 #include <sycl/sycl.hpp>
14 #include <vector>
15 
16 #include "../sycl/ceed-sycl-common.hpp"
17 #include "../sycl/ceed-sycl-compile.hpp"
18 #include "ceed-sycl-ref-qfunction-load.hpp"
19 #include "ceed-sycl-ref.hpp"
20 
21 #define WG_SIZE_QF 384
22 
23 //------------------------------------------------------------------------------
24 // Apply QFunction
25 //------------------------------------------------------------------------------
26 static int CeedQFunctionApply_Sycl(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) {
27   Ceed                ceed;
28   Ceed_Sycl          *ceed_Sycl;
29   void               *context_data;
30   CeedInt             num_input_fields, num_output_fields;
31   CeedQFunction_Sycl *impl;
32 
33   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
34 
35   // Build and compile kernel, if not done
36   if (!impl->QFunction) CeedCallBackend(CeedQFunctionBuildKernel_Sycl(qf));
37 
38   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
39   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
40 
41   CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
42 
43   // Read vectors
44   std::vector<const CeedScalar *> inputs(num_input_fields);
45   const CeedVector               *U_i = U;
46   for (auto &input_i : inputs) {
47     CeedCallBackend(CeedVectorGetArrayRead(*U_i, CEED_MEM_DEVICE, &input_i));
48     ++U_i;
49   }
50 
51   std::vector<CeedScalar *> outputs(num_output_fields);
52   CeedVector               *V_i = V;
53   for (auto &output_i : outputs) {
54     CeedCallBackend(CeedVectorGetArrayWrite(*V_i, CEED_MEM_DEVICE, &output_i));
55     ++V_i;
56   }
57 
58   // Get context data
59   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &context_data));
60 
61   std::vector<sycl::event> e;
62 
63   if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()};
64 
65   // Launch as a basic parallel_for over Q quadrature points
66   ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
67     cgh.depends_on(e);
68 
69     int iarg{};
70     cgh.set_arg(iarg, context_data);
71     ++iarg;
72     cgh.set_arg(iarg, Q);
73     ++iarg;
74     for (auto &input_i : inputs) {
75       cgh.set_arg(iarg, input_i);
76       ++iarg;
77     }
78     for (auto &output_i : outputs) {
79       cgh.set_arg(iarg, output_i);
80       ++iarg;
81     }
82     // Hard-coding the work-group size for now
83     // We could use the Level Zero API to query and set an appropriate size in future
84     // Equivalent of CUDA Occupancy Calculator
85     int               wg_size   = WG_SIZE_QF;
86     sycl::range<1>    rounded_Q = ((Q + (wg_size - 1)) / wg_size) * wg_size;
87     sycl::nd_range<1> kernel_range(rounded_Q, wg_size);
88     cgh.parallel_for(kernel_range, *(impl->QFunction));
89   });
90 
91   // Restore vectors
92   U_i = U;
93   for (auto &input_i : inputs) {
94     CeedCallBackend(CeedVectorRestoreArrayRead(*U_i, &input_i));
95     ++U_i;
96   }
97 
98   V_i = V;
99   for (auto &output_i : outputs) {
100     CeedCallBackend(CeedVectorRestoreArray(*V_i, &output_i));
101     ++V_i;
102   }
103 
104   // Restore context
105   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &context_data));
106   return CEED_ERROR_SUCCESS;
107 }
108 
109 //------------------------------------------------------------------------------
110 // Destroy QFunction
111 //------------------------------------------------------------------------------
112 static int CeedQFunctionDestroy_Sycl(CeedQFunction qf) {
113   Ceed                ceed;
114   CeedQFunction_Sycl *impl;
115 
116   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
117   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
118   delete impl->QFunction;
119   delete impl->sycl_module;
120   CeedCallBackend(CeedFree(&impl));
121   return CEED_ERROR_SUCCESS;
122 }
123 
124 //------------------------------------------------------------------------------
125 // Create QFunction
126 //------------------------------------------------------------------------------
127 int CeedQFunctionCreate_Sycl(CeedQFunction qf) {
128   Ceed                ceed;
129   CeedQFunction_Sycl *impl;
130 
131   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
132   CeedCallBackend(CeedCalloc(1, &impl));
133   CeedCallBackend(CeedQFunctionSetData(qf, impl));
134   // Register backend functions
135   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Sycl));
136   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Sycl));
137   return CEED_ERROR_SUCCESS;
138 }
139 
140 //------------------------------------------------------------------------------
141