xref: /libCEED/backends/sycl-ref/ceed-sycl-ref-qfunction.sycl.cpp (revision 8e6aa226c2c84e58dd7feb551fd506c4f25986db)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other
2 // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3 // files for details.
4 //
5 // SPDX-License-Identifier: BSD-2-Clause
6 //
7 // This file is part of CEED:  http://github.com/ceed
8 
9 #include <ceed/backend.h>
10 #include <ceed/ceed.h>
11 
12 #include <string>
13 #include <sycl/sycl.hpp>
14 #include <vector>
15 
16 #include "../sycl/ceed-sycl-common.hpp"
17 #include "../sycl/ceed-sycl-compile.hpp"
18 #include "ceed-sycl-ref-qfunction-load.hpp"
19 #include "ceed-sycl-ref.hpp"
20 
21 #define WG_SIZE_QF 384
22 
23 //------------------------------------------------------------------------------
24 // Apply QFunction
25 //------------------------------------------------------------------------------
26 static int CeedQFunctionApply_Sycl(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) {
27   CeedQFunction_Sycl *impl;
28   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
29 
30   // Build and compile kernel, if not done
31   if (!impl->QFunction) CeedCallBackend(CeedQFunctionBuildKernel_Sycl(qf));
32 
33   Ceed ceed;
34   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
35   Ceed_Sycl *ceed_Sycl;
36   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
37 
38   CeedInt num_input_fields, num_output_fields;
39   CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields));
40 
41   // Read vectors
42   std::vector<const CeedScalar *> inputs(num_input_fields);
43   const CeedVector               *U_i = U;
44   for (auto &input_i : inputs) {
45     CeedCallBackend(CeedVectorGetArrayRead(*U_i, CEED_MEM_DEVICE, &input_i));
46     ++U_i;
47   }
48 
49   std::vector<CeedScalar *> outputs(num_output_fields);
50   CeedVector               *V_i = V;
51   for (auto &output_i : outputs) {
52     CeedCallBackend(CeedVectorGetArrayWrite(*V_i, CEED_MEM_DEVICE, &output_i));
53     ++V_i;
54   }
55 
56   // Get context data
57   void *context_data;
58   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &context_data));
59 
60   // Order queue
61   sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
62 
63   // Launch as a basic parallel_for over Q quadrature points
64   ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
65     cgh.depends_on({e});
66 
67     int iarg{};
68     cgh.set_arg(iarg, context_data);
69     ++iarg;
70     cgh.set_arg(iarg, Q);
71     ++iarg;
72     for (auto &input_i : inputs) {
73       cgh.set_arg(iarg, input_i);
74       ++iarg;
75     }
76     for (auto &output_i : outputs) {
77       cgh.set_arg(iarg, output_i);
78       ++iarg;
79     }
80     // Hard-coding the work-group size for now
81     // We could use the Level Zero API to query and set an appropriate size in future
82     // Equivalent of CUDA Occupancy Calculator
83     int               wg_size   = WG_SIZE_QF;
84     sycl::range<1>    rounded_Q = ((Q + (wg_size - 1)) / wg_size) * wg_size;
85     sycl::nd_range<1> kernel_range(rounded_Q, wg_size);
86     cgh.parallel_for(kernel_range, *(impl->QFunction));
87   });
88 
89   // Restore vectors
90   U_i = U;
91   for (auto &input_i : inputs) {
92     CeedCallBackend(CeedVectorRestoreArrayRead(*U_i, &input_i));
93     ++U_i;
94   }
95 
96   V_i = V;
97   for (auto &output_i : outputs) {
98     CeedCallBackend(CeedVectorRestoreArray(*V_i, &output_i));
99     ++V_i;
100   }
101 
102   // Restore context
103   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &context_data));
104 
105   return CEED_ERROR_SUCCESS;
106 }
107 
108 //------------------------------------------------------------------------------
109 // Destroy QFunction
110 //------------------------------------------------------------------------------
111 static int CeedQFunctionDestroy_Sycl(CeedQFunction qf) {
112   CeedQFunction_Sycl *impl;
113   CeedCallBackend(CeedQFunctionGetData(qf, &impl));
114 
115   Ceed ceed;
116   CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed));
117 
118   delete impl->QFunction;
119   delete impl->sycl_module;
120 
121   CeedCallBackend(CeedFree(&impl));
122 
123   return CEED_ERROR_SUCCESS;
124 }
125 
126 //------------------------------------------------------------------------------
127 // Create QFunction
128 //------------------------------------------------------------------------------
129 int CeedQFunctionCreate_Sycl(CeedQFunction qf) {
130   Ceed ceed;
131   CeedQFunctionGetCeed(qf, &ceed);
132   CeedQFunction_Sycl *impl;
133 
134   CeedCallBackend(CeedCalloc(1, &impl));
135   CeedCallBackend(CeedQFunctionSetData(qf, impl));
136 
137   // Register backend functions
138   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Sycl));
139   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Sycl));
140   return CEED_ERROR_SUCCESS;
141 }
142 //------------------------------------------------------------------------------
143