xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunction.sycl.cpp (revision 07d5b98a8feba68a643190b8ea9bcdac5c3e6570)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other
2 // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3 // files for details.
4 //
5 // SPDX-License-Identifier: BSD-2-Clause
6 //
7 // This file is part of CEED:  http://github.com/ceed
8 
9 #include <ceed/backend.h>
10 #include <ceed/ceed.h>
11 
12 #include <string>
13 #include <sycl/sycl.hpp>
14 #include <vector>
15 
16 #include "../sycl/ceed-sycl-common.hpp"
17 #include "../sycl/ceed-sycl-compile.hpp"
18 #include "ceed-sycl-ref-qfunction-load.hpp"
19 #include "ceed-sycl-ref.hpp"
20 
21 #define WG_SIZE_QF 384
22 
23 //------------------------------------------------------------------------------
24 // Apply QFunction
25 //------------------------------------------------------------------------------
26 static int CeedQFunctionApply_Sycl(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) {
27   Ceed                ceed;
28   Ceed_Sycl          *ceed_Sycl;
29   void               *context_data;
30   CeedInt             num_input_fields, num_output_fields;
31   CeedQFunction_Sycl *impl;
32 
33   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
34 
35   // Build and compile kernel, if not done
36   if (!impl->QFunction) CeedCallBackend(CeedQFunctionBuildKernel_Sycl(qf));
37 
38   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
39   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
40 
41   CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
42 
43   // Read vectors
44   std::vector<const CeedScalar *> inputs(num_input_fields);
45   const CeedVector               *U_i = U;
46   for (auto &input_i : inputs) {
47     CeedCallBackend(CeedVectorGetArrayRead(*U_i, CEED_MEM_DEVICE, &input_i));
48     ++U_i;
49   }
50 
51   std::vector<CeedScalar *> outputs(num_output_fields);
52   CeedVector               *V_i = V;
53   for (auto &output_i : outputs) {
54     CeedCallBackend(CeedVectorGetArrayWrite(*V_i, CEED_MEM_DEVICE, &output_i));
55     ++V_i;
56   }
57 
58   // Get context data
59   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &context_data));
60 
61   // Order queue
62   sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
63 
64   // Launch as a basic parallel_for over Q quadrature points
65   ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
66     cgh.depends_on({e});
67 
68     int iarg{};
69     cgh.set_arg(iarg, context_data);
70     ++iarg;
71     cgh.set_arg(iarg, Q);
72     ++iarg;
73     for (auto &input_i : inputs) {
74       cgh.set_arg(iarg, input_i);
75       ++iarg;
76     }
77     for (auto &output_i : outputs) {
78       cgh.set_arg(iarg, output_i);
79       ++iarg;
80     }
81     // Hard-coding the work-group size for now
82     // We could use the Level Zero API to query and set an appropriate size in future
83     // Equivalent of CUDA Occupancy Calculator
84     int               wg_size   = WG_SIZE_QF;
85     sycl::range<1>    rounded_Q = ((Q + (wg_size - 1)) / wg_size) * wg_size;
86     sycl::nd_range<1> kernel_range(rounded_Q, wg_size);
87     cgh.parallel_for(kernel_range, *(impl->QFunction));
88   });
89 
90   // Restore vectors
91   U_i = U;
92   for (auto &input_i : inputs) {
93     CeedCallBackend(CeedVectorRestoreArrayRead(*U_i, &input_i));
94     ++U_i;
95   }
96 
97   V_i = V;
98   for (auto &output_i : outputs) {
99     CeedCallBackend(CeedVectorRestoreArray(*V_i, &output_i));
100     ++V_i;
101   }
102 
103   // Restore context
104   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &context_data));
105   return CEED_ERROR_SUCCESS;
106 }
107 
108 //------------------------------------------------------------------------------
109 // Destroy QFunction
110 //------------------------------------------------------------------------------
111 static int CeedQFunctionDestroy_Sycl(CeedQFunction qf) {
112   Ceed                ceed;
113   CeedQFunction_Sycl *impl;
114 
115   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
116   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
117   delete impl->QFunction;
118   delete impl->sycl_module;
119   CeedCallBackend(CeedFree(&impl));
120   return CEED_ERROR_SUCCESS;
121 }
122 
123 //------------------------------------------------------------------------------
124 // Create QFunction
125 //------------------------------------------------------------------------------
126 int CeedQFunctionCreate_Sycl(CeedQFunction qf) {
127   Ceed                ceed;
128   CeedQFunction_Sycl *impl;
129 
130   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
131   CeedCallBackend(CeedCalloc(1, &impl));
132   CeedCallBackend(CeedQFunctionSetData(qf, impl));
133   // Register backend functions
134   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Sycl));
135   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Sycl));
136   return CEED_ERROR_SUCCESS;
137 }
138 
139 //------------------------------------------------------------------------------
140