xref: /libCEED/rust/libceed-sys/c-src/backends/sycl-gen/ceed-sycl-gen-operator.sycl.cpp (revision 6ca0f394dabdca92269b68ec74be8bebae3befa4)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <stddef.h>
11 
12 #include "../sycl/ceed-sycl-compile.hpp"
13 #include "ceed-sycl-gen-operator-build.hpp"
14 #include "ceed-sycl-gen.hpp"
15 
16 //------------------------------------------------------------------------------
17 // Destroy operator
18 //------------------------------------------------------------------------------
19 static int CeedOperatorDestroy_Sycl_gen(CeedOperator op) {
20   CeedOperator_Sycl_gen *impl;
21   CeedCallBackend(CeedOperatorGetData(op, &impl));
22   CeedCallBackend(CeedFree(&impl));
23   return CEED_ERROR_SUCCESS;
24 }
25 
26 //------------------------------------------------------------------------------
27 // Apply and add to output
28 //------------------------------------------------------------------------------
29 static int CeedOperatorApplyAdd_Sycl_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
30   Ceed ceed;
31   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
32   Ceed_Sycl *ceed_Sycl;
33   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
34   CeedOperator_Sycl_gen *impl;
35   CeedCallBackend(CeedOperatorGetData(op, &impl));
36   CeedQFunction           qf;
37   CeedQFunction_Sycl_gen *qf_impl;
38   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
39   CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl));
40   CeedInt num_elem, num_input_fields, num_output_fields;
41   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
42   CeedOperatorField *op_input_fields, *op_output_fields;
43   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
44   CeedQFunctionField *qf_input_fields, *qf_output_fields;
45   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
46   CeedEvalMode eval_mode;
47   CeedVector   vec, output_vecs[CEED_FIELD_MAX] = {};
48 
49   // Creation of the operator
50   CeedCallBackend(CeedOperatorBuildKernel_Sycl_gen(op));
51 
52   // Input vectors
53   for (CeedInt i = 0; i < num_input_fields; i++) {
54     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
55     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
56       impl->fields->inputs[i] = NULL;
57     } else {
58       // Get input vector
59       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
60       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
61       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &impl->fields->inputs[i]));
62     }
63   }
64 
65   // Output vectors
66   for (CeedInt i = 0; i < num_output_fields; i++) {
67     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
68     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
69       impl->fields->outputs[i] = NULL;
70     } else {
71       // Get output vector
72       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
73       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
74       output_vecs[i] = vec;
75       // Check for multiple output modes
76       CeedInt index = -1;
77       for (CeedInt j = 0; j < i; j++) {
78         if (vec == output_vecs[j]) {
79           index = j;
80           break;
81         }
82       }
83       if (index == -1) {
84         CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &impl->fields->outputs[i]));
85       } else {
86         impl->fields->outputs[i] = impl->fields->outputs[index];
87       }
88     }
89   }
90 
91   // Get context data
92   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_impl->d_c));
93 
94   // Apply operator
95   const CeedInt dim  = impl->dim;
96   const CeedInt Q_1d = impl->Q_1d;
97   const CeedInt P_1d = impl->max_P_1d;
98   CeedInt       block_sizes[3], grid = 0;
99   CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes));
100   if (dim == 1) {
101     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
102     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
103   } else if (dim == 2) {
104     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
105     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
106   } else if (dim == 3) {
107     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
108     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
109   }
110 
111   sycl::range<3>    local_range(block_sizes[2], block_sizes[1], block_sizes[0]);
112   sycl::range<3>    global_range(grid * block_sizes[2], block_sizes[1], block_sizes[0]);
113   sycl::nd_range<3> kernel_range(global_range, local_range);
114 
115   //-----------
116   // Order queue
117   sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
118 
119   CeedCallSycl(ceed, ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
120     cgh.depends_on(e);
121     cgh.set_args(num_elem, qf_impl->d_c, impl->indices, impl->fields, impl->B, impl->G, impl->W);
122     cgh.parallel_for(kernel_range, *(impl->op));
123   }));
124   CeedCallSycl(ceed, ceed_Sycl->sycl_queue.wait_and_throw());
125 
126   // Restore input arrays
127   for (CeedInt i = 0; i < num_input_fields; i++) {
128     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
129     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
130     } else {
131       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
132       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
133       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &impl->fields->inputs[i]));
134     }
135   }
136 
137   // Restore output arrays
138   for (CeedInt i = 0; i < num_output_fields; i++) {
139     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
140     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
141     } else {
142       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
143       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
144       // Check for multiple output modes
145       CeedInt index = -1;
146       for (CeedInt j = 0; j < i; j++) {
147         if (vec == output_vecs[j]) {
148           index = j;
149           break;
150         }
151       }
152       if (index == -1) {
153         CeedCallBackend(CeedVectorRestoreArray(vec, &impl->fields->outputs[i]));
154       }
155     }
156   }
157 
158   // Restore context data
159   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_impl->d_c));
160 
161   return CEED_ERROR_SUCCESS;
162 }
163 
164 //------------------------------------------------------------------------------
165 // Create operator
166 //------------------------------------------------------------------------------
167 int CeedOperatorCreate_Sycl_gen(CeedOperator op) {
168   Ceed ceed;
169   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
170   Ceed_Sycl *sycl_data;
171   CeedCallBackend(CeedGetData(ceed, &sycl_data));
172 
173   CeedOperator_Sycl_gen *impl;
174   CeedCallBackend(CeedCalloc(1, &impl));
175   CeedCallBackend(CeedOperatorSetData(op, impl));
176 
177   impl->indices = sycl::malloc_device<FieldsInt_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
178   impl->fields  = sycl::malloc_host<Fields_Sycl>(1, sycl_data->sycl_context);
179   impl->B       = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
180   impl->G       = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
181   impl->W       = sycl::malloc_device<CeedScalar>(1, sycl_data->sycl_device, sycl_data->sycl_context);
182 
183   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Sycl_gen));
184   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Sycl_gen));
185   return CEED_ERROR_SUCCESS;
186 }
187 
188 //------------------------------------------------------------------------------
189