1*bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*bd882c8aSJames Wright // 4*bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5*bd882c8aSJames Wright // 6*bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7*bd882c8aSJames Wright 8*bd882c8aSJames Wright #include <ceed/backend.h> 9*bd882c8aSJames Wright #include <ceed/ceed.h> 10*bd882c8aSJames Wright #include <ceed/jit-tools.h> 11*bd882c8aSJames Wright 12*bd882c8aSJames Wright #include <map> 13*bd882c8aSJames Wright #include <string_view> 14*bd882c8aSJames Wright #include <sycl/sycl.hpp> 15*bd882c8aSJames Wright 16*bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 17*bd882c8aSJames Wright #include "ceed-sycl-shared.hpp" 18*bd882c8aSJames Wright 19*bd882c8aSJames Wright //------------------------------------------------------------------------------ 20*bd882c8aSJames Wright // Compute the local range of for basis kernels 21*bd882c8aSJames Wright //------------------------------------------------------------------------------ 22*bd882c8aSJames Wright static int ComputeLocalRange(Ceed ceed, CeedInt dim, CeedInt thread_1d, CeedInt *local_range, CeedInt max_group_size = 256) { 23*bd882c8aSJames Wright local_range[0] = thread_1d; 24*bd882c8aSJames Wright local_range[1] = (dim > 1) ? thread_1d : 1; 25*bd882c8aSJames Wright 26*bd882c8aSJames Wright const CeedInt min_group_size = local_range[0] * local_range[1]; 27*bd882c8aSJames Wright CeedCheck(min_group_size <= max_group_size, ceed, CEED_ERROR_BACKEND, "Requested group size is smaller than the required minimum."); 28*bd882c8aSJames Wright 29*bd882c8aSJames Wright local_range[2] = max_group_size / min_group_size; // elements per group 30*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 31*bd882c8aSJames Wright } 32*bd882c8aSJames Wright 33*bd882c8aSJames Wright //------------------------------------------------------------------------------ 34*bd882c8aSJames Wright // Apply basis 35*bd882c8aSJames Wright //------------------------------------------------------------------------------ 36*bd882c8aSJames Wright int CeedBasisApplyTensor_Sycl_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u, 37*bd882c8aSJames Wright CeedVector v) { 38*bd882c8aSJames Wright Ceed ceed; 39*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 40*bd882c8aSJames Wright Ceed_Sycl *ceed_Sycl; 41*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &ceed_Sycl)); 42*bd882c8aSJames Wright CeedBasis_Sycl_shared *impl; 43*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl)); 44*bd882c8aSJames Wright 45*bd882c8aSJames Wright // Read vectors 46*bd882c8aSJames Wright const CeedScalar *d_u; 47*bd882c8aSJames Wright CeedScalar *d_v; 48*bd882c8aSJames Wright if (eval_mode != CEED_EVAL_WEIGHT) { 49*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 50*bd882c8aSJames Wright } 51*bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 52*bd882c8aSJames Wright 53*bd882c8aSJames Wright // Apply basis operation 54*bd882c8aSJames Wright switch (eval_mode) { 55*bd882c8aSJames Wright case CEED_EVAL_INTERP: { 56*bd882c8aSJames Wright CeedInt *lrange = impl->interp_local_range; 57*bd882c8aSJames Wright const CeedInt &elem_per_group = lrange[2]; 58*bd882c8aSJames Wright const CeedInt group_count = (num_elem / elem_per_group) + !!(num_elem % elem_per_group); 59*bd882c8aSJames Wright //----------- 60*bd882c8aSJames Wright sycl::range<3> local_range(lrange[2], lrange[1], lrange[0]); 61*bd882c8aSJames Wright sycl::range<3> global_range(group_count * lrange[2], lrange[1], lrange[0]); 62*bd882c8aSJames Wright sycl::nd_range<3> kernel_range(global_range, local_range); 63*bd882c8aSJames Wright //----------- 64*bd882c8aSJames Wright sycl::kernel *interp_kernel = (t_mode == CEED_TRANSPOSE) ? impl->interp_transpose_kernel : impl->interp_kernel; 65*bd882c8aSJames Wright 66*bd882c8aSJames Wright // Order queue 67*bd882c8aSJames Wright sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier(); 68*bd882c8aSJames Wright 69*bd882c8aSJames Wright ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 70*bd882c8aSJames Wright cgh.depends_on(e); 71*bd882c8aSJames Wright cgh.set_args(num_elem, impl->d_interp_1d, d_u, d_v); 72*bd882c8aSJames Wright cgh.parallel_for(kernel_range, *interp_kernel); 73*bd882c8aSJames Wright }); 74*bd882c8aSJames Wright 75*bd882c8aSJames Wright } break; 76*bd882c8aSJames Wright case CEED_EVAL_GRAD: { 77*bd882c8aSJames Wright CeedInt *lrange = impl->grad_local_range; 78*bd882c8aSJames Wright const CeedInt &elem_per_group = lrange[2]; 79*bd882c8aSJames Wright const CeedInt group_count = (num_elem / elem_per_group) + !!(num_elem % elem_per_group); 80*bd882c8aSJames Wright //----------- 81*bd882c8aSJames Wright sycl::range<3> local_range(lrange[2], lrange[1], lrange[0]); 82*bd882c8aSJames Wright sycl::range<3> global_range(group_count * lrange[2], lrange[1], lrange[0]); 83*bd882c8aSJames Wright sycl::nd_range<3> kernel_range(global_range, local_range); 84*bd882c8aSJames Wright //----------- 85*bd882c8aSJames Wright sycl::kernel *grad_kernel = (t_mode == CEED_TRANSPOSE) ? impl->grad_transpose_kernel : impl->grad_kernel; 86*bd882c8aSJames Wright const CeedScalar *d_grad_1d = (impl->d_collo_grad_1d) ? impl->d_collo_grad_1d : impl->d_grad_1d; 87*bd882c8aSJames Wright // Order queue 88*bd882c8aSJames Wright sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier(); 89*bd882c8aSJames Wright 90*bd882c8aSJames Wright ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 91*bd882c8aSJames Wright cgh.depends_on(e); 92*bd882c8aSJames Wright cgh.set_args(num_elem, impl->d_interp_1d, d_grad_1d, d_u, d_v); 93*bd882c8aSJames Wright cgh.parallel_for(kernel_range, *grad_kernel); 94*bd882c8aSJames Wright }); 95*bd882c8aSJames Wright } break; 96*bd882c8aSJames Wright case CEED_EVAL_WEIGHT: { 97*bd882c8aSJames Wright CeedInt *lrange = impl->weight_local_range; 98*bd882c8aSJames Wright const CeedInt &elem_per_group = lrange[2]; 99*bd882c8aSJames Wright const CeedInt group_count = (num_elem / elem_per_group) + !!(num_elem % elem_per_group); 100*bd882c8aSJames Wright //----------- 101*bd882c8aSJames Wright sycl::range<3> local_range(lrange[2], lrange[1], lrange[0]); 102*bd882c8aSJames Wright sycl::range<3> global_range(group_count * lrange[2], lrange[1], lrange[0]); 103*bd882c8aSJames Wright sycl::nd_range<3> kernel_range(global_range, local_range); 104*bd882c8aSJames Wright //----------- 105*bd882c8aSJames Wright // Order queue 106*bd882c8aSJames Wright sycl::event e = ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier(); 107*bd882c8aSJames Wright 108*bd882c8aSJames Wright ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 109*bd882c8aSJames Wright cgh.depends_on(e); 110*bd882c8aSJames Wright cgh.set_args(num_elem, impl->d_q_weight_1d, d_v); 111*bd882c8aSJames Wright cgh.parallel_for(kernel_range, *(impl->weight_kernel)); 112*bd882c8aSJames Wright }); 113*bd882c8aSJames Wright } break; 114*bd882c8aSJames Wright // LCOV_EXCL_START 115*bd882c8aSJames Wright // Evaluate the divergence to/from the quadrature points 116*bd882c8aSJames Wright case CEED_EVAL_DIV: 117*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported"); 118*bd882c8aSJames Wright // Evaluate the curl to/from the quadrature points 119*bd882c8aSJames Wright case CEED_EVAL_CURL: 120*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported"); 121*bd882c8aSJames Wright // Take no action, BasisApply should not have been called 122*bd882c8aSJames Wright case CEED_EVAL_NONE: 123*bd882c8aSJames Wright return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context"); 124*bd882c8aSJames Wright // LCOV_EXCL_STOP 125*bd882c8aSJames Wright } 126*bd882c8aSJames Wright 127*bd882c8aSJames Wright // Restore vectors 128*bd882c8aSJames Wright if (eval_mode != CEED_EVAL_WEIGHT) { 129*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 130*bd882c8aSJames Wright } 131*bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 132*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 133*bd882c8aSJames Wright } 134*bd882c8aSJames Wright 135*bd882c8aSJames Wright //------------------------------------------------------------------------------ 136*bd882c8aSJames Wright // Destroy basis 137*bd882c8aSJames Wright //------------------------------------------------------------------------------ 138*bd882c8aSJames Wright static int CeedBasisDestroy_Sycl_shared(CeedBasis basis) { 139*bd882c8aSJames Wright Ceed ceed; 140*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 141*bd882c8aSJames Wright CeedBasis_Sycl_shared *impl; 142*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl)); 143*bd882c8aSJames Wright Ceed_Sycl *data; 144*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 145*bd882c8aSJames Wright 146*bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 147*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_q_weight_1d, data->sycl_context)); 148*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_interp_1d, data->sycl_context)); 149*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_grad_1d, data->sycl_context)); 150*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_collo_grad_1d, data->sycl_context)); 151*bd882c8aSJames Wright 152*bd882c8aSJames Wright delete impl->interp_kernel; 153*bd882c8aSJames Wright delete impl->interp_transpose_kernel; 154*bd882c8aSJames Wright delete impl->grad_kernel; 155*bd882c8aSJames Wright delete impl->grad_transpose_kernel; 156*bd882c8aSJames Wright delete impl->weight_kernel; 157*bd882c8aSJames Wright delete impl->sycl_module; 158*bd882c8aSJames Wright 159*bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 160*bd882c8aSJames Wright 161*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 162*bd882c8aSJames Wright } 163*bd882c8aSJames Wright 164*bd882c8aSJames Wright //------------------------------------------------------------------------------ 165*bd882c8aSJames Wright // Create tensor basis 166*bd882c8aSJames Wright // TODO: Refactor 167*bd882c8aSJames Wright //------------------------------------------------------------------------------ 168*bd882c8aSJames Wright int CeedBasisCreateTensorH1_Sycl_shared(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d, 169*bd882c8aSJames Wright const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) { 170*bd882c8aSJames Wright Ceed ceed; 171*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 172*bd882c8aSJames Wright CeedBasis_Sycl_shared *impl; 173*bd882c8aSJames Wright CeedCallBackend(CeedCalloc(1, &impl)); 174*bd882c8aSJames Wright Ceed_Sycl *data; 175*bd882c8aSJames Wright CeedCallBackend(CeedGetData(ceed, &data)); 176*bd882c8aSJames Wright 177*bd882c8aSJames Wright CeedInt num_comp; 178*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 179*bd882c8aSJames Wright 180*bd882c8aSJames Wright const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d); 181*bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim); 182*bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim); 183*bd882c8aSJames Wright 184*bd882c8aSJames Wright CeedInt *interp_lrange = impl->interp_local_range; 185*bd882c8aSJames Wright CeedCallBackend(ComputeLocalRange(ceed, dim, thread_1d, interp_lrange)); 186*bd882c8aSJames Wright const CeedInt interp_group_size = interp_lrange[0] * interp_lrange[1] * interp_lrange[2]; 187*bd882c8aSJames Wright 188*bd882c8aSJames Wright CeedInt *grad_lrange = impl->grad_local_range; 189*bd882c8aSJames Wright CeedCallBackend(ComputeLocalRange(ceed, dim, thread_1d, grad_lrange)); 190*bd882c8aSJames Wright const CeedInt grad_group_size = grad_lrange[0] * grad_lrange[1] * grad_lrange[2]; 191*bd882c8aSJames Wright 192*bd882c8aSJames Wright CeedCallBackend(ComputeLocalRange(ceed, dim, Q_1d, impl->weight_local_range)); 193*bd882c8aSJames Wright 194*bd882c8aSJames Wright // Copy basis data to GPU 195*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_q_weight_1d = sycl::malloc_device<CeedScalar>(Q_1d, data->sycl_device, data->sycl_context)); 196*bd882c8aSJames Wright sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight_1d, impl->d_q_weight_1d, Q_1d); 197*bd882c8aSJames Wright 198*bd882c8aSJames Wright const CeedInt interp_length = Q_1d * P_1d; 199*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_interp_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 200*bd882c8aSJames Wright sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp_1d, impl->d_interp_1d, interp_length); 201*bd882c8aSJames Wright 202*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_grad_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 203*bd882c8aSJames Wright sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad_1d, impl->d_grad_1d, interp_length); 204*bd882c8aSJames Wright 205*bd882c8aSJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_weight, copy_interp, copy_grad})); 206*bd882c8aSJames Wright 207*bd882c8aSJames Wright // Compute collocated gradient and copy to GPU 208*bd882c8aSJames Wright impl->d_collo_grad_1d = NULL; 209*bd882c8aSJames Wright const bool has_collocated_grad = (dim == 3) && (Q_1d >= P_1d); 210*bd882c8aSJames Wright if (has_collocated_grad) { 211*bd882c8aSJames Wright CeedScalar *collo_grad_1d; 212*bd882c8aSJames Wright CeedCallBackend(CeedMalloc(Q_1d * Q_1d, &collo_grad_1d)); 213*bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCollocatedGrad(basis, collo_grad_1d)); 214*bd882c8aSJames Wright const CeedInt cgrad_length = Q_1d * Q_1d; 215*bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_collo_grad_1d = sycl::malloc_device<CeedScalar>(cgrad_length, data->sycl_device, data->sycl_context)); 216*bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(collo_grad_1d, impl->d_collo_grad_1d, cgrad_length).wait_and_throw()); 217*bd882c8aSJames Wright CeedCallBackend(CeedFree(&collo_grad_1d)); 218*bd882c8aSJames Wright } 219*bd882c8aSJames Wright 220*bd882c8aSJames Wright // ---[Refactor into separate function]------> 221*bd882c8aSJames Wright // Define compile-time constants 222*bd882c8aSJames Wright std::map<std::string, CeedInt> jit_constants; 223*bd882c8aSJames Wright jit_constants["BASIS_DIM"] = dim; 224*bd882c8aSJames Wright jit_constants["BASIS_Q_1D"] = Q_1d; 225*bd882c8aSJames Wright jit_constants["BASIS_P_1D"] = P_1d; 226*bd882c8aSJames Wright jit_constants["T_1D"] = thread_1d; 227*bd882c8aSJames Wright jit_constants["BASIS_NUM_COMP"] = num_comp; 228*bd882c8aSJames Wright jit_constants["BASIS_NUM_NODES"] = num_nodes; 229*bd882c8aSJames Wright jit_constants["BASIS_NUM_QPTS"] = num_qpts; 230*bd882c8aSJames Wright jit_constants["BASIS_HAS_COLLOCATED_GRAD"] = has_collocated_grad; 231*bd882c8aSJames Wright jit_constants["BASIS_INTERP_SCRATCH_SIZE"] = interp_group_size; 232*bd882c8aSJames Wright jit_constants["BASIS_GRAD_SCRATCH_SIZE"] = grad_group_size; 233*bd882c8aSJames Wright 234*bd882c8aSJames Wright // Load kernel source 235*bd882c8aSJames Wright char *basis_kernel_path, *basis_kernel_source; 236*bd882c8aSJames Wright CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-shared-basis-tensor.h", &basis_kernel_path)); 237*bd882c8aSJames Wright CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source -----\n"); 238*bd882c8aSJames Wright CeedCallBackend(CeedLoadSourceToBuffer(ceed, basis_kernel_path, &basis_kernel_source)); 239*bd882c8aSJames Wright CeedDebug256(ceed, 2, "----- Loading Basis Kernel Source Complete -----\n"); 240*bd882c8aSJames Wright 241*bd882c8aSJames Wright // Compile kernels into a kernel bundle 242*bd882c8aSJames Wright CeedCallBackend(CeedJitBuildModule_Sycl(ceed, basis_kernel_source, &impl->sycl_module, jit_constants)); 243*bd882c8aSJames Wright 244*bd882c8aSJames Wright // Load kernel functions 245*bd882c8aSJames Wright CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, "Interp", &impl->interp_kernel)); 246*bd882c8aSJames Wright CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, "InterpTranspose", &impl->interp_transpose_kernel)); 247*bd882c8aSJames Wright CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, "Grad", &impl->grad_kernel)); 248*bd882c8aSJames Wright CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, "GradTranspose", &impl->grad_transpose_kernel)); 249*bd882c8aSJames Wright CeedCallBackend(CeedJitGetKernel_Sycl(ceed, impl->sycl_module, "Weight", &impl->weight_kernel)); 250*bd882c8aSJames Wright 251*bd882c8aSJames Wright // Clean-up 252*bd882c8aSJames Wright CeedCallBackend(CeedFree(&basis_kernel_path)); 253*bd882c8aSJames Wright CeedCallBackend(CeedFree(&basis_kernel_source)); 254*bd882c8aSJames Wright // <---[Refactor into separate function]------ 255*bd882c8aSJames Wright 256*bd882c8aSJames Wright CeedCallBackend(CeedBasisSetData(basis, impl)); 257*bd882c8aSJames Wright 258*bd882c8aSJames Wright // Register backend functions 259*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApplyTensor_Sycl_shared)); 260*bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Sycl_shared)); 261*bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 262*bd882c8aSJames Wright } 263*bd882c8aSJames Wright //------------------------------------------------------------------------------ 264