1*6ca0f394SUmesh Unnikrishnan // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*6ca0f394SUmesh Unnikrishnan // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*6ca0f394SUmesh Unnikrishnan // 4*6ca0f394SUmesh Unnikrishnan // SPDX-License-Identifier: BSD-2-Clause 5*6ca0f394SUmesh Unnikrishnan // 6*6ca0f394SUmesh Unnikrishnan // This file is part of CEED: http://github.com/ceed 7*6ca0f394SUmesh Unnikrishnan 8*6ca0f394SUmesh Unnikrishnan #include <ceed/backend.h> 9*6ca0f394SUmesh Unnikrishnan #include <ceed/ceed.h> 10*6ca0f394SUmesh Unnikrishnan #include <stddef.h> 11*6ca0f394SUmesh Unnikrishnan 12*6ca0f394SUmesh Unnikrishnan #include "../sycl/ceed-sycl-compile.hpp" 13*6ca0f394SUmesh Unnikrishnan #include "ceed-sycl-gen-operator-build.hpp" 14*6ca0f394SUmesh Unnikrishnan #include "ceed-sycl-gen.hpp" 15*6ca0f394SUmesh Unnikrishnan 16*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 17*6ca0f394SUmesh Unnikrishnan // Destroy operator 18*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 19*6ca0f394SUmesh Unnikrishnan static int CeedOperatorDestroy_Sycl_gen(CeedOperator op) { 20*6ca0f394SUmesh Unnikrishnan CeedOperator_Sycl_gen *impl; 21*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetData(op, &impl)); 22*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedFree(&impl)); 23*6ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS; 24*6ca0f394SUmesh Unnikrishnan } 25*6ca0f394SUmesh Unnikrishnan 26*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 27*6ca0f394SUmesh Unnikrishnan // Apply and add to output 28*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 29*6ca0f394SUmesh Unnikrishnan static int CeedOperatorApplyAdd_Sycl_gen(CeedOperator op, CeedVector input_vec, CeedVector output_vec, CeedRequest *request) { 30*6ca0f394SUmesh Unnikrishnan Ceed ceed; 31*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 32*6ca0f394SUmesh Unnikrishnan Ceed_Sycl *ceed_Sycl; 33*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetData(ceed, &ceed_Sycl)); 34*6ca0f394SUmesh Unnikrishnan CeedOperator_Sycl_gen *impl; 35*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetData(op, &impl)); 36*6ca0f394SUmesh Unnikrishnan CeedQFunction qf; 37*6ca0f394SUmesh Unnikrishnan CeedQFunction_Sycl_gen *qf_impl; 38*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetQFunction(op, &qf)); 39*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionGetData(qf, &qf_impl)); 40*6ca0f394SUmesh Unnikrishnan CeedInt num_elem, num_input_fields, num_output_fields; 41*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem)); 42*6ca0f394SUmesh Unnikrishnan CeedOperatorField *op_input_fields, *op_output_fields; 43*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields)); 44*6ca0f394SUmesh Unnikrishnan CeedQFunctionField *qf_input_fields, *qf_output_fields; 45*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields)); 46*6ca0f394SUmesh Unnikrishnan CeedEvalMode eval_mode; 47*6ca0f394SUmesh Unnikrishnan CeedVector vec, output_vecs[CEED_FIELD_MAX] = {}; 48*6ca0f394SUmesh Unnikrishnan 49*6ca0f394SUmesh Unnikrishnan // Creation of the operator 50*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorBuildKernel_Sycl_gen(op)); 51*6ca0f394SUmesh Unnikrishnan 52*6ca0f394SUmesh Unnikrishnan // Input vectors 53*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) { 54*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 55*6ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 56*6ca0f394SUmesh Unnikrishnan impl->fields->inputs[i] = NULL; 57*6ca0f394SUmesh Unnikrishnan } else { 58*6ca0f394SUmesh Unnikrishnan // Get input vector 59*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 60*6ca0f394SUmesh Unnikrishnan if (vec == CEED_VECTOR_ACTIVE) vec = input_vec; 61*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &impl->fields->inputs[i])); 62*6ca0f394SUmesh Unnikrishnan } 63*6ca0f394SUmesh Unnikrishnan } 64*6ca0f394SUmesh Unnikrishnan 65*6ca0f394SUmesh Unnikrishnan // Output vectors 66*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 67*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 68*6ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 69*6ca0f394SUmesh Unnikrishnan impl->fields->outputs[i] = NULL; 70*6ca0f394SUmesh Unnikrishnan } else { 71*6ca0f394SUmesh Unnikrishnan // Get output vector 72*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 73*6ca0f394SUmesh Unnikrishnan if (vec == CEED_VECTOR_ACTIVE) vec = output_vec; 74*6ca0f394SUmesh Unnikrishnan output_vecs[i] = vec; 75*6ca0f394SUmesh Unnikrishnan // Check for multiple output modes 76*6ca0f394SUmesh Unnikrishnan CeedInt index = -1; 77*6ca0f394SUmesh Unnikrishnan for (CeedInt j = 0; j < i; j++) { 78*6ca0f394SUmesh Unnikrishnan if (vec == output_vecs[j]) { 79*6ca0f394SUmesh Unnikrishnan index = j; 80*6ca0f394SUmesh Unnikrishnan break; 81*6ca0f394SUmesh Unnikrishnan } 82*6ca0f394SUmesh Unnikrishnan } 83*6ca0f394SUmesh Unnikrishnan if (index == -1) { 84*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedVectorGetArray(vec, CEED_MEM_DEVICE, &impl->fields->outputs[i])); 85*6ca0f394SUmesh Unnikrishnan } else { 86*6ca0f394SUmesh Unnikrishnan impl->fields->outputs[i] = impl->fields->outputs[index]; 87*6ca0f394SUmesh Unnikrishnan } 88*6ca0f394SUmesh Unnikrishnan } 89*6ca0f394SUmesh Unnikrishnan } 90*6ca0f394SUmesh Unnikrishnan 91*6ca0f394SUmesh Unnikrishnan // Get context data 92*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &qf_impl->d_c)); 93*6ca0f394SUmesh Unnikrishnan 94*6ca0f394SUmesh Unnikrishnan // Apply operator 95*6ca0f394SUmesh Unnikrishnan const CeedInt dim = impl->dim; 96*6ca0f394SUmesh Unnikrishnan const CeedInt Q_1d = impl->Q_1d; 97*6ca0f394SUmesh Unnikrishnan const CeedInt P_1d = impl->max_P_1d; 98*6ca0f394SUmesh Unnikrishnan CeedInt block_sizes[3], grid = 0; 99*6ca0f394SUmesh Unnikrishnan CeedCallBackend(BlockGridCalculate_Sycl_gen(dim, P_1d, Q_1d, block_sizes)); 100*6ca0f394SUmesh Unnikrishnan if (dim == 1) { 101*6ca0f394SUmesh Unnikrishnan grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 102*6ca0f394SUmesh Unnikrishnan // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs)); 103*6ca0f394SUmesh Unnikrishnan } else if (dim == 2) { 104*6ca0f394SUmesh Unnikrishnan grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 105*6ca0f394SUmesh Unnikrishnan // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs)); 106*6ca0f394SUmesh Unnikrishnan } else if (dim == 3) { 107*6ca0f394SUmesh Unnikrishnan grid = num_elem / block_sizes[2] + ((num_elem / block_sizes[2] * block_sizes[2] < num_elem) ? 1 : 0); 108*6ca0f394SUmesh Unnikrishnan // CeedCallBackend(CeedRunKernelDimSharedSycl(ceed, impl->op, grid, block_sizes[0], block_sizes[1], block_sizes[2], sharedMem, opargs)); 109*6ca0f394SUmesh Unnikrishnan } 110*6ca0f394SUmesh Unnikrishnan 111*6ca0f394SUmesh Unnikrishnan sycl::range<3> local_range(block_sizes[2], block_sizes[1], block_sizes[0]); 112*6ca0f394SUmesh Unnikrishnan sycl::range<3> global_range(grid * block_sizes[2], block_sizes[1], block_sizes[0]); 113*6ca0f394SUmesh Unnikrishnan sycl::nd_range<3> kernel_range(global_range, local_range); 114*6ca0f394SUmesh Unnikrishnan 115*6ca0f394SUmesh Unnikrishnan //----------- 116*6ca0f394SUmesh Unnikrishnan // Order queue 117*6ca0f394SUmesh Unnikrishnan sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier(); 118*6ca0f394SUmesh Unnikrishnan 119*6ca0f394SUmesh Unnikrishnan CeedCallSycl(ceed, ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 120*6ca0f394SUmesh Unnikrishnan cgh.depends_on(e); 121*6ca0f394SUmesh Unnikrishnan cgh.set_args(num_elem, qf_impl->d_c, impl->indices, impl->fields, impl->B, impl->G, impl->W); 122*6ca0f394SUmesh Unnikrishnan cgh.parallel_for(kernel_range, *(impl->op)); 123*6ca0f394SUmesh Unnikrishnan })); 124*6ca0f394SUmesh Unnikrishnan CeedCallSycl(ceed, ceed_Sycl->sycl_queue.wait_and_throw()); 125*6ca0f394SUmesh Unnikrishnan 126*6ca0f394SUmesh Unnikrishnan // Restore input arrays 127*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_input_fields; i++) { 128*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode)); 129*6ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 130*6ca0f394SUmesh Unnikrishnan } else { 131*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec)); 132*6ca0f394SUmesh Unnikrishnan if (vec == CEED_VECTOR_ACTIVE) vec = input_vec; 133*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedVectorRestoreArrayRead(vec, &impl->fields->inputs[i])); 134*6ca0f394SUmesh Unnikrishnan } 135*6ca0f394SUmesh Unnikrishnan } 136*6ca0f394SUmesh Unnikrishnan 137*6ca0f394SUmesh Unnikrishnan // Restore output arrays 138*6ca0f394SUmesh Unnikrishnan for (CeedInt i = 0; i < num_output_fields; i++) { 139*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode)); 140*6ca0f394SUmesh Unnikrishnan if (eval_mode == CEED_EVAL_WEIGHT) { // Skip 141*6ca0f394SUmesh Unnikrishnan } else { 142*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec)); 143*6ca0f394SUmesh Unnikrishnan if (vec == CEED_VECTOR_ACTIVE) vec = output_vec; 144*6ca0f394SUmesh Unnikrishnan // Check for multiple output modes 145*6ca0f394SUmesh Unnikrishnan CeedInt index = -1; 146*6ca0f394SUmesh Unnikrishnan for (CeedInt j = 0; j < i; j++) { 147*6ca0f394SUmesh Unnikrishnan if (vec == output_vecs[j]) { 148*6ca0f394SUmesh Unnikrishnan index = j; 149*6ca0f394SUmesh Unnikrishnan break; 150*6ca0f394SUmesh Unnikrishnan } 151*6ca0f394SUmesh Unnikrishnan } 152*6ca0f394SUmesh Unnikrishnan if (index == -1) { 153*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedVectorRestoreArray(vec, &impl->fields->outputs[i])); 154*6ca0f394SUmesh Unnikrishnan } 155*6ca0f394SUmesh Unnikrishnan } 156*6ca0f394SUmesh Unnikrishnan } 157*6ca0f394SUmesh Unnikrishnan 158*6ca0f394SUmesh Unnikrishnan // Restore context data 159*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &qf_impl->d_c)); 160*6ca0f394SUmesh Unnikrishnan 161*6ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS; 162*6ca0f394SUmesh Unnikrishnan } 163*6ca0f394SUmesh Unnikrishnan 164*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 165*6ca0f394SUmesh Unnikrishnan // Create operator 166*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 167*6ca0f394SUmesh Unnikrishnan int CeedOperatorCreate_Sycl_gen(CeedOperator op) { 168*6ca0f394SUmesh Unnikrishnan Ceed ceed; 169*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorGetCeed(op, &ceed)); 170*6ca0f394SUmesh Unnikrishnan Ceed_Sycl *sycl_data; 171*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedGetData(ceed, &sycl_data)); 172*6ca0f394SUmesh Unnikrishnan 173*6ca0f394SUmesh Unnikrishnan CeedOperator_Sycl_gen *impl; 174*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedCalloc(1, &impl)); 175*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedOperatorSetData(op, impl)); 176*6ca0f394SUmesh Unnikrishnan 177*6ca0f394SUmesh Unnikrishnan impl->indices = sycl::malloc_device<FieldsInt_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context); 178*6ca0f394SUmesh Unnikrishnan impl->fields = sycl::malloc_host<Fields_Sycl>(1, sycl_data->sycl_context); 179*6ca0f394SUmesh Unnikrishnan impl->B = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context); 180*6ca0f394SUmesh Unnikrishnan impl->G = sycl::malloc_device<Fields_Sycl>(1, sycl_data->sycl_device, sycl_data->sycl_context); 181*6ca0f394SUmesh Unnikrishnan impl->W = sycl::malloc_device<CeedScalar>(1, sycl_data->sycl_device, sycl_data->sycl_context); 182*6ca0f394SUmesh Unnikrishnan 183*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Sycl_gen)); 184*6ca0f394SUmesh Unnikrishnan CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Sycl_gen)); 185*6ca0f394SUmesh Unnikrishnan return CEED_ERROR_SUCCESS; 186*6ca0f394SUmesh Unnikrishnan } 187*6ca0f394SUmesh Unnikrishnan 188*6ca0f394SUmesh Unnikrishnan //------------------------------------------------------------------------------ 189