1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed/backend.h> 9 #include <ceed/ceed.h> 10 #include <ceed/jit-tools.h> 11 12 #include <map> 13 #include <string_view> 14 #include <sycl/sycl.hpp> 15 16 #include "../sycl/ceed-sycl-compile.hpp" 17 #include "ceed-sycl-shared.hpp" 18 19 //------------------------------------------------------------------------------ 20 // Compute the local range of for basis kernels 21 //------------------------------------------------------------------------------ 22 static int ComputeLocalRange(Ceed ceed, CeedInt dim, CeedInt thread_1d, CeedInt *local_range, CeedInt max_group_size = 256) { 23 local_range[0] = thread_1d; 24 local_range[1] = (dim > 1) ? thread_1d : 1; 25 const CeedInt min_group_size = local_range[0] * local_range[1]; 26 27 CeedCheck(min_group_size <= max_group_size, ceed, CEED_ERROR_BACKEND, "Requested group size is smaller than the required minimum."); 28 29 local_range[2] = max_group_size / min_group_size; // elements per group 30 return CEED_ERROR_SUCCESS; 31 } 32 33 //------------------------------------------------------------------------------ 34 // Apply basis 35 //------------------------------------------------------------------------------ 36 int CeedBasisApplyTensor_Sycl_shared(CeedBasis basis, const CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode eval_mode, CeedVector u, 37 CeedVector v) { 38 Ceed ceed; 39 Ceed_Sycl *ceed_Sycl; 40 const CeedScalar *d_u; 41 CeedScalar *d_v; 42 CeedBasis_Sycl_shared *impl; 43 44 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 45 CeedCallBackend(CeedGetData(ceed, &ceed_Sycl)); 46 CeedCallBackend(CeedBasisGetData(basis, &impl)); 47 48 // Get read/write access to u, v 49 if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 50 else CeedCheck(eval_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); 51 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 52 53 // Apply basis operation 54 switch (eval_mode) { 55 case CEED_EVAL_INTERP: { 56 CeedInt *lrange = impl->interp_local_range; 57 const CeedInt &elem_per_group = lrange[2]; 58 const CeedInt group_count = (num_elem / elem_per_group) + !!(num_elem % elem_per_group); 59 //----------- 60 sycl::range<3> local_range(lrange[2], lrange[1], lrange[0]); 61 sycl::range<3> global_range(group_count * lrange[2], lrange[1], lrange[0]); 62 sycl::nd_range<3> kernel_range(global_range, local_range); 63 //----------- 64 sycl::kernel *interp_kernel = (t_mode == CEED_TRANSPOSE) ? impl->interp_transpose_kernel : impl->interp_kernel; 65 66 std::vector<sycl::event> e; 67 68 if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()}; 69 ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 70 cgh.depends_on(e); 71 cgh.set_args(num_elem, impl->d_interp_1d, d_u, d_v); 72 cgh.parallel_for(kernel_range, *interp_kernel); 73 }); 74 75 } break; 76 case CEED_EVAL_GRAD: { 77 CeedInt *lrange = impl->grad_local_range; 78 const CeedInt &elem_per_group = lrange[2]; 79 const CeedInt group_count = (num_elem / elem_per_group) + !!(num_elem % elem_per_group); 80 //----------- 81 sycl::range<3> local_range(lrange[2], lrange[1], lrange[0]); 82 sycl::range<3> global_range(group_count * lrange[2], lrange[1], lrange[0]); 83 sycl::nd_range<3> kernel_range(global_range, local_range); 84 //----------- 85 sycl::kernel *grad_kernel = (t_mode == CEED_TRANSPOSE) ? impl->grad_transpose_kernel : impl->grad_kernel; 86 const CeedScalar *d_grad_1d = (impl->d_collo_grad_1d) ? impl->d_collo_grad_1d : impl->d_grad_1d; 87 88 std::vector<sycl::event> e; 89 90 if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()}; 91 92 ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 93 cgh.depends_on(e); 94 cgh.set_args(num_elem, impl->d_interp_1d, d_grad_1d, d_u, d_v); 95 cgh.parallel_for(kernel_range, *grad_kernel); 96 }); 97 } break; 98 case CEED_EVAL_WEIGHT: { 99 CeedInt *lrange = impl->weight_local_range; 100 const CeedInt &elem_per_group = lrange[2]; 101 const CeedInt group_count = (num_elem / elem_per_group) + !!(num_elem % elem_per_group); 102 //----------- 103 sycl::range<3> local_range(lrange[2], lrange[1], lrange[0]); 104 sycl::range<3> global_range(group_count * lrange[2], lrange[1], lrange[0]); 105 sycl::nd_range<3> kernel_range(global_range, local_range); 106 //----------- 107 std::vector<sycl::event> e; 108 109 CeedCheck(impl->d_q_weight_1d, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weight_1d not set", CeedEvalModes[eval_mode]); 110 if (!ceed_Sycl->sycl_queue.is_in_order()) e = {ceed_Sycl->sycl_queue.ext_oneapi_submit_barrier()}; 111 112 ceed_Sycl->sycl_queue.submit([&](sycl::handler &cgh) { 113 cgh.depends_on(e); 114 cgh.set_args(num_elem, impl->d_q_weight_1d, d_v); 115 cgh.parallel_for(kernel_range, *(impl->weight_kernel)); 116 }); 117 } break; 118 case CEED_EVAL_NONE: /* handled separately below */ 119 break; 120 // LCOV_EXCL_START 121 case CEED_EVAL_DIV: 122 case CEED_EVAL_CURL: 123 return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[eval_mode]); 124 // LCOV_EXCL_STOP 125 } 126 127 // Restore vectors, cover CEED_EVAL_NONE 128 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 129 if (eval_mode == CEED_EVAL_NONE) CeedCallBackend(CeedVectorSetArray(v, CEED_MEM_DEVICE, CEED_COPY_VALUES, (CeedScalar *)d_u)); 130 if (eval_mode != CEED_EVAL_WEIGHT) CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 131 CeedCallBackend(CeedDestroy(&ceed)); 132 return CEED_ERROR_SUCCESS; 133 } 134 135 //------------------------------------------------------------------------------ 136 // Destroy basis 137 //------------------------------------------------------------------------------ 138 static int CeedBasisDestroy_Sycl_shared(CeedBasis basis) { 139 Ceed ceed; 140 Ceed_Sycl *data; 141 CeedBasis_Sycl_shared *impl; 142 143 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 144 CeedCallBackend(CeedBasisGetData(basis, &impl)); 145 CeedCallBackend(CeedGetData(ceed, &data)); 146 CeedCallSycl(ceed, data->sycl_queue.wait_and_throw()); 147 if (impl->d_q_weight_1d) CeedCallSycl(ceed, sycl::free(impl->d_q_weight_1d, data->sycl_context)); 148 CeedCallSycl(ceed, sycl::free(impl->d_interp_1d, data->sycl_context)); 149 CeedCallSycl(ceed, sycl::free(impl->d_grad_1d, data->sycl_context)); 150 CeedCallSycl(ceed, sycl::free(impl->d_collo_grad_1d, data->sycl_context)); 151 152 delete impl->interp_kernel; 153 delete impl->interp_transpose_kernel; 154 delete impl->grad_kernel; 155 delete impl->grad_transpose_kernel; 156 delete impl->weight_kernel; 157 delete impl->sycl_module; 158 159 CeedCallBackend(CeedFree(&impl)); 160 CeedCallBackend(CeedDestroy(&ceed)); 161 return CEED_ERROR_SUCCESS; 162 } 163 164 //------------------------------------------------------------------------------ 165 // Create tensor basis 166 // TODO: Refactor 167 //------------------------------------------------------------------------------ 168 int CeedBasisCreateTensorH1_Sycl_shared(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d, 169 const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) { 170 Ceed ceed; 171 Ceed_Sycl *data; 172 char *basis_kernel_source; 173 const char *basis_kernel_path; 174 CeedInt num_comp; 175 CeedBasis_Sycl_shared *impl; 176 177 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 178 CeedCallBackend(CeedCalloc(1, &impl)); 179 CeedCallBackend(CeedGetData(ceed, &data)); 180 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 181 182 const CeedInt thread_1d = CeedIntMax(Q_1d, P_1d); 183 const CeedInt num_nodes = CeedIntPow(P_1d, dim); 184 const CeedInt num_qpts = CeedIntPow(Q_1d, dim); 185 186 CeedInt *interp_lrange = impl->interp_local_range; 187 188 CeedCallBackend(ComputeLocalRange(ceed, dim, thread_1d, interp_lrange)); 189 const CeedInt interp_group_size = interp_lrange[0] * interp_lrange[1] * interp_lrange[2]; 190 191 CeedInt *grad_lrange = impl->grad_local_range; 192 193 CeedCallBackend(ComputeLocalRange(ceed, dim, thread_1d, grad_lrange)); 194 const CeedInt grad_group_size = grad_lrange[0] * grad_lrange[1] * grad_lrange[2]; 195 196 CeedCallBackend(ComputeLocalRange(ceed, dim, Q_1d, impl->weight_local_range)); 197 198 std::vector<sycl::event> e; 199 200 if (!data->sycl_queue.is_in_order()) e = {data->sycl_queue.ext_oneapi_submit_barrier()}; 201 202 // Copy basis data to GPU 203 std::vector<sycl::event> copy_events; 204 if (q_weight_1d) { 205 CeedCallSycl(ceed, impl->d_q_weight_1d = sycl::malloc_device<CeedScalar>(Q_1d, data->sycl_device, data->sycl_context)); 206 sycl::event copy_weight = data->sycl_queue.copy<CeedScalar>(q_weight_1d, impl->d_q_weight_1d, Q_1d, e); 207 copy_events.push_back(copy_weight); 208 } 209 210 const CeedInt interp_length = Q_1d * P_1d; 211 CeedCallSycl(ceed, impl->d_interp_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 212 sycl::event copy_interp = data->sycl_queue.copy<CeedScalar>(interp_1d, impl->d_interp_1d, interp_length, e); 213 copy_events.push_back(copy_interp); 214 215 CeedCallSycl(ceed, impl->d_grad_1d = sycl::malloc_device<CeedScalar>(interp_length, data->sycl_device, data->sycl_context)); 216 sycl::event copy_grad = data->sycl_queue.copy<CeedScalar>(grad_1d, impl->d_grad_1d, interp_length, e); 217 copy_events.push_back(copy_grad); 218 219 CeedCallSycl(ceed, sycl::event::wait_and_throw(copy_events)); 220 221 // Compute collocated gradient and copy to GPU 222 impl->d_collo_grad_1d = NULL; 223 const bool has_collocated_grad = (dim == 3) && (Q_1d >= P_1d); 224 225 if (has_collocated_grad) { 226 CeedScalar *collo_grad_1d; 227 const CeedInt cgrad_length = Q_1d * Q_1d; 228 229 CeedCallBackend(CeedMalloc(Q_1d * Q_1d, &collo_grad_1d)); 230 CeedCallBackend(CeedBasisGetCollocatedGrad(basis, collo_grad_1d)); 231 CeedCallSycl(ceed, impl->d_collo_grad_1d = sycl::malloc_device<CeedScalar>(cgrad_length, data->sycl_device, data->sycl_context)); 232 CeedCallSycl(ceed, data->sycl_queue.copy<CeedScalar>(collo_grad_1d, impl->d_collo_grad_1d, cgrad_length, e).wait_and_throw()); 233 CeedCallBackend(CeedFree(&collo_grad_1d)); 234 } 235 236 // ---[Refactor into separate function]------> 237 // Define compile-time constants 238 std::map<std::string, CeedInt> jit_constants; 239 jit_constants["BASIS_DIM"] = dim; 240 jit_constants["BASIS_Q_1D"] = Q_1d; 241 jit_constants["BASIS_P_1D"] = P_1d; 242 jit_constants["T_1D"] = thread_1d; 243 jit_constants["BASIS_NUM_COMP"] = num_comp; 244 jit_constants["BASIS_NUM_NODES"] = num_nodes; 245 jit_constants["BASIS_NUM_QPTS"] = num_qpts; 246 jit_constants["BASIS_HAS_COLLOCATED_GRAD"] = has_collocated_grad; 247 jit_constants["BASIS_INTERP_SCRATCH_SIZE"] = interp_group_size; 248 jit_constants["BASIS_GRAD_SCRATCH_SIZE"] = grad_group_size; 249 250 // Load kernel source 251 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/sycl/sycl-shared-basis-tensor.h", &basis_kernel_path)); 252 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n"); 253 { 254 char *source; 255 256 CeedCallBackend(CeedLoadSourceToBuffer(ceed, basis_kernel_path, &source)); 257 basis_kernel_source = source; 258 } 259 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete -----\n"); 260 261 // Compile kernels into a kernel bundle 262 CeedCallBackend(CeedBuildModule_Sycl(ceed, basis_kernel_source, &impl->sycl_module, jit_constants)); 263 264 // Load kernel functions 265 CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "Interp", &impl->interp_kernel)); 266 CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "InterpTranspose", &impl->interp_transpose_kernel)); 267 CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "Grad", &impl->grad_kernel)); 268 CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "GradTranspose", &impl->grad_transpose_kernel)); 269 CeedCallBackend(CeedGetKernel_Sycl(ceed, impl->sycl_module, "Weight", &impl->weight_kernel)); 270 271 // Clean-up 272 CeedCallBackend(CeedFree(&basis_kernel_path)); 273 CeedCallBackend(CeedFree(&basis_kernel_source)); 274 // <---[Refactor into separate function]------ 275 276 CeedCallBackend(CeedBasisSetData(basis, impl)); 277 278 // Register backend functions 279 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Apply", CeedBasisApplyTensor_Sycl_shared)); 280 CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Sycl_shared)); 281 CeedCallBackend(CeedDestroy(&ceed)); 282 return CEED_ERROR_SUCCESS; 283 } 284 285 //------------------------------------------------------------------------------ 286