1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other 2*bd882c8aSJames Wright // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE 3*bd882c8aSJames Wright // files for details. 4*bd882c8aSJames Wright // 5*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 6*bd882c8aSJames Wright // 7*bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 8*bd882c8aSJames Wright 9*bd882c8aSJames Wright #include <ceed/backend.h> 10*bd882c8aSJames Wright #include <ceed/ceed.h> 11*bd882c8aSJames Wright 12*bd882c8aSJames Wright #include <string> 13*bd882c8aSJames Wright #include <sycl/sycl.hpp> 14*bd882c8aSJames Wright #include <vector> 15*bd882c8aSJames Wright 16*bd882c8aSJames Wright #include "../sycl/ceed-sycl-common.hpp" 17*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 18*bd882c8aSJames Wright #include "ceed-sycl-ref-qfunction-load.hpp" 19*bd882c8aSJames Wright #include "ceed-sycl-ref.hpp" 20*bd882c8aSJames Wright 21*bd882c8aSJames Wright #define WG_SIZE_QF 384 22*bd882c8aSJames Wright 23*bd882c8aSJames Wright //------------------------------------------------------------------------------ 24*bd882c8aSJames Wright // Apply QFunction 25*bd882c8aSJames Wright //------------------------------------------------------------------------------ 26*bd882c8aSJames Wright static int CeedQFunctionApply_Sycl(CeedQFunction qf, CeedInt Q, CeedVector *U, CeedVector *V) { 27*bd882c8aSJames Wright CeedQFunction_Sycl *impl; 28*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetData(qf, &impl)); 29*bd882c8aSJames Wright 30*bd882c8aSJames Wright // Build and compile kernel, if not done 31*bd882c8aSJames Wright if (!impl->QFunction) CeedCallBackend(CeedBuildQFunction_Sycl(qf)); 32*bd882c8aSJames Wright 33*bd882c8aSJames Wright Ceed ceed; 34*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 35*bd882c8aSJames Wright Ceed_Sycl *ceed_Sycl; 36*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &ceed_Sycl)); 37*bd882c8aSJames Wright 38*bd882c8aSJames Wright CeedInt num_input_fields, num_output_fields; 39*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetNumArgs(qf, &num_input_fields, &num_output_fields)); 40*bd882c8aSJames Wright 41*bd882c8aSJames Wright // Read vectors 42*bd882c8aSJames Wright std::vector<const CeedScalar *> inputs(num_input_fields); 43*bd882c8aSJames Wright const CeedVector *U_i = U; 44*bd882c8aSJames Wright for (auto &input_i : inputs) { 45*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(*U_i, CEED_MEM_DEVICE, &input_i)); 46*bd882c8aSJames Wright ++U_i; 47*bd882c8aSJames Wright } 48*bd882c8aSJames Wright 49*bd882c8aSJames Wright std::vector<CeedScalar *> outputs(num_output_fields); 50*bd882c8aSJames Wright CeedVector *V_i = V; 51*bd882c8aSJames Wright for (auto &output_i : outputs) { 52*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(*V_i, CEED_MEM_DEVICE, &output_i)); 53*bd882c8aSJames Wright ++V_i; 54*bd882c8aSJames Wright } 55*bd882c8aSJames Wright 56*bd882c8aSJames Wright // Get context data 57*bd882c8aSJames Wright void *context_data; 58*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetInnerContextData(qf, CEED_MEM_DEVICE, &context_data)); 59*bd882c8aSJames Wright 60*bd882c8aSJames Wright // Order queue 61*bd882c8aSJames Wright sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier(); 62*bd882c8aSJames Wright 63*bd882c8aSJames Wright // Launch as a basic parallel_for over Q quadrature points 64*bd882c8aSJames Wright ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 65*bd882c8aSJames Wright cgh.depends_on({e}); 66*bd882c8aSJames Wright 67*bd882c8aSJames Wright int iarg{}; 68*bd882c8aSJames Wright cgh.set_arg(iarg, context_data); 69*bd882c8aSJames Wright ++iarg; 70*bd882c8aSJames Wright cgh.set_arg(iarg, Q); 71*bd882c8aSJames Wright ++iarg; 72*bd882c8aSJames Wright for (auto &input_i : inputs) { 73*bd882c8aSJames Wright cgh.set_arg(iarg, input_i); 74*bd882c8aSJames Wright ++iarg; 75*bd882c8aSJames Wright } 76*bd882c8aSJames Wright for (auto &output_i : outputs) { 77*bd882c8aSJames Wright cgh.set_arg(iarg, output_i); 78*bd882c8aSJames Wright ++iarg; 79*bd882c8aSJames Wright } 80*bd882c8aSJames Wright // Hard-coding the work-group size for now 81*bd882c8aSJames Wright // We could use the Level Zero API to query and set an appropriate size in future 82*bd882c8aSJames Wright // Equivalent of CUDA Occupancy Calculator 83*bd882c8aSJames Wright int wg_size = WG_SIZE_QF; 84*bd882c8aSJames Wright sycl::range<1> rounded_Q = ((Q + (wg_size - 1)) / wg_size) * wg_size; 85*bd882c8aSJames Wright sycl::nd_range<1> kernel_range(rounded_Q, wg_size); 86*bd882c8aSJames Wright cgh.parallel_for(kernel_range, *(impl->QFunction)); 87*bd882c8aSJames Wright }); 88*bd882c8aSJames Wright 89*bd882c8aSJames Wright // Restore vectors 90*bd882c8aSJames Wright U_i = U; 91*bd882c8aSJames Wright for (auto &input_i : inputs) { 92*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(*U_i, &input_i)); 93*bd882c8aSJames Wright ++U_i; 94*bd882c8aSJames Wright } 95*bd882c8aSJames Wright 96*bd882c8aSJames Wright V_i = V; 97*bd882c8aSJames Wright for (auto &output_i : outputs) { 98*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(*V_i, &output_i)); 99*bd882c8aSJames Wright ++V_i; 100*bd882c8aSJames Wright } 101*bd882c8aSJames Wright 102*bd882c8aSJames Wright // Restore context 103*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionRestoreInnerContextData(qf, &context_data)); 104*bd882c8aSJames Wright 105*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 106*bd882c8aSJames Wright } 107*bd882c8aSJames Wright 108*bd882c8aSJames Wright //------------------------------------------------------------------------------ 109*bd882c8aSJames Wright // Destroy QFunction 110*bd882c8aSJames Wright //------------------------------------------------------------------------------ 111*bd882c8aSJames Wright static int CeedQFunctionDestroy_Sycl(CeedQFunction qf) { 112*bd882c8aSJames Wright CeedQFunction_Sycl *impl; 113*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetData(qf, &impl)); 114*bd882c8aSJames Wright 115*bd882c8aSJames Wright Ceed ceed; 116*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionGetCeed(qf, &ceed)); 117*bd882c8aSJames Wright 118*bd882c8aSJames Wright delete impl->QFunction; 119*bd882c8aSJames Wright delete impl->sycl_module; 120*bd882c8aSJames Wright 121*bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 122*bd882c8aSJames Wright 123*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 124*bd882c8aSJames Wright } 125*bd882c8aSJames Wright 126*bd882c8aSJames Wright //------------------------------------------------------------------------------ 127*bd882c8aSJames Wright // Create QFunction 128*bd882c8aSJames Wright //------------------------------------------------------------------------------ 129*bd882c8aSJames Wright int CeedQFunctionCreate_Sycl(CeedQFunction qf) { 130*bd882c8aSJames Wright Ceed ceed; 131*bd882c8aSJames Wright CeedQFunctionGetCeed(qf, &ceed); 132*bd882c8aSJames Wright CeedQFunction_Sycl *impl; 133*bd882c8aSJames Wright 134*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(1, &impl)); 135*bd882c8aSJames Wright CeedCallBackend(CeedQFunctionSetData(qf, impl)); 136*bd882c8aSJames Wright 137*bd882c8aSJames Wright // Register backend functions 138*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Apply", CeedQFunctionApply_Sycl)); 139*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "QFunction", qf, "Destroy", CeedQFunctionDestroy_Sycl)); 140*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 141*bd882c8aSJames Wright } 142*bd882c8aSJames Wright //------------------------------------------------------------------------------ 143