1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <ceed/jit-tools.h> 11 #include <string.h> 12 13 #ifdef CEED_MAGMA_USE_HIP 14 #include "../hip/ceed-hip-common.h" 15 #include "../hip/ceed-hip-compile.h" 16 #else 17 #include "../cuda/ceed-cuda-common.h" 18 #include "../cuda/ceed-cuda-compile.h" 19 #endif 20 #include "ceed-magma-common.h" 21 #include "ceed-magma.h" 22 23 #include "ceed-magma-gemm-nontensor.h" 24 #include "ceed-magma-gemm-selector.h" 25 26 //------------------------------------------------------------------------------ 27 // Basis apply - tensor 28 //------------------------------------------------------------------------------ 29 static int CeedBasisApply_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, CeedVector v) { 30 Ceed ceed; 31 Ceed_Magma *data; 32 CeedInt dim, num_comp, num_nodes, P_1d, Q_1d, P, Q; 33 const CeedScalar *d_u; 34 CeedScalar *d_v; 35 CeedBasis_Magma *impl; 36 37 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 38 CeedCallBackend(CeedGetData(ceed, &data)); 39 CeedCallBackend(CeedBasisGetData(basis, &impl)); 40 CeedCallBackend(CeedBasisGetDimension(basis, &dim)); 41 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 42 CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); 43 CeedCallBackend(CeedBasisGetNumNodes1D(basis, &P_1d)); 44 CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d)); 45 P = P_1d; 46 Q = Q_1d; 47 if (t_mode == CEED_TRANSPOSE) { 48 P = Q_1d; 49 Q = P_1d; 50 } 51 52 // Read vectors 53 if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 54 else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); 55 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 56 57 // Clear v for transpose operation 58 if (t_mode == CEED_TRANSPOSE) { 59 CeedSize length; 60 61 CeedCallBackend(CeedVectorGetLength(v, &length)); 62 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 63 magmablas_slaset(MagmaFull, length, 1, 0.0, 0.0, (float *)d_v, length, data->queue); 64 } else { 65 magmablas_dlaset(MagmaFull, length, 1, 0.0, 0.0, (double *)d_v, length, data->queue); 66 } 67 ceed_magma_queue_sync(data->queue); 68 } 69 70 // Apply basis operation 71 switch (e_mode) { 72 case CEED_EVAL_INTERP: { 73 // Define element sizes for dofs/quad 74 CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim); 75 CeedInt elem_dofs_size = CeedIntPow(P_1d, dim); 76 77 // E-vector ordering -------------- Q-vector ordering 78 // component component 79 // elem elem 80 // node node 81 82 // --- Define strides for NOTRANSPOSE mode: --- 83 // Input (d_u) is E-vector, output (d_v) is Q-vector 84 85 // Element strides 86 CeedInt u_elem_stride = elem_dofs_size; 87 CeedInt v_elem_stride = elem_qpts_size; 88 // Component strides 89 CeedInt u_comp_stride = num_elem * elem_dofs_size; 90 CeedInt v_comp_stride = num_elem * elem_qpts_size; 91 if (t_mode == CEED_TRANSPOSE) { 92 // Input (d_u) is Q-vector, output (d_v) is E-vector 93 // Element strides 94 v_elem_stride = elem_dofs_size; 95 u_elem_stride = elem_qpts_size; 96 // Component strides 97 v_comp_stride = num_elem * elem_dofs_size; 98 u_comp_stride = num_elem * elem_qpts_size; 99 } 100 CeedInt num_threads = 1; 101 CeedInt num_t_col = 1; 102 CeedInt shared_mem = 0; 103 CeedInt max_P_Q = CeedIntMax(P, Q); 104 105 switch (dim) { 106 case 1: 107 num_threads = max_P_Q; 108 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D); 109 shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q)); 110 shared_mem += sizeof(CeedScalar) * (P * Q); 111 break; 112 case 2: 113 num_threads = max_P_Q; 114 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D); 115 shared_mem += P * Q * sizeof(CeedScalar); // for sT 116 // for reforming rU we need P x P, and for the intermediate output we need P x Q 117 shared_mem += num_t_col * (P * max_P_Q * sizeof(CeedScalar)); 118 break; 119 case 3: 120 num_threads = max_P_Q * max_P_Q; 121 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D); 122 shared_mem += sizeof(CeedScalar) * (P * Q); // for sT 123 // rU needs P^2 x P, the intermediate output needs max(P^2 x Q, P x Q^2) 124 shared_mem += sizeof(CeedScalar) * num_t_col * (CeedIntMax(P * P * max_P_Q, P * Q * Q)); 125 break; 126 } 127 CeedInt grid = CeedDivUpInt(num_elem, num_t_col); 128 void *args[] = {&impl->d_interp_1d, &d_u, &u_elem_stride, &u_comp_stride, &d_v, &v_elem_stride, &v_comp_stride, &num_elem}; 129 130 if (t_mode == CEED_TRANSPOSE) { 131 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->InterpTranspose, grid, num_threads, num_t_col, 1, shared_mem, args)); 132 } else { 133 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Interp, grid, num_threads, num_t_col, 1, shared_mem, args)); 134 } 135 } break; 136 case CEED_EVAL_GRAD: { 137 // Define element sizes for dofs/quad 138 CeedInt elem_qpts_size = CeedIntPow(Q_1d, dim); 139 CeedInt elem_dofs_size = CeedIntPow(P_1d, dim); 140 141 // In CEED_NOTRANSPOSE mode: 142 // d_u is (P^dim x nc), column-major layout (nc = num_comp) 143 // d_v is (Q^dim x nc x dim), column-major layout (nc = num_comp) 144 // In CEED_TRANSPOSE mode, the sizes of d_u and d_v are switched. 145 146 // E-vector ordering -------------- Q-vector ordering 147 // dim 148 // component component 149 // elem elem 150 // node node 151 152 // --- Define strides for NOTRANSPOSE mode: --- 153 // Input (d_u) is E-vector, output (d_v) is Q-vector 154 155 // Element strides 156 CeedInt u_elem_stride = elem_dofs_size; 157 CeedInt v_elem_stride = elem_qpts_size; 158 // Component strides 159 CeedInt u_comp_stride = num_elem * elem_dofs_size; 160 CeedInt v_comp_stride = num_elem * elem_qpts_size; 161 // Dimension strides 162 CeedInt u_dim_stride = 0; 163 CeedInt v_dim_stride = num_elem * elem_qpts_size * num_comp; 164 if (t_mode == CEED_TRANSPOSE) { 165 // Input (d_u) is Q-vector, output (d_v) is E-vector 166 // Element strides 167 v_elem_stride = elem_dofs_size; 168 u_elem_stride = elem_qpts_size; 169 // Component strides 170 v_comp_stride = num_elem * elem_dofs_size; 171 u_comp_stride = num_elem * elem_qpts_size; 172 // Dimension strides 173 v_dim_stride = 0; 174 u_dim_stride = num_elem * elem_qpts_size * num_comp; 175 } 176 CeedInt num_threads = 1; 177 CeedInt num_t_col = 1; 178 CeedInt shared_mem = 0; 179 CeedInt max_P_Q = CeedIntMax(P, Q); 180 181 switch (dim) { 182 case 1: 183 num_threads = max_P_Q; 184 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D); 185 shared_mem += sizeof(CeedScalar) * num_t_col * (num_comp * (1 * P + 1 * Q)); 186 shared_mem += sizeof(CeedScalar) * (P * Q); 187 break; 188 case 2: 189 num_threads = max_P_Q; 190 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D); 191 shared_mem += sizeof(CeedScalar) * 2 * P * Q; // for sTinterp and sTgrad 192 // for reforming rU we need P x P, and for the intermediate output we need P x Q 193 shared_mem += sizeof(CeedScalar) * num_t_col * (P * max_P_Q); 194 break; 195 case 3: 196 num_threads = max_P_Q * max_P_Q; 197 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D); 198 shared_mem += sizeof(CeedScalar) * 2 * P * Q; // for sTinterp and sTgrad 199 // rU needs P^2 x P, the intermediate outputs need (P^2 x Q + P x Q^2) 200 shared_mem += sizeof(CeedScalar) * num_t_col * CeedIntMax(P * P * P, (P * P * Q) + (P * Q * Q)); 201 break; 202 } 203 CeedInt grid = CeedDivUpInt(num_elem, num_t_col); 204 void *args[] = {&impl->d_interp_1d, &impl->d_grad_1d, &d_u, &u_elem_stride, &u_comp_stride, &u_dim_stride, &d_v, 205 &v_elem_stride, &v_comp_stride, &v_dim_stride, &num_elem}; 206 207 if (t_mode == CEED_TRANSPOSE) { 208 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->GradTranspose, grid, num_threads, num_t_col, 1, shared_mem, args)); 209 } else { 210 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Grad, grid, num_threads, num_t_col, 1, shared_mem, args)); 211 } 212 } break; 213 case CEED_EVAL_WEIGHT: { 214 CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE"); 215 CeedInt elem_dofs_size = CeedIntPow(Q, dim); 216 CeedInt num_threads = 1; 217 CeedInt num_t_col = 1; 218 CeedInt shared_mem = 0; 219 220 switch (dim) { 221 case 1: 222 num_threads = Q; 223 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_1D); 224 shared_mem += sizeof(CeedScalar) * Q; // for d_q_weight_1d 225 shared_mem += sizeof(CeedScalar) * num_t_col * Q; // for output 226 break; 227 case 2: 228 num_threads = Q; 229 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_2D); 230 shared_mem += sizeof(CeedScalar) * Q; // for d_q_weight_1d 231 break; 232 case 3: 233 num_threads = Q * Q; 234 num_t_col = MAGMA_BASIS_NTCOL(num_threads, MAGMA_MAXTHREADS_3D); 235 shared_mem += sizeof(CeedScalar) * Q; // for d_q_weight_1d 236 break; 237 } 238 CeedInt grid = CeedDivUpInt(num_elem, num_t_col); 239 void *args[] = {&impl->d_q_weight_1d, &d_v, &elem_dofs_size, &num_elem}; 240 241 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Weight, grid, num_threads, num_t_col, 1, shared_mem, args)); 242 } break; 243 // LCOV_EXCL_START 244 case CEED_EVAL_DIV: 245 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported"); 246 case CEED_EVAL_CURL: 247 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported"); 248 case CEED_EVAL_NONE: 249 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_NONE does not make sense in this context"); 250 // LCOV_EXCL_STOP 251 } 252 253 // Must sync to ensure completeness 254 ceed_magma_queue_sync(data->queue); 255 256 // Restore vectors 257 if (e_mode != CEED_EVAL_WEIGHT) { 258 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 259 } 260 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 261 return CEED_ERROR_SUCCESS; 262 } 263 264 //------------------------------------------------------------------------------ 265 // Basis apply - non-tensor 266 //------------------------------------------------------------------------------ 267 static int CeedBasisApplyNonTensor_Magma(CeedBasis basis, CeedInt num_elem, CeedTransposeMode t_mode, CeedEvalMode e_mode, CeedVector u, 268 CeedVector v) { 269 Ceed ceed; 270 Ceed_Magma *data; 271 CeedInt dim, num_comp, num_nodes, num_qpts, P, Q, N; 272 const CeedScalar *d_u; 273 CeedScalar *d_v; 274 CeedBasisNonTensor_Magma *impl; 275 276 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 277 CeedCallBackend(CeedGetData(ceed, &data)); 278 CeedCallBackend(CeedBasisGetData(basis, &impl)); 279 CeedCallBackend(CeedBasisGetDimension(basis, &dim)); 280 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 281 CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); 282 CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts)); 283 P = num_nodes; 284 Q = num_qpts; 285 N = num_elem * num_comp; 286 287 // Read vectors 288 if (u != CEED_VECTOR_NONE) CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u)); 289 else CeedCheck(e_mode == CEED_EVAL_WEIGHT, ceed, CEED_ERROR_BACKEND, "An input vector is required for this CeedEvalMode"); 290 CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v)); 291 292 // Clear v for transpose operation 293 if (t_mode == CEED_TRANSPOSE) { 294 CeedSize length; 295 296 CeedCallBackend(CeedVectorGetLength(v, &length)); 297 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 298 magmablas_slaset(MagmaFull, length, 1, 0.0, 0.0, (float *)d_v, length, data->queue); 299 } else { 300 magmablas_dlaset(MagmaFull, length, 1, 0.0, 0.0, (double *)d_v, length, data->queue); 301 } 302 ceed_magma_queue_sync(data->queue); 303 } 304 305 // Apply basis operation 306 if (e_mode != CEED_EVAL_WEIGHT) { 307 if (P < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_P && Q < MAGMA_NONTENSOR_CUSTOM_KERNEL_MAX_Q) { 308 CeedInt n_array[MAGMA_NONTENSOR_KERNEL_INSTANCES] = {MAGMA_NONTENSOR_KERNEL_N_VALUES}; 309 CeedInt iN = 0, diff = abs(n_array[iN] - N), idiff; 310 CeedInt M = (t_mode == CEED_TRANSPOSE) ? P : Q, K = (t_mode == CEED_TRANSPOSE) ? Q : P; 311 312 for (CeedInt in = iN + 1; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) { 313 idiff = abs(n_array[in] - N); 314 if (idiff < diff) { 315 iN = in; 316 diff = idiff; 317 } 318 } 319 320 // Compile kernels for N as needed 321 if (!impl->NB_interp[iN]) { 322 Ceed ceed_delegate; 323 char *interp_kernel_path, *grad_kernel_path, *basis_kernel_source; 324 magma_int_t arch = magma_getdevice_arch(); 325 326 // Tuning parameters for NB 327 impl->NB_interp[iN] = nontensor_rtc_get_nb(arch, 'n', 1, P, Q, n_array[iN]); 328 impl->NB_interp_t[iN] = nontensor_rtc_get_nb(arch, 't', 1, P, Q, n_array[iN]); 329 impl->NB_grad[iN] = nontensor_rtc_get_nb(arch, 'n', dim, P, Q, n_array[iN]); 330 impl->NB_grad_t[iN] = nontensor_rtc_get_nb(arch, 't', dim, P, Q, n_array[iN]); 331 332 // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data 333 CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate)); 334 335 // Compile kernels 336 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-interp-nontensor.h", &interp_kernel_path)); 337 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n"); 338 CeedCallBackend(CeedLoadSourceToBuffer(ceed, interp_kernel_path, &basis_kernel_source)); 339 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-grad-nontensor.h", &grad_kernel_path)); 340 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_kernel_path, &basis_kernel_source)); 341 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n"); 342 CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module_interp[iN], 7, "BASIS_DIM", dim, "BASIS_P", P, "BASIS_Q", 343 Q, "BASIS_NB_INTERP_N", impl->NB_interp[iN], "BASIS_NB_INTERP_T", impl->NB_interp_t[iN], "BASIS_NB_GRAD_N", 344 impl->NB_grad[iN], "BASIS_NB_GRAD_T", impl->NB_grad_t[iN])); 345 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_interp_nontensor_n", &impl->Interp[iN])); 346 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_interp_nontensor_t", &impl->InterpTranspose[iN])); 347 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_grad_nontensor_n", &impl->Grad[iN])); 348 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_interp[iN], "magma_grad_nontensor_t", &impl->GradTranspose[iN])); 349 CeedCallBackend(CeedFree(&interp_kernel_path)); 350 CeedCallBackend(CeedFree(&grad_kernel_path)); 351 CeedCallBackend(CeedFree(&basis_kernel_source)); 352 } 353 354 // Apply basis operation 355 CeedInt num_t_col = MAGMA_BASIS_NTCOL(M, MAGMA_MAXTHREADS_1D); 356 if (e_mode == CEED_EVAL_INTERP) { 357 CeedInt NB = (t_mode == CEED_TRANSPOSE) ? impl->NB_interp_t[iN] : impl->NB_interp[iN]; 358 CeedInt grid = CeedDivUpInt(N, NB * num_t_col); 359 CeedInt shared_mem_A = (t_mode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar); 360 CeedInt shared_mem_B = num_t_col * K * NB * sizeof(CeedScalar); 361 CeedInt shared_mem = (t_mode == CEED_TRANSPOSE) ? (shared_mem_A + shared_mem_B) : CeedIntMax(shared_mem_A, shared_mem_B); 362 void *args[] = {&N, &impl->d_interp, &P, &d_u, &K, &d_v, &M}; 363 364 if (t_mode == CEED_TRANSPOSE) { 365 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->InterpTranspose[iN], grid, M, num_t_col, 1, shared_mem, args)); 366 } else { 367 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Interp[iN], grid, M, num_t_col, 1, shared_mem, args)); 368 } 369 } else if (e_mode == CEED_EVAL_GRAD) { 370 CeedInt NB = (t_mode == CEED_TRANSPOSE) ? impl->NB_grad_t[iN] : impl->NB_grad[iN]; 371 CeedInt grid = CeedDivUpInt(N, NB * num_t_col); 372 CeedInt shared_mem = num_t_col * K * NB * sizeof(CeedScalar) + ((t_mode == CEED_TRANSPOSE) ? 0 : K * M * sizeof(CeedScalar)); 373 void *args[] = {&N, &impl->d_grad, &P, &d_u, &K, &d_v, &M}; 374 375 if (t_mode == CEED_TRANSPOSE) { 376 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->GradTranspose[iN], grid, M, num_t_col, 1, shared_mem, args)); 377 } else { 378 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Grad[iN], grid, M, num_t_col, 1, shared_mem, args)); 379 } 380 } else { 381 // LCOV_EXCL_START 382 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV, CEED_EVAL_CURL not supported"); 383 // LCOV_EXCL_STOP 384 } 385 } else { 386 if (e_mode == CEED_EVAL_INTERP) { 387 if (t_mode == CEED_TRANSPOSE) { 388 magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, N, Q, 1.0, impl->d_interp, P, d_u, Q, 0.0, d_v, P, data->queue); 389 } else { 390 magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, N, P, 1.0, impl->d_interp, P, d_u, P, 0.0, d_v, Q, data->queue); 391 } 392 } else if (e_mode == CEED_EVAL_GRAD) { 393 if (t_mode == CEED_TRANSPOSE) { 394 for (int d = 0; d < dim; d++) { 395 const CeedScalar beta = (d > 0) ? 1.0 : 0.0; 396 magma_gemm_nontensor(MagmaNoTrans, MagmaNoTrans, P, N, Q, 1.0, impl->d_grad + d * P * Q, P, d_u + d * N * Q, Q, beta, d_v, P, 397 data->queue); 398 } 399 } else { 400 for (int d = 0; d < dim; d++) { 401 magma_gemm_nontensor(MagmaTrans, MagmaNoTrans, Q, N, P, 1.0, impl->d_grad + d * P * Q, P, d_u, P, 0.0, d_v + d * N * Q, Q, data->queue); 402 } 403 } 404 } else { 405 // LCOV_EXCL_START 406 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV, CEED_EVAL_CURL not supported"); 407 // LCOV_EXCL_STOP 408 } 409 } 410 } else { 411 CeedCheck(t_mode != CEED_TRANSPOSE, ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT incompatible with CEED_TRANSPOSE"); 412 CeedInt num_t_col = MAGMA_BASIS_NTCOL(Q, MAGMA_MAXTHREADS_1D); 413 CeedInt grid = CeedDivUpInt(num_elem, num_t_col); 414 CeedInt shared_mem = Q * sizeof(CeedScalar) + num_t_col * Q * sizeof(CeedScalar); 415 void *args[] = {&num_elem, &impl->d_q_weight, &d_v, &Q}; 416 417 CeedCallBackend(CeedRunKernelDimSharedMagma(ceed, impl->Weight, grid, Q, num_t_col, 1, shared_mem, args)); 418 } 419 420 // Must sync to ensure completeness 421 ceed_magma_queue_sync(data->queue); 422 423 // Restore vectors 424 if (e_mode != CEED_EVAL_WEIGHT) { 425 CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u)); 426 } 427 CeedCallBackend(CeedVectorRestoreArray(v, &d_v)); 428 return CEED_ERROR_SUCCESS; 429 } 430 431 //------------------------------------------------------------------------------ 432 // Destroy tensor basis 433 //------------------------------------------------------------------------------ 434 static int CeedBasisDestroy_Magma(CeedBasis basis) { 435 Ceed ceed; 436 CeedBasis_Magma *impl; 437 438 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 439 CeedCallBackend(CeedBasisGetData(basis, &impl)); 440 #ifdef CEED_MAGMA_USE_HIP 441 CeedCallHip(ceed, hipModuleUnload(impl->module)); 442 #else 443 CeedCallCuda(ceed, cuModuleUnload(impl->module)); 444 #endif 445 CeedCallBackend(magma_free(impl->d_interp_1d)); 446 CeedCallBackend(magma_free(impl->d_grad_1d)); 447 CeedCallBackend(magma_free(impl->d_q_weight_1d)); 448 CeedCallBackend(CeedFree(&impl)); 449 return CEED_ERROR_SUCCESS; 450 } 451 452 //------------------------------------------------------------------------------ 453 // Destroy non-tensor basis 454 //------------------------------------------------------------------------------ 455 static int CeedBasisDestroyNonTensor_Magma(CeedBasis basis) { 456 Ceed ceed; 457 CeedBasisNonTensor_Magma *impl; 458 459 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 460 CeedCallBackend(CeedBasisGetData(basis, &impl)); 461 #ifdef CEED_MAGMA_USE_HIP 462 CeedCallHip(ceed, hipModuleUnload(impl->module_weight)); 463 #else 464 CeedCallCuda(ceed, cuModuleUnload(impl->module_weight)); 465 #endif 466 for (CeedInt in = 0; in < MAGMA_NONTENSOR_KERNEL_INSTANCES; in++) { 467 if (impl->module_interp[in]) { 468 #ifdef CEED_MAGMA_USE_HIP 469 CeedCallHip(ceed, hipModuleUnload(impl->module_interp[in])); 470 #else 471 CeedCallCuda(ceed, cuModuleUnload(impl->module_interp[in])); 472 #endif 473 } 474 } 475 CeedCallBackend(magma_free(impl->d_interp)); 476 CeedCallBackend(magma_free(impl->d_grad)); 477 CeedCallBackend(magma_free(impl->d_q_weight)); 478 CeedCallBackend(CeedFree(&impl)); 479 return CEED_ERROR_SUCCESS; 480 } 481 482 //------------------------------------------------------------------------------ 483 // Create tensor 484 //------------------------------------------------------------------------------ 485 int CeedBasisCreateTensorH1_Magma(CeedInt dim, CeedInt P_1d, CeedInt Q_1d, const CeedScalar *interp_1d, const CeedScalar *grad_1d, 486 const CeedScalar *q_ref_1d, const CeedScalar *q_weight_1d, CeedBasis basis) { 487 Ceed ceed, ceed_delegate; 488 Ceed_Magma *data; 489 char *interp_kernel_path, *grad_kernel_path, *weight_kernel_path, *basis_kernel_source; 490 CeedInt num_comp; 491 CeedBasis_Magma *impl; 492 493 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 494 CeedCallBackend(CeedGetData(ceed, &data)); 495 CeedCallBackend(CeedCalloc(1, &impl)); 496 497 // Copy basis data to GPU 498 CeedCallBackend(magma_malloc((void **)&impl->d_q_weight_1d, Q_1d * sizeof(q_weight_1d[0]))); 499 magma_setvector(Q_1d, sizeof(q_weight_1d[0]), q_weight_1d, 1, impl->d_q_weight_1d, 1, data->queue); 500 CeedCallBackend(magma_malloc((void **)&impl->d_interp_1d, Q_1d * P_1d * sizeof(interp_1d[0]))); 501 magma_setvector(Q_1d * P_1d, sizeof(interp_1d[0]), interp_1d, 1, impl->d_interp_1d, 1, data->queue); 502 CeedCallBackend(magma_malloc((void **)&impl->d_grad_1d, Q_1d * P_1d * sizeof(grad_1d[0]))); 503 magma_setvector(Q_1d * P_1d, sizeof(grad_1d[0]), grad_1d, 1, impl->d_grad_1d, 1, data->queue); 504 505 // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data 506 CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate)); 507 508 // Compile kernels 509 CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); 510 { 511 char *interp_kernel_name_base = "ceed/jit-source/magma/magma-basis-interp"; 512 CeedInt interp_kernel_name_len = strlen(interp_kernel_name_base) + 6; 513 char interp_kernel_name[interp_kernel_name_len]; 514 515 snprintf(interp_kernel_name, interp_kernel_name_len, "%s-%" CeedInt_FMT "d.h", interp_kernel_name_base, dim); 516 CeedCallBackend(CeedGetJitAbsolutePath(ceed, interp_kernel_name, &interp_kernel_path)); 517 } 518 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n"); 519 CeedCallBackend(CeedLoadSourceToBuffer(ceed, interp_kernel_path, &basis_kernel_source)); 520 { 521 char *grad_kernel_name_base = "ceed/jit-source/magma/magma-basis-grad"; 522 CeedInt grad_kernel_name_len = strlen(grad_kernel_name_base) + 6; 523 char grad_kernel_name[grad_kernel_name_len]; 524 525 snprintf(grad_kernel_name, grad_kernel_name_len, "%s-%" CeedInt_FMT "d.h", grad_kernel_name_base, dim); 526 CeedCallBackend(CeedGetJitAbsolutePath(ceed, grad_kernel_name, &grad_kernel_path)); 527 } 528 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, grad_kernel_path, &basis_kernel_source)); 529 { 530 char *weight_kernel_name_base = "ceed/jit-source/magma/magma-basis-weight"; 531 CeedInt weight_kernel_name_len = strlen(weight_kernel_name_base) + 6; 532 char weight_kernel_name[weight_kernel_name_len]; 533 534 snprintf(weight_kernel_name, weight_kernel_name_len, "%s-%" CeedInt_FMT "d.h", weight_kernel_name_base, dim); 535 CeedCallBackend(CeedGetJitAbsolutePath(ceed, weight_kernel_name, &weight_kernel_path)); 536 } 537 CeedCallBackend(CeedLoadSourceToInitializedBuffer(ceed, weight_kernel_path, &basis_kernel_source)); 538 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n"); 539 CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module, 5, "BASIS_DIM", dim, "BASIS_NUM_COMP", num_comp, "BASIS_P", 540 P_1d, "BASIS_Q", Q_1d, "BASIS_MAX_P_Q", CeedIntMax(P_1d, Q_1d))); 541 switch (dim) { 542 case 1: 543 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_1d_kernel", &impl->Interp)); 544 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_1d_kernel", &impl->InterpTranspose)); 545 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_1d_kernel", &impl->Grad)); 546 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_1d_kernel", &impl->GradTranspose)); 547 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_1d_kernel", &impl->Weight)); 548 break; 549 case 2: 550 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_2d_kernel", &impl->Interp)); 551 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_2d_kernel", &impl->InterpTranspose)); 552 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_2d_kernel", &impl->Grad)); 553 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_2d_kernel", &impl->GradTranspose)); 554 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_2d_kernel", &impl->Weight)); 555 break; 556 case 3: 557 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpn_3d_kernel", &impl->Interp)); 558 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_interpt_3d_kernel", &impl->InterpTranspose)); 559 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradn_3d_kernel", &impl->Grad)); 560 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_gradt_3d_kernel", &impl->GradTranspose)); 561 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module, "magma_weight_3d_kernel", &impl->Weight)); 562 break; 563 } 564 CeedCallBackend(CeedFree(&interp_kernel_path)); 565 CeedCallBackend(CeedFree(&grad_kernel_path)); 566 CeedCallBackend(CeedFree(&weight_kernel_path)); 567 CeedCallBackend(CeedFree(&basis_kernel_source)); 568 569 CeedCallBackend(CeedBasisSetData(basis, impl)); 570 571 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApply_Magma)); 572 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroy_Magma)); 573 return CEED_ERROR_SUCCESS; 574 } 575 576 //------------------------------------------------------------------------------ 577 // Create non-tensor H^1 578 //------------------------------------------------------------------------------ 579 int CeedBasisCreateH1_Magma(CeedElemTopology topo, CeedInt dim, CeedInt num_nodes, CeedInt num_qpts, const CeedScalar *interp, const CeedScalar *grad, 580 const CeedScalar *q_ref, const CeedScalar *q_weight, CeedBasis basis) { 581 Ceed ceed, ceed_delegate; 582 Ceed_Magma *data; 583 char *weight_kernel_path, *basis_kernel_source; 584 CeedBasisNonTensor_Magma *impl; 585 586 CeedCallBackend(CeedBasisGetCeed(basis, &ceed)); 587 CeedCallBackend(CeedGetData(ceed, &data)); 588 CeedCallBackend(CeedCalloc(1, &impl)); 589 590 // Copy basis data to GPU 591 CeedCallBackend(magma_malloc((void **)&impl->d_q_weight, num_qpts * sizeof(q_weight[0]))); 592 magma_setvector(num_qpts, sizeof(q_weight[0]), q_weight, 1, impl->d_q_weight, 1, data->queue); 593 CeedCallBackend(magma_malloc((void **)&impl->d_interp, num_qpts * num_nodes * sizeof(interp[0]))); 594 magma_setvector(num_qpts * num_nodes, sizeof(interp[0]), interp, 1, impl->d_interp, 1, data->queue); 595 CeedCallBackend(magma_malloc((void **)&impl->d_grad, num_qpts * num_nodes * dim * sizeof(grad[0]))); 596 magma_setvector(num_qpts * num_nodes * dim, sizeof(grad[0]), grad, 1, impl->d_grad, 1, data->queue); 597 598 // The RTC compilation code expects a Ceed with the common Ceed_Cuda or Ceed_Hip data 599 CeedCallBackend(CeedGetDelegate(ceed, &ceed_delegate)); 600 601 // Compile weight kernel (the remainder of kernel compilation happens at first call to CeedBasisApply) 602 CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/magma/magma-basis-weight-nontensor.h", &weight_kernel_path)); 603 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source -----\n"); 604 CeedCallBackend(CeedLoadSourceToBuffer(ceed, weight_kernel_path, &basis_kernel_source)); 605 CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Basis Kernel Source Complete! -----\n"); 606 CeedCallBackend(CeedCompileMagma(ceed_delegate, basis_kernel_source, &impl->module_weight, 1, "BASIS_Q", num_qpts)); 607 CeedCallBackend(CeedGetKernelMagma(ceed, impl->module_weight, "magma_weight_nontensor", &impl->Weight)); 608 CeedCallBackend(CeedFree(&weight_kernel_path)); 609 CeedCallBackend(CeedFree(&basis_kernel_source)); 610 611 CeedCallBackend(CeedBasisSetData(basis, impl)); 612 613 // Register backend functions 614 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Apply", CeedBasisApplyNonTensor_Magma)); 615 CeedCallBackend(CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", CeedBasisDestroyNonTensor_Magma)); 616 return CEED_ERROR_SUCCESS; 617 } 618 619 //------------------------------------------------------------------------------ 620