15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3bd882c8aSJames Wright // 4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5bd882c8aSJames Wright // 6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7bd882c8aSJames Wright 8bd882c8aSJames Wright #include <ceed/backend.h> 9bd882c8aSJames Wright #include <ceed/ceed.h> 10bd882c8aSJames Wright #include <ceed/jit-tools.h> 11bd882c8aSJames Wright 12bd882c8aSJames Wright #include <map> 13bd882c8aSJames Wright #include <string_view> 14bd882c8aSJames Wright #include <sycl/sycl.hpp> 15bd882c8aSJames Wright 16bd882c8aSJames Wright #include "../sycl/ceed-sycl-compile.hpp" 17bd882c8aSJames Wright #include "ceed-sycl-shared.hpp" 18bd882c8aSJames Wright 19bd882c8aSJames Wright //------------------------------------------------------------------------------ 20bd882c8aSJames Wright // Compute the local range of for basis kernels 21bd882c8aSJames Wright //------------------------------------------------------------------------------ 22bd882c8aSJames Wright static int ComputeLocalRange(Ceed ceed, CeedInt dim, CeedInt thread_1d, CeedInt *local_range, CeedInt max_group_size = 256) { 23bd882c8aSJames Wright local_range[0] = thread_1d; 24bd882c8aSJames Wright local_range[1] = (dim > 1) ? thread_1d : 1; 25bd882c8aSJames Wright const CeedInt min_group_size = local_range[0] * local_range[1]; 26dd64fc84SJeremy L Thompson 27bd882c8aSJames Wright CeedCheck(min_group_size <= max_group_size, ceed, CEED_ERROR_BACKEND, "Requested group size is smaller than the required minimum."); 28bd882c8aSJames Wright 29bd882c8aSJames Wright local_range[2] = max_group_size / min_group_size; // elements per group 30bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 31bd882c8aSJames Wright } 32bd882c8aSJames Wright 33bd882c8aSJames Wright //------------------------------------------------------------------------------ 34bd882c8aSJames Wright // Apply basis 35bd882c8aSJames Wright //------------------------------------------------------------------------------ 36bd882c8aSJames Wright int CeedBasisApplyTensor_Sycl_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u, 37bd882c8aSJames Wright CeedVector v) { 38bd882c8aSJames Wright Ceed ceed; 39bd882c8aSJames Wright Ceed_Sycl *ceed_Sycl; 40dd64fc84SJeremy L Thompson const CeedScalar *d_u; 41dd64fc84SJeremy L Thompson CeedScalar *d_v; 42bd882c8aSJames Wright CeedBasis_Sycl_shared *impl; 43dd64fc84SJeremy L Thompson 44dd64fc84SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 45dd64fc84SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &ceed_Sycl)); 46bd882c8aSJames Wright CeedCallBackend(CeedBasisGetData(basis, &impl)); 47bd882c8aSJames Wright 480ae60fd3SJeremy L Thompson // Get read/write access to u, v 490ae60fd3SJeremy L Thompson if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 500ae60fd3SJeremy L Thompson else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); 51bd882c8aSJames Wright CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 52bd882c8aSJames Wright 53bd882c8aSJames Wright // Apply basis operation 54bd882c8aSJames Wright switch (eval_mode) { 55bd882c8aSJames Wright case CEED_EVAL_INTERP: { 56bd882c8aSJames Wright CeedInt *lrange = impl->interp_local_range; 57bd882c8aSJames Wright const CeedInt &elem_per_group = lrange[2]; 58bd882c8aSJames Wright const CeedInt group_count = (num_elem / elem_per_group) + !!(num_elem % elem_per_group); 59bd882c8aSJames Wright //----------- 60bd882c8aSJames Wright sycl::range<3> local_range(lrange[2], lrange[1], lrange[0]); 61bd882c8aSJames Wright sycl::range<3> global_range(group_count * lrange[2], lrange[1], lrange[0]); 62bd882c8aSJames Wright sycl::nd_range<3> kernel_range(global_range, local_range); 63bd882c8aSJames Wright //----------- 64bd882c8aSJames Wright sycl::kernel *interp_kernel = (t_mode == CEED_TRANSPOSE) ? impl->interp_transpose_kernel : impl->interp_kernel; 65bd882c8aSJames Wright 66*1f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 67bd882c8aSJames Wright 68*1f4b1b45SUmesh Unnikrishnan if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()}; 69bd882c8aSJames Wright ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 70bd882c8aSJames Wright cgh.depends_on(e); 71bd882c8aSJames Wright cgh.set_args(num_elem, impl->d_interp_1d, d_u, d_v); 72bd882c8aSJames Wright cgh.parallel_for(kernel_range, *interp_kernel); 73bd882c8aSJames Wright }); 74bd882c8aSJames Wright 75bd882c8aSJames Wright } break; 76bd882c8aSJames Wright case CEED_EVAL_GRAD: { 77bd882c8aSJames Wright CeedInt *lrange = impl->grad_local_range; 78bd882c8aSJames Wright const CeedInt &elem_per_group = lrange[2]; 79bd882c8aSJames Wright const CeedInt group_count = (num_elem / elem_per_group) + !!(num_elem % elem_per_group); 80bd882c8aSJames Wright //----------- 81bd882c8aSJames Wright sycl::range<3> local_range(lrange[2], lrange[1], lrange[0]); 82bd882c8aSJames Wright sycl::range<3> global_range(group_count * lrange[2], lrange[1], lrange[0]); 83bd882c8aSJames Wright sycl::nd_range<3> kernel_range(global_range, local_range); 84bd882c8aSJames Wright //----------- 85bd882c8aSJames Wright sycl::kernel *grad_kernel = (t_mode == CEED_TRANSPOSE) ? impl->grad_transpose_kernel : impl->grad_kernel; 86bd882c8aSJames Wright const CeedScalar *d_grad_1d = (impl->d_collo_grad_1d) ? impl->d_collo_grad_1d : impl->d_grad_1d; 87*1f4b1b45SUmesh Unnikrishnan 88*1f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 89*1f4b1b45SUmesh Unnikrishnan 90*1f4b1b45SUmesh Unnikrishnan if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()}; 91bd882c8aSJames Wright 92bd882c8aSJames Wright ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 93bd882c8aSJames Wright cgh.depends_on(e); 94bd882c8aSJames Wright cgh.set_args(num_elem, impl->d_interp_1d, d_grad_1d, d_u, d_v); 95bd882c8aSJames Wright cgh.parallel_for(kernel_range, *grad_kernel); 96bd882c8aSJames Wright }); 97bd882c8aSJames Wright } break; 98bd882c8aSJames Wright case CEED_EVAL_WEIGHT: { 99bd882c8aSJames Wright CeedInt *lrange = impl->weight_local_range; 100bd882c8aSJames Wright const CeedInt &elem_per_group = lrange[2]; 101bd882c8aSJames Wright const CeedInt group_count = (num_elem / elem_per_group) + !!(num_elem % elem_per_group); 102bd882c8aSJames Wright //----------- 103bd882c8aSJames Wright sycl::range<3> local_range(lrange[2], lrange[1], lrange[0]); 104bd882c8aSJames Wright sycl::range<3> global_range(group_count * lrange[2], lrange[1], lrange[0]); 105bd882c8aSJames Wright sycl::nd_range<3> kernel_range(global_range, local_range); 106bd882c8aSJames Wright //----------- 107*1f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 108*1f4b1b45SUmesh Unnikrishnan 109*1f4b1b45SUmesh Unnikrishnan if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()}; 110bd882c8aSJames Wright 111bd882c8aSJames Wright ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 112bd882c8aSJames Wright cgh.depends_on(e); 113bd882c8aSJames Wright cgh.set_args(num_elem, impl->d_q_weight_1d, d_v); 114bd882c8aSJames Wright cgh.parallel_for(kernel_range, *(impl->weight_kernel)); 115bd882c8aSJames Wright }); 116bd882c8aSJames Wright } break; 1170ae60fd3SJeremy L Thompson case CEED_EVAL_NONE: /* handled separately below */ 1180ae60fd3SJeremy L Thompson break; 119bd882c8aSJames Wright // LCOV_EXCL_START 120bd882c8aSJames Wright case CEED_EVAL_DIV: 121bd882c8aSJames Wright case CEED_EVAL_CURL: 1224e3038a5SJeremy L Thompson return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]); 123bd882c8aSJames Wright // LCOV_EXCL_STOP 124bd882c8aSJames Wright } 125bd882c8aSJames Wright 1260ae60fd3SJeremy L Thompson // Restore vectors, cover CEED_EVAL_NONE 127bd882c8aSJames Wright CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 1280ae60fd3SJeremy L Thompson if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u)); 1290ae60fd3SJeremy L Thompson if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 1300ae60fd3SJeremy L Thompson 131bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 132bd882c8aSJames Wright } 133bd882c8aSJames Wright 134bd882c8aSJames Wright //------------------------------------------------------------------------------ 135bd882c8aSJames Wright // Destroy basis 136bd882c8aSJames Wright //------------------------------------------------------------------------------ 137bd882c8aSJames Wright static int CeedBasisDestroy_Sycl_shared(CeedBasis basis) { 138bd882c8aSJames Wright Ceed ceed; 139bd882c8aSJames Wright Ceed_Sycl *data; 140dd64fc84SJeremy L Thompson CeedBasis_Sycl_shared *impl; 141bd882c8aSJames Wright 142dd64fc84SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 143dd64fc84SJeremy L Thompson CeedCallBackend(CeedBasisGetData(basis, &impl)); 144dd64fc84SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 145bd882c8aSJames Wright CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 146bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_q_weight_1d, data->sycl_context)); 147bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_interp_1d, data->sycl_context)); 148bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_grad_1d, data->sycl_context)); 149bd882c8aSJames Wright CeedCallSycl(ceed, sycl::free(impl->d_collo_grad_1d, data->sycl_context)); 150bd882c8aSJames Wright 151bd882c8aSJames Wright delete impl->interp_kernel; 152bd882c8aSJames Wright delete impl->interp_transpose_kernel; 153bd882c8aSJames Wright delete impl->grad_kernel; 154bd882c8aSJames Wright delete impl->grad_transpose_kernel; 155bd882c8aSJames Wright delete impl->weight_kernel; 156bd882c8aSJames Wright delete impl->sycl_module; 157bd882c8aSJames Wright 158bd882c8aSJames Wright CeedCallBackend(CeedFree(&impl)); 159bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 160bd882c8aSJames Wright } 161bd882c8aSJames Wright 162bd882c8aSJames Wright //------------------------------------------------------------------------------ 163bd882c8aSJames Wright // Create tensor basis 164bd882c8aSJames Wright // TODO: Refactor 165bd882c8aSJames Wright //------------------------------------------------------------------------------ 166bd882c8aSJames Wright int CeedBasisCreateTensorH1_Sycl_shared(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d, 167bd882c8aSJames Wright const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) { 168bd882c8aSJames Wright Ceed ceed; 169bd882c8aSJames Wright Ceed_Sycl *data; 17022070f95SJeremy L Thompson char *basis_kernel_source; 17122070f95SJeremy L Thompson const char *basis_kernel_path; 172bd882c8aSJames Wright CeedInt num_comp; 173dd64fc84SJeremy L Thompson CeedBasis_Sycl_shared *impl; 174dd64fc84SJeremy L Thompson 175dd64fc84SJeremy L Thompson CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 176dd64fc84SJeremy L Thompson CeedCallBackend(CeedCalloc(1, &impl)); 177dd64fc84SJeremy L Thompson CeedCallBackend(CeedGetData(ceed, &data)); 178bd882c8aSJames Wright CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 179bd882c8aSJames Wright 180bd882c8aSJames Wright const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d); 181bd882c8aSJames Wright const CeedInt num_nodes = CeedIntPow(P_1d, dim); 182bd882c8aSJames Wright const CeedInt num_qpts = CeedIntPow(Q_1d, dim); 183bd882c8aSJames Wright 184bd882c8aSJames Wright CeedInt *interp_lrange = impl->interp_local_range; 185dd64fc84SJeremy L Thompson 186bd882c8aSJames Wright CeedCallBackend(ComputeLocalRange(ceed, dim, thread_1d, interp_lrange)); 187bd882c8aSJames Wright const CeedInt interp_group_size = interp_lrange[0] * interp_lrange[1] * interp_lrange[2]; 188bd882c8aSJames Wright 189bd882c8aSJames Wright CeedInt *grad_lrange = impl->grad_local_range; 190dd64fc84SJeremy L Thompson 191bd882c8aSJames Wright CeedCallBackend(ComputeLocalRange(ceed, dim, thread_1d, grad_lrange)); 192bd882c8aSJames Wright const CeedInt grad_group_size = grad_lrange[0] * grad_lrange[1] * grad_lrange[2]; 193bd882c8aSJames Wright 194bd882c8aSJames Wright CeedCallBackend(ComputeLocalRange(ceed, dim, Q_1d, impl->weight_local_range)); 195bd882c8aSJames Wright 196*1f4b1b45SUmesh Unnikrishnan std::vector<sycl::event> e; 197*1f4b1b45SUmesh Unnikrishnan 198*1f4b1b45SUmesh Unnikrishnan if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 199*1f4b1b45SUmesh Unnikrishnan 200bd882c8aSJames Wright // Copy basis data to GPU 201bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_q_weight_1d = sycl::malloc_device<CeedScalar>(Q_1d, data->sycl_device, data->sycl_context)); 202*1f4b1b45SUmesh Unnikrishnan sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight_1d, impl->d_q_weight_1d, Q_1d, e); 203bd882c8aSJames Wright 204bd882c8aSJames Wright const CeedInt interp_length = Q_1d * P_1d; 205bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_interp_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 206*1f4b1b45SUmesh Unnikrishnan sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp_1d, impl->d_interp_1d, interp_length, e); 207bd882c8aSJames Wright 208bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_grad_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 209*1f4b1b45SUmesh Unnikrishnan sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad_1d, impl->d_grad_1d, interp_length, e); 210bd882c8aSJames Wright 211bd882c8aSJames Wright CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_weight, copy_interp, copy_grad})); 212bd882c8aSJames Wright 213bd882c8aSJames Wright // Compute collocated gradient and copy to GPU 214bd882c8aSJames Wright impl->d_collo_grad_1d = NULL; 215bd882c8aSJames Wright const bool has_collocated_grad = (dim == 3) && (Q_1d >= P_1d); 216dd64fc84SJeremy L Thompson 217bd882c8aSJames Wright if (has_collocated_grad) { 218bd882c8aSJames Wright CeedScalar *collo_grad_1d; 219dd64fc84SJeremy L Thompson const CeedInt cgrad_length = Q_1d * Q_1d; 220dd64fc84SJeremy L Thompson 221bd882c8aSJames Wright CeedCallBackend(CeedMalloc(Q_1d * Q_1d, &collo_grad_1d)); 222bd882c8aSJames Wright CeedCallBackend(CeedBasisGetCollocatedGrad(basis, collo_grad_1d)); 223bd882c8aSJames Wright CeedCallSycl(ceed, impl->d_collo_grad_1d = sycl::malloc_device<CeedScalar>(cgrad_length, data->sycl_device, data->sycl_context)); 224*1f4b1b45SUmesh Unnikrishnan CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(collo_grad_1d, impl->d_collo_grad_1d, cgrad_length, e).wait_and_throw()); 225bd882c8aSJames Wright CeedCallBackend(CeedFree(&collo_grad_1d)); 226bd882c8aSJames Wright } 227bd882c8aSJames Wright 228bd882c8aSJames Wright // ---[Refactor into separate function]------> 229bd882c8aSJames Wright // Define compile-time constants 230bd882c8aSJames Wright std::map<std::string, CeedInt> jit_constants; 231bd882c8aSJames Wright jit_constants["BASIS_DIM"] = dim; 232bd882c8aSJames Wright jit_constants["BASIS_Q_1D"] = Q_1d; 233bd882c8aSJames Wright jit_constants["BASIS_P_1D"] = P_1d; 234bd882c8aSJames Wright jit_constants["T_1D"] = thread_1d; 235bd882c8aSJames Wright jit_constants["BASIS_NUM_COMP"] = num_comp; 236bd882c8aSJames Wright jit_constants["BASIS_NUM_NODES"] = num_nodes; 237bd882c8aSJames Wright jit_constants["BASIS_NUM_QPTS"] = num_qpts; 238bd882c8aSJames Wright jit_constants["BASIS_HAS_COLLOCATED_GRAD"] = has_collocated_grad; 239bd882c8aSJames Wright jit_constants["BASIS_INTERP_SCRATCH_SIZE"] = interp_group_size; 240bd882c8aSJames Wright jit_constants["BASIS_GRAD_SCRATCH_SIZE"] = grad_group_size; 241bd882c8aSJames Wright 242bd882c8aSJames Wright // Load kernel source 243bd882c8aSJames Wright CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-shared-basis-tensor.h", &basis_kernel_path)); 24423d4529eSJeremy L Thompson CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n"); 24522070f95SJeremy L Thompson { 24622070f95SJeremy L Thompson char *source; 24722070f95SJeremy L Thompson 24822070f95SJeremy L Thompson CeedCallBackend(CeedLoadSourceToBuffer(ceed, basis_kernel_path, &source)); 24922070f95SJeremy L Thompson basis_kernel_source = source; 25022070f95SJeremy L Thompson } 25123d4529eSJeremy L Thompson CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete -----\n"); 252bd882c8aSJames Wright 253bd882c8aSJames Wright // Compile kernels into a kernel bundle 254eb7e6cafSJeremy L Thompson CeedCallBackend(CeedBuildModule_Sycl(ceed, basis_kernel_source, &impl->sycl_module, jit_constants)); 255bd882c8aSJames Wright 256bd882c8aSJames Wright // Load kernel functions 257eb7e6cafSJeremy L Thompson CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "Interp", &impl->interp_kernel)); 258eb7e6cafSJeremy L Thompson CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "InterpTranspose", &impl->interp_transpose_kernel)); 259eb7e6cafSJeremy L Thompson CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "Grad", &impl->grad_kernel)); 260eb7e6cafSJeremy L Thompson CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "GradTranspose", &impl->grad_transpose_kernel)); 261eb7e6cafSJeremy L Thompson CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "Weight", &impl->weight_kernel)); 262bd882c8aSJames Wright 263bd882c8aSJames Wright // Clean-up 264bd882c8aSJames Wright CeedCallBackend(CeedFree(&basis_kernel_path)); 265bd882c8aSJames Wright CeedCallBackend(CeedFree(&basis_kernel_source)); 266bd882c8aSJames Wright // <---[Refactor into separate function]------ 267bd882c8aSJames Wright 268bd882c8aSJames Wright CeedCallBackend(CeedBasisSetData(basis, impl)); 269bd882c8aSJames Wright 270bd882c8aSJames Wright // Register backend functions 271bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApplyTensor_Sycl_shared)); 272bd882c8aSJames Wright CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Sycl_shared)); 273bd882c8aSJames Wright return CEED_ERROR_SUCCESS; 274bd882c8aSJames Wright } 275ff1e7120SSebastian Grimberg 276bd882c8aSJames Wright //------------------------------------------------------------------------------ 277