1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
26ca0f394SUmesh Unnikrishnan // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
36ca0f394SUmesh Unnikrishnan //
46ca0f394SUmesh Unnikrishnan // SPDX-License-Identifier: BSD-2-Clause
56ca0f394SUmesh Unnikrishnan //
66ca0f394SUmesh Unnikrishnan // This file is part of CEED: http://github.com/ceed
76ca0f394SUmesh Unnikrishnan
86ca0f394SUmesh Unnikrishnan #include <ceed/backend.h>
96ca0f394SUmesh Unnikrishnan #include <ceed/ceed.h>
106ca0f394SUmesh Unnikrishnan #include <stddef.h>
116ca0f394SUmesh Unnikrishnan
126ca0f394SUmesh Unnikrishnan #include "../sycl/ceed-sycl-compile.hpp"
136ca0f394SUmesh Unnikrishnan #include "ceed-sycl-gen-operator-build.hpp"
146ca0f394SUmesh Unnikrishnan #include "ceed-sycl-gen.hpp"
156ca0f394SUmesh Unnikrishnan
166ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
176ca0f394SUmesh Unnikrishnan // Destroy operator
186ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
CeedOperatorDestroy_Sycl_gen(CeedOperator op)196ca0f394SUmesh Unnikrishnan static int CeedOperatorDestroy_Sycl_gen(CeedOperator op) {
206ca0f394SUmesh Unnikrishnan CeedOperator_Sycl_gen *impl;
21dd64fc84SJeremy L Thompson
226ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetData(op, &impl));
236ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedFree(&impl));
246ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS;
256ca0f394SUmesh Unnikrishnan }
266ca0f394SUmesh Unnikrishnan
276ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
286ca0f394SUmesh Unnikrishnan // Apply and add to output
296ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
CeedOperatorApplyAdd_Sycl_gen(CeedOperator op,CeedVector input_vec,CeedVector output_vec,CeedRequest * request)306ca0f394SUmesh Unnikrishnan static int CeedOperatorApplyAdd_Sycl_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
316ca0f394SUmesh Unnikrishnan Ceed ceed;
326ca0f394SUmesh Unnikrishnan Ceed_Sycl *ceed_Sycl;
33dd64fc84SJeremy L Thompson CeedInt num_elem, num_input_fields, num_output_fields;
34dd64fc84SJeremy L Thompson CeedEvalMode eval_mode;
35dd64fc84SJeremy L Thompson CeedVector output_vecs[CEED_FIELD_MAX] = {};
36dd64fc84SJeremy L Thompson CeedQFunctionField *qf_input_fields, *qf_output_fields;
376ca0f394SUmesh Unnikrishnan CeedQFunction_Sycl_gen *qf_impl;
38dd64fc84SJeremy L Thompson CeedQFunction qf;
39dd64fc84SJeremy L Thompson CeedOperatorField *op_input_fields, *op_output_fields;
40dd64fc84SJeremy L Thompson CeedOperator_Sycl_gen *impl;
41dd64fc84SJeremy L Thompson
422d42b1dfSJeremy L Thompson // Check for tensor-product bases
432d42b1dfSJeremy L Thompson {
442d42b1dfSJeremy L Thompson bool has_tensor_bases;
452d42b1dfSJeremy L Thompson
462d42b1dfSJeremy L Thompson CeedCallBackend(CeedOperatorHasTensorBases(op, &has_tensor_bases));
472d42b1dfSJeremy L Thompson // -- Fallback to ref if not all bases are tensor-product
482d42b1dfSJeremy L Thompson if (!has_tensor_bases) {
492d42b1dfSJeremy L Thompson CeedOperator op_fallback;
502d42b1dfSJeremy L Thompson
51c11e12f4SJeremy L Thompson CeedDebug256(CeedOperatorReturnCeed(op), CEED_DEBUG_COLOR_SUCCESS, "Falling back to sycl/ref CeedOperator due to non-tensor bases");
522d42b1dfSJeremy L Thompson CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
532d42b1dfSJeremy L Thompson CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
542d42b1dfSJeremy L Thompson return CEED_ERROR_SUCCESS;
552d42b1dfSJeremy L Thompson }
562d42b1dfSJeremy L Thompson }
572d42b1dfSJeremy L Thompson
58c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
59c11e12f4SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
60c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetData(op, &impl));
61c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
62c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl));
63c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
64c11e12f4SJeremy L Thompson CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
65c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
66c11e12f4SJeremy L Thompson
676ca0f394SUmesh Unnikrishnan // Creation of the operator
686ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorBuildKernel_Sycl_gen(op));
696ca0f394SUmesh Unnikrishnan
706ca0f394SUmesh Unnikrishnan // Input vectors
716ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) {
726ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
736ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
746ca0f394SUmesh Unnikrishnan impl->fields->inputs[i] = NULL;
756ca0f394SUmesh Unnikrishnan } else {
76681d0ea7SJeremy L Thompson bool is_active;
77dd64fc84SJeremy L Thompson CeedVector vec;
78dd64fc84SJeremy L Thompson
796ca0f394SUmesh Unnikrishnan // Get input vector
806ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
81681d0ea7SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
82681d0ea7SJeremy L Thompson if (is_active) vec = input_vec;
836ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &impl->fields->inputs[i]));
84681d0ea7SJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
856ca0f394SUmesh Unnikrishnan }
866ca0f394SUmesh Unnikrishnan }
876ca0f394SUmesh Unnikrishnan
886ca0f394SUmesh Unnikrishnan // Output vectors
896ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
906ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
916ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
926ca0f394SUmesh Unnikrishnan impl->fields->outputs[i] = NULL;
936ca0f394SUmesh Unnikrishnan } else {
94681d0ea7SJeremy L Thompson bool is_active;
95dd64fc84SJeremy L Thompson CeedVector vec;
96dd64fc84SJeremy L Thompson
976ca0f394SUmesh Unnikrishnan // Get output vector
986ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
99681d0ea7SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
100681d0ea7SJeremy L Thompson if (is_active) vec = output_vec;
1016ca0f394SUmesh Unnikrishnan output_vecs[i] = vec;
1026ca0f394SUmesh Unnikrishnan // Check for multiple output modes
1036ca0f394SUmesh Unnikrishnan CeedInt index = -1;
1046ca0f394SUmesh Unnikrishnan for (CeedInt j = 0; j < i; j++) {
1056ca0f394SUmesh Unnikrishnan if (vec == output_vecs[j]) {
1066ca0f394SUmesh Unnikrishnan index = j;
1076ca0f394SUmesh Unnikrishnan break;
1086ca0f394SUmesh Unnikrishnan }
1096ca0f394SUmesh Unnikrishnan }
1106ca0f394SUmesh Unnikrishnan if (index == -1) {
1116ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &impl->fields->outputs[i]));
1126ca0f394SUmesh Unnikrishnan } else {
1136ca0f394SUmesh Unnikrishnan impl->fields->outputs[i] = impl->fields->outputs[index];
1146ca0f394SUmesh Unnikrishnan }
11585938a6dSJames Wright if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
1166ca0f394SUmesh Unnikrishnan }
1176ca0f394SUmesh Unnikrishnan }
1186ca0f394SUmesh Unnikrishnan
1196ca0f394SUmesh Unnikrishnan // Get context data
1206ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_impl->d_c));
1216ca0f394SUmesh Unnikrishnan
1226ca0f394SUmesh Unnikrishnan // Apply operator
1236ca0f394SUmesh Unnikrishnan const CeedInt dim = impl->dim;
1246ca0f394SUmesh Unnikrishnan const CeedInt Q_1d = impl->Q_1d;
1256ca0f394SUmesh Unnikrishnan const CeedInt P_1d = impl->max_P_1d;
1266ca0f394SUmesh Unnikrishnan CeedInt block_sizes[3], grid = 0;
127dd64fc84SJeremy L Thompson
1286ca0f394SUmesh Unnikrishnan CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes));
1296ca0f394SUmesh Unnikrishnan if (dim == 1) {
1306ca0f394SUmesh Unnikrishnan grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
1316ca0f394SUmesh Unnikrishnan // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
1326ca0f394SUmesh Unnikrishnan } else if (dim == 2) {
1336ca0f394SUmesh Unnikrishnan grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
1346ca0f394SUmesh Unnikrishnan // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
1356ca0f394SUmesh Unnikrishnan } else if (dim == 3) {
1366ca0f394SUmesh Unnikrishnan grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
1376ca0f394SUmesh Unnikrishnan // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
1386ca0f394SUmesh Unnikrishnan }
1396ca0f394SUmesh Unnikrishnan
1406ca0f394SUmesh Unnikrishnan sycl::range<3> local_range(block_sizes[2], block_sizes[1], block_sizes[0]);
1416ca0f394SUmesh Unnikrishnan sycl::range<3> global_range(grid * block_sizes[2], block_sizes[1], block_sizes[0]);
1426ca0f394SUmesh Unnikrishnan sycl::nd_range<3> kernel_range(global_range, local_range);
1436ca0f394SUmesh Unnikrishnan
1446ca0f394SUmesh Unnikrishnan //-----------
1451f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e;
1461f4b1b45SUmesh Unnikrishnan
1471f4b1b45SUmesh Unnikrishnan if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()};
1486ca0f394SUmesh Unnikrishnan
1496ca0f394SUmesh Unnikrishnan CeedCallSycl(ceed, ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
1506ca0f394SUmesh Unnikrishnan cgh.depends_on(e);
1516ca0f394SUmesh Unnikrishnan cgh.set_args(num_elem, qf_impl->d_c, impl->indices, impl->fields, impl->B, impl->G, impl->W);
1526ca0f394SUmesh Unnikrishnan cgh.parallel_for(kernel_range, *(impl->op));
1536ca0f394SUmesh Unnikrishnan }));
1546ca0f394SUmesh Unnikrishnan CeedCallSycl(ceed, ceed_Sycl->sycl_queue.wait_and_throw());
1556ca0f394SUmesh Unnikrishnan
1566ca0f394SUmesh Unnikrishnan // Restore input arrays
1576ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) {
1586ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
1596ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
1606ca0f394SUmesh Unnikrishnan } else {
161681d0ea7SJeremy L Thompson bool is_active;
162dd64fc84SJeremy L Thompson CeedVector vec;
163dd64fc84SJeremy L Thompson
1646ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
165681d0ea7SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
166681d0ea7SJeremy L Thompson if (is_active) vec = input_vec;
1676ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedVectorRestoreArrayRead(vec, &impl->fields->inputs[i]));
168681d0ea7SJeremy L Thompson if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
1696ca0f394SUmesh Unnikrishnan }
1706ca0f394SUmesh Unnikrishnan }
1716ca0f394SUmesh Unnikrishnan
1726ca0f394SUmesh Unnikrishnan // Restore output arrays
1736ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) {
1746ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
1756ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_WEIGHT) { // Skip
1766ca0f394SUmesh Unnikrishnan } else {
177681d0ea7SJeremy L Thompson bool is_active;
178dd64fc84SJeremy L Thompson CeedVector vec;
179dd64fc84SJeremy L Thompson
1806ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
181681d0ea7SJeremy L Thompson is_active = vec == CEED_VECTOR_ACTIVE;
182681d0ea7SJeremy L Thompson if (is_active) vec = output_vec;
1836ca0f394SUmesh Unnikrishnan // Check for multiple output modes
1846ca0f394SUmesh Unnikrishnan CeedInt index = -1;
185dd64fc84SJeremy L Thompson
1866ca0f394SUmesh Unnikrishnan for (CeedInt j = 0; j < i; j++) {
1876ca0f394SUmesh Unnikrishnan if (vec == output_vecs[j]) {
1886ca0f394SUmesh Unnikrishnan index = j;
1896ca0f394SUmesh Unnikrishnan break;
1906ca0f394SUmesh Unnikrishnan }
1916ca0f394SUmesh Unnikrishnan }
1926ca0f394SUmesh Unnikrishnan if (index == -1) {
1936ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedVectorRestoreArray(vec, &impl->fields->outputs[i]));
1946ca0f394SUmesh Unnikrishnan }
19585938a6dSJames Wright if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
1966ca0f394SUmesh Unnikrishnan }
1976ca0f394SUmesh Unnikrishnan }
1986ca0f394SUmesh Unnikrishnan
1996ca0f394SUmesh Unnikrishnan // Restore context data
2006ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_impl->d_c));
2019bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
202c11e12f4SJeremy L Thompson CeedCallBackend(CeedQFunctionDestroy(&qf));
2036ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS;
2046ca0f394SUmesh Unnikrishnan }
2056ca0f394SUmesh Unnikrishnan
2066ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
2076ca0f394SUmesh Unnikrishnan // Create operator
2086ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
CeedOperatorCreate_Sycl_gen(CeedOperator op)2096ca0f394SUmesh Unnikrishnan int CeedOperatorCreate_Sycl_gen(CeedOperator op) {
2106ca0f394SUmesh Unnikrishnan Ceed ceed;
2116ca0f394SUmesh Unnikrishnan Ceed_Sycl *sycl_data;
212dd64fc84SJeremy L Thompson CeedOperator_Sycl_gen *impl;
213dd64fc84SJeremy L Thompson
214dd64fc84SJeremy L Thompson CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
2156ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetData(ceed, &sycl_data));
2166ca0f394SUmesh Unnikrishnan
2176ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedCalloc(1, &impl));
2186ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorSetData(op, impl));
2196ca0f394SUmesh Unnikrishnan
2206ca0f394SUmesh Unnikrishnan impl->indices = sycl::malloc_device<FieldsInt_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
2216ca0f394SUmesh Unnikrishnan impl->fields = sycl::malloc_host<Fields_Sycl>(1, sycl_data->sycl_context);
2226ca0f394SUmesh Unnikrishnan impl->B = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
2236ca0f394SUmesh Unnikrishnan impl->G = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
2246ca0f394SUmesh Unnikrishnan impl->W = sycl::malloc_device<CeedScalar>(1, sycl_data->sycl_device, sycl_data->sycl_context);
2256ca0f394SUmesh Unnikrishnan
2266ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Sycl_gen));
2276ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Sycl_gen));
2289bc66399SJeremy L Thompson CeedCallBackend(CeedDestroy(&ceed));
2296ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS;
2306ca0f394SUmesh Unnikrishnan }
2316ca0f394SUmesh Unnikrishnan
2326ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
233