1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <string.h> 12 13 #ifdef CEED_MAGMA_USE_HIP 14 #include "../hip/ceed-hip-common.h" 15 #include "../hip/ceed-hip-compile.h" 16 #else 17 #include "../cuda/ceed-cuda-common.h" 18 #include "../cuda/ceed-cuda-compile.h" 19 #endif 20 #include "ceed-magma-common.h" 21 #include "ceed-magma.h" 22 23 #include "ceed-magma-gemm-nontensor.h" 24 #include "ceed-magma-gemm-selector.h" 25 26 //------------------------------------------------------------------------------ 27 // Basis apply - tensor 28 //------------------------------------------------------------------------------ 29 static int CeedBasisApplyCore_Magma(CeedBasis basis, bool apply_add, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, 30 CeedVector v) { 31 Ceed ceed; 32 Ceed_Magma *data; 33 CeedInt dim, num_comp, num_nodes, P_1d, Q_1d, P, Q; 34 const CeedScalar *d_u; 35 CeedScalar *d_v; 36 CeedBasis_Magma *impl; 37 38 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 39 CeedCallBackend(CeedGetData(ceed, &data)); 40 CeedCallBackend(CeedBasisGetData(basis, &impl)); 41 CeedCallBackend(CeedBasisGetDimension(basis, &dim)); 42 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 43 CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); 44 CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d)); 45 CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d)); 46 P = P_1d; 47 Q = Q_1d; 48 if (t_mode == CEED_TRANSPOSE) { 49 P = Q_1d; 50 Q = P_1d; 51 } 52 53 // Read vectors 54 if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 55 else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); 56 if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 57 else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 58 59 // Apply basis operation 60 switch (e_mode) { 61 case CEED_EVAL_INTERP: { 62 // Define element sizes for dofs/quad 63 CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim); 64 CeedInt elem_dofs_size = CeedIntPow(P_1d, dim); 65 66 // E-vector ordering -------------- Q-vector ordering 67 // component component 68 // elem elem 69 // node node 70 71 // --- Define strides for NOTRANSPOSE mode: --- 72 // Input (d_u) is E-vector, output (d_v) is Q-vector 73 74 // Element strides 75 CeedInt u_elem_stride = elem_dofs_size; 76 CeedInt v_elem_stride = elem_qpts_size; 77 // Component strides 78 CeedInt u_comp_stride = num_elem * elem_dofs_size; 79 CeedInt v_comp_stride = num_elem * elem_qpts_size; 80 if (t_mode == CEED_TRANSPOSE) { 81 // Input (d_u) is Q-vector, output (d_v) is E-vector 82 // Element strides 83 v_elem_stride = elem_dofs_size; 84 u_elem_stride = elem_qpts_size; 85 // Component strides 86 v_comp_stride = num_elem * elem_dofs_size; 87 u_comp_stride = num_elem * elem_qpts_size; 88 } 89 CeedInt num_threads = 1; 90 CeedInt num_t_col = 1; 91 CeedInt shared_mem = 0; 92 CeedInt max_P_Q = CeedIntMax(P, Q); 93 94 switch (dim) { 95 case 1: 96 num_threads = max_P_Q; 97 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D); 98 shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q)); 99 shared_mem += sizeof(CeedScalar) * (P * Q); 100 break; 101 case 2: 102 num_threads = max_P_Q; 103 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D); 104 shared_mem += P * Q * sizeof(CeedScalar); // for sT 105 // for reforming rU we need P x P, and for the intermediate output we need P x Q 106 shared_mem += num_t_col * (P * max_P_Q * sizeof(CeedScalar)); 107 break; 108 case 3: 109 num_threads = max_P_Q * max_P_Q; 110 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D); 111 shared_mem += sizeof(CeedScalar) * (P * Q); // for sT 112 // rU needs P^2 x P, the intermediate output needs max(P^2 x Q, P x Q^2) 113 shared_mem += sizeof(CeedScalar) * num_t_col * (CeedIntMax(P * P * max_P_Q, P * Q * Q)); 114 break; 115 } 116 CeedInt grid = CeedDivUpInt(num_elem, num_t_col); 117 void *args[] = {&impl->d_interp_1d, &d_u, &u_elem_stride, &u_comp_stride, &d_v, &v_elem_stride, &v_comp_stride, &num_elem}; 118 119 if (t_mode == CEED_TRANSPOSE) { 120 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, apply_add ? impl->InterpTransposeAdd : impl->InterpTranspose, grid, num_threads, num_t_col, 121 1, shared_mem, args)); 122 } else { 123 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Interp, grid, num_threads, num_t_col, 1, shared_mem, args)); 124 } 125 } break; 126 case CEED_EVAL_GRAD: { 127 // Define element sizes for dofs/quad 128 CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim); 129 CeedInt elem_dofs_size = CeedIntPow(P_1d, dim); 130 131 // In CEED_NOTRANSPOSE mode: 132 // d_u is (P^dim x nc), column-major layout (nc = num_comp) 133 // d_v is (Q^dim x nc x dim), column-major layout (nc = num_comp) 134 // In CEED_TRANSPOSE mode, the sizes of d_u and d_v are switched. 135 136 // E-vector ordering -------------- Q-vector ordering 137 // dim 138 // component component 139 // elem elem 140 // node node 141 142 // --- Define strides for NOTRANSPOSE mode: --- 143 // Input (d_u) is E-vector, output (d_v) is Q-vector 144 145 // Element strides 146 CeedInt u_elem_stride = elem_dofs_size; 147 CeedInt v_elem_stride = elem_qpts_size; 148 // Component strides 149 CeedInt u_comp_stride = num_elem * elem_dofs_size; 150 CeedInt v_comp_stride = num_elem * elem_qpts_size; 151 // Dimension strides 152 CeedInt u_dim_stride = 0; 153 CeedInt v_dim_stride = num_elem * elem_qpts_size * num_comp; 154 if (t_mode == CEED_TRANSPOSE) { 155 // Input (d_u) is Q-vector, output (d_v) is E-vector 156 // Element strides 157 v_elem_stride = elem_dofs_size; 158 u_elem_stride = elem_qpts_size; 159 // Component strides 160 v_comp_stride = num_elem * elem_dofs_size; 161 u_comp_stride = num_elem * elem_qpts_size; 162 // Dimension strides 163 v_dim_stride = 0; 164 u_dim_stride = num_elem * elem_qpts_size * num_comp; 165 } 166 CeedInt num_threads = 1; 167 CeedInt num_t_col = 1; 168 CeedInt shared_mem = 0; 169 CeedInt max_P_Q = CeedIntMax(P, Q); 170 171 switch (dim) { 172 case 1: 173 num_threads = max_P_Q; 174 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D); 175 shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q)); 176 shared_mem += sizeof(CeedScalar) * (P * Q); 177 break; 178 case 2: 179 num_threads = max_P_Q; 180 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D); 181 shared_mem += sizeof(CeedScalar) * 2 * P * Q; // for sTinterp and sTgrad 182 // for reforming rU we need P x P, and for the intermediate output we need P x Q 183 shared_mem += sizeof(CeedScalar) * num_t_col * (P * max_P_Q); 184 break; 185 case 3: 186 num_threads = max_P_Q * max_P_Q; 187 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D); 188 shared_mem += sizeof(CeedScalar) * 2 * P * Q; // for sTinterp and sTgrad 189 // rU needs P^2 x P, the intermediate outputs need (P^2 x Q + P x Q^2) 190 shared_mem += sizeof(CeedScalar) * num_t_col * CeedIntMax(P * P * P, (P * P * Q) + (P * Q * Q)); 191 break; 192 } 193 CeedInt grid = CeedDivUpInt(num_elem, num_t_col); 194 void *args[] = {&impl->d_interp_1d, &impl->d_grad_1d, &d_u, &u_elem_stride, &u_comp_stride, &u_dim_stride, &d_v, 195 &v_elem_stride, &v_comp_stride, &v_dim_stride, &num_elem}; 196 197 if (t_mode == CEED_TRANSPOSE) { 198 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, apply_add ? impl->GradTransposeAdd : impl->GradTranspose, grid, num_threads, num_t_col, 1, 199 shared_mem, args)); 200 } else { 201 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Grad, grid, num_threads, num_t_col, 1, shared_mem, args)); 202 } 203 } break; 204 case CEED_EVAL_WEIGHT: { 205 CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE"); 206 CeedCheck(impl->d_q_weight_1d, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weight_1d not set", CeedEvalModes[e_mode]); 207 CeedInt elem_dofs_size = CeedIntPow(Q, dim); 208 CeedInt num_threads = 1; 209 CeedInt num_t_col = 1; 210 CeedInt shared_mem = 0; 211 212 switch (dim) { 213 case 1: 214 num_threads = Q; 215 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D); 216 shared_mem += sizeof(CeedScalar) * Q; // for d_q_weight_1d 217 shared_mem += sizeof(CeedScalar) * num_t_col * Q; // for output 218 break; 219 case 2: 220 num_threads = Q; 221 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D); 222 shared_mem += sizeof(CeedScalar) * Q; // for d_q_weight_1d 223 break; 224 case 3: 225 num_threads = Q * Q; 226 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D); 227 shared_mem += sizeof(CeedScalar) * Q; // for d_q_weight_1d 228 break; 229 } 230 CeedInt grid = CeedDivUpInt(num_elem, num_t_col); 231 void *args[] = {&impl->d_q_weight_1d, &d_v, &elem_dofs_size, &num_elem}; 232 233 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Weight, grid, num_threads, num_t_col, 1, shared_mem, args)); 234 } break; 235 // LCOV_EXCL_START 236 case CEED_EVAL_DIV: 237 case CEED_EVAL_CURL: 238 return CeedError(ceed, CEED_ERROR_BACKEND, "%s not supported", CeedEvalModes[e_mode]); 239 case CEED_EVAL_NONE: 240 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context"); 241 // LCOV_EXCL_STOP 242 } 243 244 // Must sync to ensure completeness 245 ceed_magma_queue_sync(data->queue); 246 247 // Restore vectors 248 if (e_mode != CEED_EVAL_WEIGHT) { 249 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 250 } 251 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 252 return CEED_ERROR_SUCCESS; 253 } 254 255 static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, CeedVector v) { 256 CeedCallBackend(CeedBasisApplyCore_Magma(basis, false, num_elem, t_mode, e_mode, u, v)); 257 return CEED_ERROR_SUCCESS; 258 } 259 260 static int CeedBasisApplyAdd_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, CeedVector v) { 261 CeedCallBackend(CeedBasisApplyCore_Magma(basis, true, num_elem, t_mode, e_mode, u, v)); 262 return CEED_ERROR_SUCCESS; 263 } 264 265 //------------------------------------------------------------------------------ 266 // Basis apply - tensor AtPoints 267 //------------------------------------------------------------------------------ 268 int CeedBasisApplyAtPoints_Magma(CeedBasis basis, const CeedInt num_elem, const CeedInt *num_points, CeedTransposeMode t_mode, CeedEvalMode eval_mode, 269 CeedVector x_ref, CeedVector u, CeedVector v) { 270 return CeedError(CeedBasisReturnCeed(basis), CEED_ERROR_BACKEND, "Backend does not implement CeedBasisApplyAtPoints"); 271 } 272 273 //------------------------------------------------------------------------------ 274 // Basis apply - non-tensor 275 //------------------------------------------------------------------------------ 276 static int CeedBasisApplyNonTensorCore_Magma(CeedBasis basis, bool apply_add, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, 277 CeedVector u, CeedVector v) { 278 Ceed ceed; 279 Ceed_Magma *data; 280 CeedInt num_comp, num_nodes, num_qpts, P, Q, N; 281 const CeedScalar *d_u; 282 CeedScalar *d_v; 283 CeedBasisNonTensor_Magma *impl; 284 285 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 286 CeedCallBackend(CeedGetData(ceed, &data)); 287 CeedCallBackend(CeedBasisGetData(basis, &impl)); 288 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 289 CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); 290 CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts)); 291 P = num_nodes; 292 Q = num_qpts; 293 N = num_elem * num_comp; 294 295 // Read vectors 296 if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 297 else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); 298 if (apply_add) CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v)); 299 else CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 300 301 // Compile kernels for N as needed 302 CeedInt iN = 0; 303 if (P <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q && (e_mode != CEED_EVAL_WEIGHT || !impl->Weight)) { 304 CeedInt n_array[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_KERNEL_N_VALUES}; 305 CeedInt diff = abs(n_array[iN] - N), idiff; 306 307 for (CeedInt in = iN + 1; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) { 308 idiff = abs(n_array[in] - N); 309 if (idiff < diff) { 310 iN = in; 311 diff = idiff; 312 } 313 } 314 315 if (!impl->NB_interp[iN]) { 316 CeedFESpace fe_space; 317 CeedInt q_comp_interp, q_comp_deriv; 318 Ceed ceed_delegate; 319 char *basis_kernel_source; 320 const char *basis_kernel_path, *weight_kernel_path; 321 char **file_paths = NULL; 322 CeedInt num_file_paths = 0; 323 magma_int_t arch = magma_getdevice_arch(); 324 325 // Tuning parameters for NB 326 CeedCallBackend(CeedBasisGetFESpace(basis, &fe_space)); 327 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 328 switch (fe_space) { 329 case CEED_FE_SPACE_H1: 330 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_deriv)); 331 break; 332 case CEED_FE_SPACE_HDIV: 333 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_deriv)); 334 break; 335 case CEED_FE_SPACE_HCURL: 336 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_deriv)); 337 break; 338 } 339 impl->NB_interp[iN] = nontensor_rtc_get_nb(arch, 'n', q_comp_interp, P, Q, n_array[iN]); 340 impl->NB_interp_t[iN] = nontensor_rtc_get_nb(arch, 't', q_comp_interp, P, Q, n_array[iN]); 341 impl->NB_deriv[iN] = nontensor_rtc_get_nb(arch, 'n', q_comp_deriv, P, Q, n_array[iN]); 342 impl->NB_deriv_t[iN] = nontensor_rtc_get_nb(arch, 't', q_comp_deriv, P, Q, n_array[iN]); 343 344 // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data 345 CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate)); 346 347 // Compile kernels 348 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-interp-deriv-nontensor.h", &basis_kernel_path)); 349 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n"); 350 CeedCallBackend(CeedLoadSourceAndInitializeBuffer(ceed, basis_kernel_path, &num_file_paths, &file_paths, &basis_kernel_source)); 351 if (!impl->Weight) { 352 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path)); 353 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_kernel_path, &num_file_paths, &file_paths, &basis_kernel_source)); 354 } 355 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n"); 356 CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module[iN], 8, "BASIS_Q_COMP_INTERP", q_comp_interp, 357 "BASIS_Q_COMP_DERIV", q_comp_deriv, "BASIS_P", P, "BASIS_Q", Q, "BASIS_NB_INTERP_N", impl->NB_interp[iN], 358 "BASIS_NB_INTERP_T", impl->NB_interp_t[iN], "BASIS_NB_DERIV_N", impl->NB_deriv[iN], "BASIS_NB_DERIV_T", 359 impl->NB_deriv_t[iN])); 360 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_interp_nontensor_n", &impl->Interp[iN])); 361 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_interp_nontensor_t", &impl->InterpTranspose[iN])); 362 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_interp_nontensor_ta", &impl->InterpTransposeAdd[iN])); 363 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_deriv_nontensor_n", &impl->Deriv[iN])); 364 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_deriv_nontensor_t", &impl->DerivTranspose[iN])); 365 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_deriv_nontensor_ta", &impl->DerivTransposeAdd[iN])); 366 if (!impl->Weight) { 367 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[iN], "magma_weight_nontensor", &impl->Weight)); 368 CeedCallBackend(CeedFree(&weight_kernel_path)); 369 } 370 CeedCallBackend(CeedFree(&basis_kernel_path)); 371 CeedCallBackend(CeedFree(&basis_kernel_source)); 372 for (CeedInt i = 0; i < num_file_paths; i++) CeedCall(CeedFree(&file_paths[i])); 373 CeedCall(CeedFree(&file_paths)); 374 } 375 } 376 377 // Apply basis operation 378 if (e_mode != CEED_EVAL_WEIGHT) { 379 const CeedScalar *d_b = NULL; 380 CeedInt q_comp, NB, M, K; 381 CeedMagmaFunction Kernel; 382 383 switch (e_mode) { 384 case CEED_EVAL_INTERP: 385 d_b = impl->d_interp; 386 break; 387 case CEED_EVAL_GRAD: 388 d_b = impl->d_grad; 389 break; 390 case CEED_EVAL_DIV: 391 d_b = impl->d_div; 392 break; 393 case CEED_EVAL_CURL: 394 d_b = impl->d_curl; 395 break; 396 // LCOV_EXCL_START 397 case CEED_EVAL_WEIGHT: 398 case CEED_EVAL_NONE: 399 return CeedError(ceed, CEED_ERROR_BACKEND, "%s does not make sense in this context", CeedEvalModes[e_mode]); 400 // LCOV_EXCL_STOP 401 } 402 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, e_mode, &q_comp)); 403 M = (t_mode == CEED_TRANSPOSE) ? P : Q, K = (t_mode == CEED_TRANSPOSE) ? Q : P; 404 405 if (P <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q <= MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) { 406 if (e_mode == CEED_EVAL_INTERP) { 407 if (t_mode == CEED_TRANSPOSE) { 408 Kernel = apply_add ? impl->InterpTransposeAdd[iN] : impl->InterpTranspose[iN]; 409 NB = impl->NB_interp_t[iN]; 410 } else { 411 Kernel = impl->Interp[iN]; 412 NB = impl->NB_interp[iN]; 413 } 414 } else { 415 if (t_mode == CEED_TRANSPOSE) { 416 Kernel = apply_add ? impl->DerivTransposeAdd[iN] : impl->DerivTranspose[iN]; 417 NB = impl->NB_deriv_t[iN]; 418 } else { 419 Kernel = impl->Deriv[iN]; 420 NB = impl->NB_deriv[iN]; 421 } 422 } 423 CeedInt num_t_col = MAGMA_BASIS_NTCOL(M, MAGMA_MAXTHREADS_1D); 424 CeedInt grid = CeedDivUpInt(N, num_t_col * NB); 425 CeedInt shared_mem_A = P * Q * sizeof(CeedScalar); 426 CeedInt shared_mem_B = num_t_col * K * NB * sizeof(CeedScalar); 427 CeedInt shared_mem = (t_mode != CEED_TRANSPOSE && q_comp > 1) ? (shared_mem_A + shared_mem_B) : CeedIntMax(shared_mem_A, shared_mem_B); 428 void *args[] = {&N, &d_b, &d_u, &d_v}; 429 430 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, Kernel, grid, M, num_t_col, 1, shared_mem, args)); 431 } else { 432 for (CeedInt d = 0; d < q_comp; d++) { 433 if (t_mode == CEED_TRANSPOSE) { 434 const CeedScalar beta = (apply_add || (d > 0)) ? 1.0 : 0.0; 435 magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, N, Q, 1.0, d_b + d * P * Q, P, d_u + d * N * Q, Q, beta, d_v, P, data->queue); 436 } else { 437 magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, N, P, 1.0, d_b + d * P * Q, P, d_u, P, 0.0, d_v + d * N * Q, Q, data->queue); 438 } 439 } 440 } 441 } else { 442 CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE"); 443 CeedCheck(impl->d_q_weight, ceed, CEED_ERROR_BACKEND, "%s not supported; q_weight not set", CeedEvalModes[e_mode]); 444 CeedInt num_t_col = MAGMA_BASIS_NTCOL(Q, MAGMA_MAXTHREADS_1D); 445 CeedInt grid = CeedDivUpInt(num_elem, num_t_col); 446 CeedInt shared_mem = Q * sizeof(CeedScalar) + num_t_col * Q * sizeof(CeedScalar); 447 void *args[] = {&num_elem, &impl->d_q_weight, &d_v}; 448 449 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Weight, grid, Q, num_t_col, 1, shared_mem, args)); 450 } 451 452 // Must sync to ensure completeness 453 ceed_magma_queue_sync(data->queue); 454 455 // Restore vectors 456 if (e_mode != CEED_EVAL_WEIGHT) { 457 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 458 } 459 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 460 return CEED_ERROR_SUCCESS; 461 } 462 463 static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, 464 CeedVector v) { 465 CeedCallBackend(CeedBasisApplyNonTensorCore_Magma(basis, false, num_elem, t_mode, e_mode, u, v)); 466 return CEED_ERROR_SUCCESS; 467 } 468 469 static int CeedBasisApplyAddNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, 470 CeedVector v) { 471 CeedCallBackend(CeedBasisApplyNonTensorCore_Magma(basis, true, num_elem, t_mode, e_mode, u, v)); 472 return CEED_ERROR_SUCCESS; 473 } 474 475 //------------------------------------------------------------------------------ 476 // Destroy tensor basis 477 //------------------------------------------------------------------------------ 478 static int CeedBasisDestroy_Magma(CeedBasis basis) { 479 Ceed ceed; 480 CeedBasis_Magma *impl; 481 482 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 483 CeedCallBackend(CeedBasisGetData(basis, &impl)); 484 #ifdef CEED_MAGMA_USE_HIP 485 CeedCallHip(ceed, hipModuleUnload(impl->module)); 486 #else 487 CeedCallCuda(ceed, cuModuleUnload(impl->module)); 488 #endif 489 CeedCallBackend(magma_free(impl->d_interp_1d)); 490 CeedCallBackend(magma_free(impl->d_grad_1d)); 491 if (impl->d_q_weight_1d) CeedCallBackend(magma_free(impl->d_q_weight_1d)); 492 CeedCallBackend(CeedFree(&impl)); 493 return CEED_ERROR_SUCCESS; 494 } 495 496 //------------------------------------------------------------------------------ 497 // Destroy non-tensor basis 498 //------------------------------------------------------------------------------ 499 static int CeedBasisDestroyNonTensor_Magma(CeedBasis basis) { 500 Ceed ceed; 501 CeedBasisNonTensor_Magma *impl; 502 503 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 504 CeedCallBackend(CeedBasisGetData(basis, &impl)); 505 for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) { 506 if (impl->module[in]) { 507 #ifdef CEED_MAGMA_USE_HIP 508 CeedCallHip(ceed, hipModuleUnload(impl->module[in])); 509 #else 510 CeedCallCuda(ceed, cuModuleUnload(impl->module[in])); 511 #endif 512 } 513 } 514 CeedCallBackend(magma_free(impl->d_interp)); 515 CeedCallBackend(magma_free(impl->d_grad)); 516 CeedCallBackend(magma_free(impl->d_div)); 517 CeedCallBackend(magma_free(impl->d_curl)); 518 if (impl->d_q_weight) CeedCallBackend(magma_free(impl->d_q_weight)); 519 CeedCallBackend(CeedFree(&impl)); 520 return CEED_ERROR_SUCCESS; 521 } 522 523 //------------------------------------------------------------------------------ 524 // Create tensor 525 //------------------------------------------------------------------------------ 526 int CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d, 527 const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) { 528 Ceed ceed, ceed_delegate; 529 Ceed_Magma *data; 530 char *basis_kernel_source; 531 const char *interp_kernel_path, *grad_kernel_path, *weight_kernel_path; 532 char **file_paths = NULL; 533 CeedInt num_file_paths = 0; 534 CeedInt num_comp; 535 CeedBasis_Magma *impl; 536 537 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 538 CeedCallBackend(CeedGetData(ceed, &data)); 539 CeedCallBackend(CeedCalloc(1, &impl)); 540 541 // Copy basis data to GPU 542 if (q_weight_1d) { 543 CeedCallBackend(magma_malloc((void **)&impl->d_q_weight_1d, Q_1d * sizeof(q_weight_1d[0]))); 544 magma_setvector(Q_1d, sizeof(q_weight_1d[0]), q_weight_1d, 1, impl->d_q_weight_1d, 1, data->queue); 545 } 546 CeedCallBackend(magma_malloc((void **)&impl->d_interp_1d, Q_1d * P_1d * sizeof(interp_1d[0]))); 547 magma_setvector(Q_1d * P_1d, sizeof(interp_1d[0]), interp_1d, 1, impl->d_interp_1d, 1, data->queue); 548 CeedCallBackend(magma_malloc((void **)&impl->d_grad_1d, Q_1d * P_1d * sizeof(grad_1d[0]))); 549 magma_setvector(Q_1d * P_1d, sizeof(grad_1d[0]), grad_1d, 1, impl->d_grad_1d, 1, data->queue); 550 551 // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data 552 CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate)); 553 554 // Compile kernels 555 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 556 { 557 char *interp_kernel_name_base = "ceed/jit-source/magma/magma-basis-interp"; 558 CeedInt interp_kernel_name_len = strlen(interp_kernel_name_base) + 6; 559 char interp_kernel_name[interp_kernel_name_len]; 560 561 snprintf(interp_kernel_name, interp_kernel_name_len, "%s-%" CeedInt_FMT "d.h", interp_kernel_name_base, dim); 562 CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_kernel_name, &interp_kernel_path)); 563 } 564 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n"); 565 CeedCallBackend(CeedLoadSourceAndInitializeBuffer(ceed, interp_kernel_path, &num_file_paths, &file_paths, &basis_kernel_source)); 566 { 567 char *grad_kernel_name_base = "ceed/jit-source/magma/magma-basis-grad"; 568 CeedInt grad_kernel_name_len = strlen(grad_kernel_name_base) + 6; 569 char grad_kernel_name[grad_kernel_name_len]; 570 571 snprintf(grad_kernel_name, grad_kernel_name_len, "%s-%" CeedInt_FMT "d.h", grad_kernel_name_base, dim); 572 CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_kernel_name, &grad_kernel_path)); 573 } 574 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_kernel_path, &num_file_paths, &file_paths, &basis_kernel_source)); 575 { 576 char *weight_kernel_name_base = "ceed/jit-source/magma/magma-basis-weight"; 577 CeedInt weight_kernel_name_len = strlen(weight_kernel_name_base) + 6; 578 char weight_kernel_name[weight_kernel_name_len]; 579 580 snprintf(weight_kernel_name, weight_kernel_name_len, "%s-%" CeedInt_FMT "d.h", weight_kernel_name_base, dim); 581 CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_kernel_name, &weight_kernel_path)); 582 } 583 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_kernel_path, &num_file_paths, &file_paths, &basis_kernel_source)); 584 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n"); 585 CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module, 5, "BASIS_DIM", dim, "BASIS_NUM_COMP", num_comp, "BASIS_P", 586 P_1d, "BASIS_Q", Q_1d, "BASIS_MAX_P_Q", CeedIntMax(P_1d, Q_1d))); 587 switch (dim) { 588 case 1: 589 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->Interp)); 590 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->InterpTranspose)); 591 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpta_1d_kernel", &impl->InterpTransposeAdd)); 592 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->Grad)); 593 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->GradTranspose)); 594 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradta_1d_kernel", &impl->GradTransposeAdd)); 595 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->Weight)); 596 break; 597 case 2: 598 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->Interp)); 599 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->InterpTranspose)); 600 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpta_2d_kernel", &impl->InterpTransposeAdd)); 601 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->Grad)); 602 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->GradTranspose)); 603 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradta_2d_kernel", &impl->GradTransposeAdd)); 604 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->Weight)); 605 break; 606 case 3: 607 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->Interp)); 608 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->InterpTranspose)); 609 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpta_3d_kernel", &impl->InterpTransposeAdd)); 610 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->Grad)); 611 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->GradTranspose)); 612 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradta_3d_kernel", &impl->GradTransposeAdd)); 613 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->Weight)); 614 break; 615 } 616 CeedCallBackend(CeedFree(&interp_kernel_path)); 617 CeedCallBackend(CeedFree(&grad_kernel_path)); 618 CeedCallBackend(CeedFree(&weight_kernel_path)); 619 CeedCallBackend(CeedFree(&basis_kernel_source)); 620 for (CeedInt i = 0; i < num_file_paths; i++) CeedCall(CeedFree(&file_paths[i])); 621 CeedCall(CeedFree(&file_paths)); 622 623 CeedCallBackend(CeedBasisSetData(basis, impl)); 624 625 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma)); 626 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAdd_Magma)); 627 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAtPoints", CeedBasisApplyAtPoints_Magma)); 628 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma)); 629 return CEED_ERROR_SUCCESS; 630 } 631 632 //------------------------------------------------------------------------------ 633 // Create non-tensor H^1 634 //------------------------------------------------------------------------------ 635 int CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad, 636 const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) { 637 Ceed ceed; 638 Ceed_Magma *data; 639 CeedBasisNonTensor_Magma *impl; 640 641 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 642 CeedCallBackend(CeedGetData(ceed, &data)); 643 CeedCallBackend(CeedCalloc(1, &impl)); 644 645 // Copy basis data to GPU 646 if (q_weight) { 647 CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0]))); 648 magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue); 649 } 650 if (interp) { 651 CeedInt q_comp_interp; 652 653 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 654 CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * q_comp_interp * sizeof(interp[0]))); 655 magma_setvector(num_qpts * num_nodes * q_comp_interp, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue); 656 } 657 if (grad) { 658 CeedInt q_comp_grad; 659 660 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_GRAD, &q_comp_grad)); 661 CeedCallBackend(magma_malloc((void **)&impl->d_grad, num_qpts * num_nodes * q_comp_grad * sizeof(grad[0]))); 662 magma_setvector(num_qpts * num_nodes * q_comp_grad, sizeof(grad[0]), grad, 1, impl->d_grad, 1, data->queue); 663 } 664 665 // Compile the weight kernel if it won't be compiled later on 666 if (num_nodes > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P || num_qpts > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) { 667 Ceed ceed_delegate; 668 char *basis_kernel_source; 669 const char *weight_kernel_path; 670 671 // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data 672 CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate)); 673 674 // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply) 675 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path)); 676 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n"); 677 CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source)); 678 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n"); 679 CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module[0], 1, "BASIS_Q", num_qpts)); 680 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[0], "magma_weight_nontensor", &impl->Weight)); 681 CeedCallBackend(CeedFree(&weight_kernel_path)); 682 CeedCallBackend(CeedFree(&basis_kernel_source)); 683 } 684 685 CeedCallBackend(CeedBasisSetData(basis, impl)); 686 687 // Register backend functions 688 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma)); 689 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Magma)); 690 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma)); 691 return CEED_ERROR_SUCCESS; 692 } 693 694 //------------------------------------------------------------------------------ 695 // Create non-tensor H(div) 696 //------------------------------------------------------------------------------ 697 int CeedBasisCreateHdiv_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, 698 const CeedScalar *div, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) { 699 Ceed ceed; 700 Ceed_Magma *data; 701 CeedBasisNonTensor_Magma *impl; 702 703 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 704 CeedCallBackend(CeedGetData(ceed, &data)); 705 CeedCallBackend(CeedCalloc(1, &impl)); 706 707 // Copy basis data to GPU 708 if (q_weight) { 709 CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0]))); 710 magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue); 711 } 712 if (interp) { 713 CeedInt q_comp_interp; 714 715 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 716 CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * q_comp_interp * sizeof(interp[0]))); 717 magma_setvector(num_qpts * num_nodes * q_comp_interp, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue); 718 } 719 if (div) { 720 CeedInt q_comp_div; 721 722 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_DIV, &q_comp_div)); 723 CeedCallBackend(magma_malloc((void **)&impl->d_div, num_qpts * num_nodes * q_comp_div * sizeof(div[0]))); 724 magma_setvector(num_qpts * num_nodes * q_comp_div, sizeof(div[0]), div, 1, impl->d_div, 1, data->queue); 725 } 726 727 // Compile the weight kernel if it won't be compiled later on 728 if (num_nodes > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P || num_qpts > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) { 729 Ceed ceed_delegate; 730 char *basis_kernel_source; 731 const char *weight_kernel_path; 732 733 // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data 734 CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate)); 735 736 // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply) 737 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path)); 738 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n"); 739 CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source)); 740 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n"); 741 CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module[0], 1, "BASIS_Q", num_qpts)); 742 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[0], "magma_weight_nontensor", &impl->Weight)); 743 CeedCallBackend(CeedFree(&weight_kernel_path)); 744 CeedCallBackend(CeedFree(&basis_kernel_source)); 745 } 746 747 CeedCallBackend(CeedBasisSetData(basis, impl)); 748 749 // Register backend functions 750 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma)); 751 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Magma)); 752 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma)); 753 return CEED_ERROR_SUCCESS; 754 } 755 756 //------------------------------------------------------------------------------ 757 // Create non-tensor H(curl) 758 //------------------------------------------------------------------------------ 759 int CeedBasisCreateHcurl_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, 760 const CeedScalar *curl, const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) { 761 Ceed ceed; 762 Ceed_Magma *data; 763 CeedBasisNonTensor_Magma *impl; 764 765 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 766 CeedCallBackend(CeedGetData(ceed, &data)); 767 CeedCallBackend(CeedCalloc(1, &impl)); 768 769 // Copy basis data to GPU 770 if (q_weight) { 771 CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0]))); 772 magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue); 773 } 774 if (interp) { 775 CeedInt q_comp_interp; 776 777 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_INTERP, &q_comp_interp)); 778 CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * q_comp_interp * sizeof(interp[0]))); 779 magma_setvector(num_qpts * num_nodes * q_comp_interp, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue); 780 } 781 if (curl) { 782 CeedInt q_comp_curl; 783 784 CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, CEED_EVAL_CURL, &q_comp_curl)); 785 CeedCallBackend(magma_malloc((void **)&impl->d_curl, num_qpts * num_nodes * q_comp_curl * sizeof(curl[0]))); 786 magma_setvector(num_qpts * num_nodes * q_comp_curl, sizeof(curl[0]), curl, 1, impl->d_curl, 1, data->queue); 787 } 788 789 // Compile the weight kernel if it won't be compiled later on 790 if (num_nodes > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P || num_qpts > MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) { 791 Ceed ceed_delegate; 792 char *basis_kernel_source; 793 const char *weight_kernel_path; 794 795 // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data 796 CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate)); 797 798 // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply) 799 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path)); 800 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n"); 801 CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source)); 802 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n"); 803 CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module[0], 1, "BASIS_Q", num_qpts)); 804 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module[0], "magma_weight_nontensor", &impl->Weight)); 805 CeedCallBackend(CeedFree(&weight_kernel_path)); 806 CeedCallBackend(CeedFree(&basis_kernel_source)); 807 } 808 809 CeedCallBackend(CeedBasisSetData(basis, impl)); 810 811 // Register backend functions 812 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma)); 813 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "ApplyAdd", CeedBasisApplyAddNonTensor_Magma)); 814 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma)); 815 return CEED_ERROR_SUCCESS; 816 } 817 818 //------------------------------------------------------------------------------ 819