xref: /libCEED/backends/sycl-gen/ceed-sycl-gen-operator.sycl.cpp (revision 0930e4e786557ee92f6edf1a51db869c03c3f660)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <stddef.h>
11 
12 #include "../sycl/ceed-sycl-compile.hpp"
13 #include "ceed-sycl-gen-operator-build.hpp"
14 #include "ceed-sycl-gen.hpp"
15 
16 //------------------------------------------------------------------------------
17 // Destroy operator
18 //------------------------------------------------------------------------------
19 static int CeedOperatorDestroy_Sycl_gen(CeedOperator op) {
20   CeedOperator_Sycl_gen *impl;
21 
22   CeedCallBackend(CeedOperatorGetData(op, &impl));
23   CeedCallBackend(CeedFree(&impl));
24   return CEED_ERROR_SUCCESS;
25 }
26 
27 //------------------------------------------------------------------------------
28 // Apply and add to output
29 //------------------------------------------------------------------------------
30 static int CeedOperatorApplyAdd_Sycl_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
31   Ceed                    ceed;
32   Ceed_Sycl              *ceed_Sycl;
33   CeedInt                 num_elem, num_input_fields, num_output_fields;
34   CeedEvalMode            eval_mode;
35   CeedVector              output_vecs[CEED_FIELD_MAX] = {};
36   CeedQFunctionField     *qf_input_fields, *qf_output_fields;
37   CeedQFunction_Sycl_gen *qf_impl;
38   CeedQFunction           qf;
39   CeedOperatorField      *op_input_fields, *op_output_fields;
40   CeedOperator_Sycl_gen  *impl;
41 
42   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
43   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
44   CeedCallBackend(CeedOperatorGetData(op, &impl));
45   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
46   CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl));
47   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
48   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
49   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
50 
51   // Creation of the operator
52   CeedCallBackend(CeedOperatorBuildKernel_Sycl_gen(op));
53 
54   // Input vectors
55   for (CeedInt i = 0; i < num_input_fields; i++) {
56     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
57     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
58       impl->fields->inputs[i] = NULL;
59     } else {
60       CeedVector vec;
61 
62       // Get input vector
63       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
64       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
65       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &impl->fields->inputs[i]));
66     }
67   }
68 
69   // Output vectors
70   for (CeedInt i = 0; i < num_output_fields; i++) {
71     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
72     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
73       impl->fields->outputs[i] = NULL;
74     } else {
75       CeedVector vec;
76 
77       // Get output vector
78       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
79       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
80       output_vecs[i] = vec;
81       // Check for multiple output modes
82       CeedInt index = -1;
83       for (CeedInt j = 0; j < i; j++) {
84         if (vec == output_vecs[j]) {
85           index = j;
86           break;
87         }
88       }
89       if (index == -1) {
90         CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &impl->fields->outputs[i]));
91       } else {
92         impl->fields->outputs[i] = impl->fields->outputs[index];
93       }
94     }
95   }
96 
97   // Get context data
98   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_impl->d_c));
99 
100   // Apply operator
101   const CeedInt dim  = impl->dim;
102   const CeedInt Q_1d = impl->Q_1d;
103   const CeedInt P_1d = impl->max_P_1d;
104   CeedInt       block_sizes[3], grid = 0;
105 
106   CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes));
107   if (dim == 1) {
108     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
109     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
110   } else if (dim == 2) {
111     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
112     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
113   } else if (dim == 3) {
114     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
115     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
116   }
117 
118   sycl::range<3>    local_range(block_sizes[2], block_sizes[1], block_sizes[0]);
119   sycl::range<3>    global_range(grid * block_sizes[2], block_sizes[1], block_sizes[0]);
120   sycl::nd_range<3> kernel_range(global_range, local_range);
121 
122   //-----------
123   // Order queue
124   sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
125 
126   CeedCallSycl(ceed, ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
127     cgh.depends_on(e);
128     cgh.set_args(num_elem, qf_impl->d_c, impl->indices, impl->fields, impl->B, impl->G, impl->W);
129     cgh.parallel_for(kernel_range, *(impl->op));
130   }));
131   CeedCallSycl(ceed, ceed_Sycl->sycl_queue.wait_and_throw());
132 
133   // Restore input arrays
134   for (CeedInt i = 0; i < num_input_fields; i++) {
135     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
136     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
137     } else {
138       CeedVector vec;
139 
140       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
141       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
142       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &impl->fields->inputs[i]));
143     }
144   }
145 
146   // Restore output arrays
147   for (CeedInt i = 0; i < num_output_fields; i++) {
148     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
149     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
150     } else {
151       CeedVector vec;
152 
153       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
154       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
155       // Check for multiple output modes
156       CeedInt index = -1;
157 
158       for (CeedInt j = 0; j < i; j++) {
159         if (vec == output_vecs[j]) {
160           index = j;
161           break;
162         }
163       }
164       if (index == -1) {
165         CeedCallBackend(CeedVectorRestoreArray(vec, &impl->fields->outputs[i]));
166       }
167     }
168   }
169 
170   // Restore context data
171   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_impl->d_c));
172   return CEED_ERROR_SUCCESS;
173 }
174 
175 //------------------------------------------------------------------------------
176 // Create operator
177 //------------------------------------------------------------------------------
178 int CeedOperatorCreate_Sycl_gen(CeedOperator op) {
179   Ceed                   ceed;
180   Ceed_Sycl             *sycl_data;
181   CeedOperator_Sycl_gen *impl;
182 
183   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
184   CeedCallBackend(CeedGetData(ceed, &sycl_data));
185 
186   CeedCallBackend(CeedCalloc(1, &impl));
187   CeedCallBackend(CeedOperatorSetData(op, impl));
188 
189   impl->indices = sycl::malloc_device<FieldsInt_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
190   impl->fields  = sycl::malloc_host<Fields_Sycl>(1, sycl_data->sycl_context);
191   impl->B       = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
192   impl->G       = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
193   impl->W       = sycl::malloc_device<CeedScalar>(1, sycl_data->sycl_device, sycl_data->sycl_context);
194 
195   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Sycl_gen));
196   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Sycl_gen));
197   return CEED_ERROR_SUCCESS;
198 }
199 
200 //------------------------------------------------------------------------------
201