xref: /libCEED/backends/sycl-gen/ceed-sycl-gen-operator.sycl.cpp (revision 2d42b1df6545af94031e4cf6a69853d29f68d801)
16ca0f394SUmesh Unnikrishnan // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
26ca0f394SUmesh Unnikrishnan // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
36ca0f394SUmesh Unnikrishnan //
46ca0f394SUmesh Unnikrishnan // SPDX-License-Identifier: BSD-2-Clause
56ca0f394SUmesh Unnikrishnan //
66ca0f394SUmesh Unnikrishnan // This file is part of CEED:  http://github.com/ceed
76ca0f394SUmesh Unnikrishnan 
86ca0f394SUmesh Unnikrishnan #include <ceed/backend.h>
96ca0f394SUmesh Unnikrishnan #include <ceed/ceed.h>
106ca0f394SUmesh Unnikrishnan #include <stddef.h>
116ca0f394SUmesh Unnikrishnan 
126ca0f394SUmesh Unnikrishnan #include "../sycl/ceed-sycl-compile.hpp"
136ca0f394SUmesh Unnikrishnan #include "ceed-sycl-gen-operator-build.hpp"
146ca0f394SUmesh Unnikrishnan #include "ceed-sycl-gen.hpp"
156ca0f394SUmesh Unnikrishnan 
166ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
176ca0f394SUmesh Unnikrishnan // Destroy operator
186ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
196ca0f394SUmesh Unnikrishnan static int CeedOperatorDestroy_Sycl_gen(CeedOperator op) {
206ca0f394SUmesh Unnikrishnan   CeedOperator_Sycl_gen *impl;
21dd64fc84SJeremy L Thompson 
226ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorGetData(op, &impl));
236ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedFree(&impl));
246ca0f394SUmesh Unnikrishnan   return CEED_ERROR_SUCCESS;
256ca0f394SUmesh Unnikrishnan }
266ca0f394SUmesh Unnikrishnan 
276ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
286ca0f394SUmesh Unnikrishnan // Apply and add to output
296ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
306ca0f394SUmesh Unnikrishnan static int CeedOperatorApplyAdd_Sycl_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) {
316ca0f394SUmesh Unnikrishnan   Ceed                    ceed;
326ca0f394SUmesh Unnikrishnan   Ceed_Sycl              *ceed_Sycl;
33dd64fc84SJeremy L Thompson   CeedInt                 num_elem, num_input_fields, num_output_fields;
34dd64fc84SJeremy L Thompson   CeedEvalMode            eval_mode;
35dd64fc84SJeremy L Thompson   CeedVector              output_vecs[CEED_FIELD_MAX] = {};
36dd64fc84SJeremy L Thompson   CeedQFunctionField     *qf_input_fields, *qf_output_fields;
376ca0f394SUmesh Unnikrishnan   CeedQFunction_Sycl_gen *qf_impl;
38dd64fc84SJeremy L Thompson   CeedQFunction           qf;
39dd64fc84SJeremy L Thompson   CeedOperatorField      *op_input_fields, *op_output_fields;
40dd64fc84SJeremy L Thompson   CeedOperator_Sycl_gen  *impl;
41dd64fc84SJeremy L Thompson 
42dd64fc84SJeremy L Thompson   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
43dd64fc84SJeremy L Thompson   CeedCallBackend(CeedGetData(ceed, &ceed_Sycl));
44dd64fc84SJeremy L Thompson   CeedCallBackend(CeedOperatorGetData(op, &impl));
456ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
466ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl));
476ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
486ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
496ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
506ca0f394SUmesh Unnikrishnan 
51*2d42b1dfSJeremy L Thompson   // Check for tensor-product bases
52*2d42b1dfSJeremy L Thompson   {
53*2d42b1dfSJeremy L Thompson     bool has_tensor_bases;
54*2d42b1dfSJeremy L Thompson 
55*2d42b1dfSJeremy L Thompson     CeedCallBackend(CeedOperatorHasTensorBases(op, &has_tensor_bases));
56*2d42b1dfSJeremy L Thompson     // -- Fallback to ref if not all bases are tensor-product
57*2d42b1dfSJeremy L Thompson     if (!has_tensor_bases) {
58*2d42b1dfSJeremy L Thompson       CeedOperator op_fallback;
59*2d42b1dfSJeremy L Thompson 
60*2d42b1dfSJeremy L Thompson       CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "Falling back to sycl/ref CeedOperator due to non-tensor bases");
61*2d42b1dfSJeremy L Thompson       CeedCallBackend(CeedOperatorGetFallback(op, &op_fallback));
62*2d42b1dfSJeremy L Thompson       CeedCallBackend(CeedOperatorApplyAdd(op_fallback, input_vec, output_vec, request));
63*2d42b1dfSJeremy L Thompson       return CEED_ERROR_SUCCESS;
64*2d42b1dfSJeremy L Thompson     }
65*2d42b1dfSJeremy L Thompson   }
66*2d42b1dfSJeremy L Thompson 
676ca0f394SUmesh Unnikrishnan   // Creation of the operator
686ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorBuildKernel_Sycl_gen(op));
696ca0f394SUmesh Unnikrishnan 
706ca0f394SUmesh Unnikrishnan   // Input vectors
716ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_input_fields; i++) {
726ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
736ca0f394SUmesh Unnikrishnan     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
746ca0f394SUmesh Unnikrishnan       impl->fields->inputs[i] = NULL;
756ca0f394SUmesh Unnikrishnan     } else {
76dd64fc84SJeremy L Thompson       CeedVector vec;
77dd64fc84SJeremy L Thompson 
786ca0f394SUmesh Unnikrishnan       // Get input vector
796ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
806ca0f394SUmesh Unnikrishnan       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
816ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &impl->fields->inputs[i]));
826ca0f394SUmesh Unnikrishnan     }
836ca0f394SUmesh Unnikrishnan   }
846ca0f394SUmesh Unnikrishnan 
856ca0f394SUmesh Unnikrishnan   // Output vectors
866ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_output_fields; i++) {
876ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
886ca0f394SUmesh Unnikrishnan     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
896ca0f394SUmesh Unnikrishnan       impl->fields->outputs[i] = NULL;
906ca0f394SUmesh Unnikrishnan     } else {
91dd64fc84SJeremy L Thompson       CeedVector vec;
92dd64fc84SJeremy L Thompson 
936ca0f394SUmesh Unnikrishnan       // Get output vector
946ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
956ca0f394SUmesh Unnikrishnan       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
966ca0f394SUmesh Unnikrishnan       output_vecs[i] = vec;
976ca0f394SUmesh Unnikrishnan       // Check for multiple output modes
986ca0f394SUmesh Unnikrishnan       CeedInt index = -1;
996ca0f394SUmesh Unnikrishnan       for (CeedInt j = 0; j < i; j++) {
1006ca0f394SUmesh Unnikrishnan         if (vec == output_vecs[j]) {
1016ca0f394SUmesh Unnikrishnan           index = j;
1026ca0f394SUmesh Unnikrishnan           break;
1036ca0f394SUmesh Unnikrishnan         }
1046ca0f394SUmesh Unnikrishnan       }
1056ca0f394SUmesh Unnikrishnan       if (index == -1) {
1066ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &impl->fields->outputs[i]));
1076ca0f394SUmesh Unnikrishnan       } else {
1086ca0f394SUmesh Unnikrishnan         impl->fields->outputs[i] = impl->fields->outputs[index];
1096ca0f394SUmesh Unnikrishnan       }
1106ca0f394SUmesh Unnikrishnan     }
1116ca0f394SUmesh Unnikrishnan   }
1126ca0f394SUmesh Unnikrishnan 
1136ca0f394SUmesh Unnikrishnan   // Get context data
1146ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_impl->d_c));
1156ca0f394SUmesh Unnikrishnan 
1166ca0f394SUmesh Unnikrishnan   // Apply operator
1176ca0f394SUmesh Unnikrishnan   const CeedInt dim  = impl->dim;
1186ca0f394SUmesh Unnikrishnan   const CeedInt Q_1d = impl->Q_1d;
1196ca0f394SUmesh Unnikrishnan   const CeedInt P_1d = impl->max_P_1d;
1206ca0f394SUmesh Unnikrishnan   CeedInt       block_sizes[3], grid = 0;
121dd64fc84SJeremy L Thompson 
1226ca0f394SUmesh Unnikrishnan   CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes));
1236ca0f394SUmesh Unnikrishnan   if (dim == 1) {
1246ca0f394SUmesh Unnikrishnan     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
1256ca0f394SUmesh Unnikrishnan     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
1266ca0f394SUmesh Unnikrishnan   } else if (dim == 2) {
1276ca0f394SUmesh Unnikrishnan     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
1286ca0f394SUmesh Unnikrishnan     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
1296ca0f394SUmesh Unnikrishnan   } else if (dim == 3) {
1306ca0f394SUmesh Unnikrishnan     grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0);
1316ca0f394SUmesh Unnikrishnan     // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs));
1326ca0f394SUmesh Unnikrishnan   }
1336ca0f394SUmesh Unnikrishnan 
1346ca0f394SUmesh Unnikrishnan   sycl::range<3>    local_range(block_sizes[2], block_sizes[1], block_sizes[0]);
1356ca0f394SUmesh Unnikrishnan   sycl::range<3>    global_range(grid * block_sizes[2], block_sizes[1], block_sizes[0]);
1366ca0f394SUmesh Unnikrishnan   sycl::nd_range<3> kernel_range(global_range, local_range);
1376ca0f394SUmesh Unnikrishnan 
1386ca0f394SUmesh Unnikrishnan   //-----------
1396ca0f394SUmesh Unnikrishnan   // Order queue
1406ca0f394SUmesh Unnikrishnan   sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier();
1416ca0f394SUmesh Unnikrishnan 
1426ca0f394SUmesh Unnikrishnan   CeedCallSycl(ceed, ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) {
1436ca0f394SUmesh Unnikrishnan     cgh.depends_on(e);
1446ca0f394SUmesh Unnikrishnan     cgh.set_args(num_elem, qf_impl->d_c, impl->indices, impl->fields, impl->B, impl->G, impl->W);
1456ca0f394SUmesh Unnikrishnan     cgh.parallel_for(kernel_range, *(impl->op));
1466ca0f394SUmesh Unnikrishnan   }));
1476ca0f394SUmesh Unnikrishnan   CeedCallSycl(ceed, ceed_Sycl->sycl_queue.wait_and_throw());
1486ca0f394SUmesh Unnikrishnan 
1496ca0f394SUmesh Unnikrishnan   // Restore input arrays
1506ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_input_fields; i++) {
1516ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
1526ca0f394SUmesh Unnikrishnan     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
1536ca0f394SUmesh Unnikrishnan     } else {
154dd64fc84SJeremy L Thompson       CeedVector vec;
155dd64fc84SJeremy L Thompson 
1566ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
1576ca0f394SUmesh Unnikrishnan       if (vec == CEED_VECTOR_ACTIVE) vec = input_vec;
1586ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedVectorRestoreArrayRead(vec, &impl->fields->inputs[i]));
1596ca0f394SUmesh Unnikrishnan     }
1606ca0f394SUmesh Unnikrishnan   }
1616ca0f394SUmesh Unnikrishnan 
1626ca0f394SUmesh Unnikrishnan   // Restore output arrays
1636ca0f394SUmesh Unnikrishnan   for (CeedInt i = 0; i < num_output_fields; i++) {
1646ca0f394SUmesh Unnikrishnan     CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
1656ca0f394SUmesh Unnikrishnan     if (eval_mode == CEED_EVAL_WEIGHT) {  // Skip
1666ca0f394SUmesh Unnikrishnan     } else {
167dd64fc84SJeremy L Thompson       CeedVector vec;
168dd64fc84SJeremy L Thompson 
1696ca0f394SUmesh Unnikrishnan       CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
1706ca0f394SUmesh Unnikrishnan       if (vec == CEED_VECTOR_ACTIVE) vec = output_vec;
1716ca0f394SUmesh Unnikrishnan       // Check for multiple output modes
1726ca0f394SUmesh Unnikrishnan       CeedInt index = -1;
173dd64fc84SJeremy L Thompson 
1746ca0f394SUmesh Unnikrishnan       for (CeedInt j = 0; j < i; j++) {
1756ca0f394SUmesh Unnikrishnan         if (vec == output_vecs[j]) {
1766ca0f394SUmesh Unnikrishnan           index = j;
1776ca0f394SUmesh Unnikrishnan           break;
1786ca0f394SUmesh Unnikrishnan         }
1796ca0f394SUmesh Unnikrishnan       }
1806ca0f394SUmesh Unnikrishnan       if (index == -1) {
1816ca0f394SUmesh Unnikrishnan         CeedCallBackend(CeedVectorRestoreArray(vec, &impl->fields->outputs[i]));
1826ca0f394SUmesh Unnikrishnan       }
1836ca0f394SUmesh Unnikrishnan     }
1846ca0f394SUmesh Unnikrishnan   }
1856ca0f394SUmesh Unnikrishnan 
1866ca0f394SUmesh Unnikrishnan   // Restore context data
1876ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_impl->d_c));
1886ca0f394SUmesh Unnikrishnan   return CEED_ERROR_SUCCESS;
1896ca0f394SUmesh Unnikrishnan }
1906ca0f394SUmesh Unnikrishnan 
1916ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1926ca0f394SUmesh Unnikrishnan // Create operator
1936ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
1946ca0f394SUmesh Unnikrishnan int CeedOperatorCreate_Sycl_gen(CeedOperator op) {
1956ca0f394SUmesh Unnikrishnan   Ceed                   ceed;
1966ca0f394SUmesh Unnikrishnan   Ceed_Sycl             *sycl_data;
197dd64fc84SJeremy L Thompson   CeedOperator_Sycl_gen *impl;
198dd64fc84SJeremy L Thompson 
199dd64fc84SJeremy L Thompson   CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
2006ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedGetData(ceed, &sycl_data));
2016ca0f394SUmesh Unnikrishnan 
2026ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedCalloc(1, &impl));
2036ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedOperatorSetData(op, impl));
2046ca0f394SUmesh Unnikrishnan 
2056ca0f394SUmesh Unnikrishnan   impl->indices = sycl::malloc_device<FieldsInt_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
2066ca0f394SUmesh Unnikrishnan   impl->fields  = sycl::malloc_host<Fields_Sycl>(1, sycl_data->sycl_context);
2076ca0f394SUmesh Unnikrishnan   impl->B       = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
2086ca0f394SUmesh Unnikrishnan   impl->G       = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context);
2096ca0f394SUmesh Unnikrishnan   impl->W       = sycl::malloc_device<CeedScalar>(1, sycl_data->sycl_device, sycl_data->sycl_context);
2106ca0f394SUmesh Unnikrishnan 
2116ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Sycl_gen));
2126ca0f394SUmesh Unnikrishnan   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Sycl_gen));
2136ca0f394SUmesh Unnikrishnan   return CEED_ERROR_SUCCESS;
2146ca0f394SUmesh Unnikrishnan }
2156ca0f394SUmesh Unnikrishnan 
2166ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------
217