1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include <ceed.h> 18 #include <ceed-backend.h> 19 #include <cuda.h> 20 #include <cuda_runtime.h> 21 #include <stddef.h> 22 #include "ceed-cuda-shared.h" 23 #include "../cuda/ceed-cuda.h" 24 25 //------------------------------------------------------------------------------ 26 // Shared mem kernels 27 //------------------------------------------------------------------------------ 28 // *INDENT-OFF* 29 static const char *kernelsShared = QUOTE( 30 31 //------------------------------------------------------------------------------ 32 // Sum input into output 33 //------------------------------------------------------------------------------ 34 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 35 for (int i = 0; i < P1D; i++) 36 r_V[i] += r_U[i]; 37 } 38 39 //------------------------------------------------------------------------------ 40 // 1D 41 //------------------------------------------------------------------------------ 42 43 //------------------------------------------------------------------------------ 44 // Read DoFs 45 //------------------------------------------------------------------------------ 46 inline __device__ void readDofs1d(const int elem, const int tidx, 47 const int tidy, const int tidz,const int comp, 48 const int nelem, const CeedScalar *d_U, 49 CeedScalar *slice) { 50 for (int i = 0; i < P1D; i++) 51 slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 52 for (int i = P1D; i < Q1D; i++) 53 slice[i + tidz*T1D] = 0.0; 54 } 55 56 //------------------------------------------------------------------------------ 57 // Write DoFs 58 //------------------------------------------------------------------------------ 59 inline __device__ void writeDofs1d(const int elem, const int tidx, 60 const int tidy, const int comp, 61 const int nelem, const CeedScalar &r_V, 62 CeedScalar *d_V) { 63 if (tidx<P1D) 64 d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 65 } 66 67 //------------------------------------------------------------------------------ 68 // Read quadrature point data 69 //------------------------------------------------------------------------------ 70 inline __device__ void readQuads1d(const int elem, const int tidx, 71 const int tidy, const int tidz, const int comp, 72 const int dim, const int nelem, 73 const CeedScalar *d_U, CeedScalar *slice) { 74 for (int i = 0; i < Q1D; i++) 75 slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 76 dim*BASIS_NCOMP*nelem*Q1D]; 77 for (int i = Q1D; i < P1D; i++) 78 slice[i + tidz*T1D] = 0.0; 79 } 80 81 //------------------------------------------------------------------------------ 82 // Write quadrature point data 83 //------------------------------------------------------------------------------ 84 inline __device__ void writeQuads1d(const int elem, const int tidx, 85 const int tidy, const int comp, 86 const int dim, const int nelem, 87 const CeedScalar &r_V, CeedScalar *d_V) { 88 if (tidx<Q1D) 89 d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 90 } 91 92 //------------------------------------------------------------------------------ 93 // 1D tensor contraction 94 //------------------------------------------------------------------------------ 95 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 96 const int tidy, const int tidz, 97 const CeedScalar &U, const CeedScalar *B, 98 CeedScalar &V) { 99 V = 0.0; 100 for (int i = 0; i < P1D; ++i) 101 V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 102 } 103 104 //------------------------------------------------------------------------------ 105 // 1D transpose tensor contraction 106 //------------------------------------------------------------------------------ 107 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 108 const int tidy, const int tidz, 109 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 110 V = 0.0; 111 for (int i = 0; i < Q1D; ++i) 112 V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 113 } 114 115 //------------------------------------------------------------------------------ 116 // 1D interpolate to quadrature points 117 //------------------------------------------------------------------------------ 118 inline __device__ void interp1d(const CeedInt nelem, const int transpose, 119 const CeedScalar *c_B, 120 const CeedScalar *__restrict__ d_U, 121 CeedScalar *__restrict__ d_V, 122 CeedScalar *slice) { 123 CeedScalar r_V; 124 CeedScalar r_t; 125 126 const int tidx = threadIdx.x; 127 const int tidy = threadIdx.y; 128 const int tidz = threadIdx.z; 129 130 131 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 132 elem += gridDim.x*blockDim.z) { 133 for (int comp = 0; comp < BASIS_NCOMP; comp++) { 134 if (!transpose) { 135 readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 136 ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 137 writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 138 } else { 139 readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 140 ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 141 writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 142 } 143 } 144 } 145 } 146 147 //------------------------------------------------------------------------------ 148 // 1D derivatives at quadrature points 149 //------------------------------------------------------------------------------ 150 inline __device__ void grad1d(const CeedInt nelem, const int transpose, 151 const CeedScalar *c_B, const CeedScalar *c_G, 152 const CeedScalar *__restrict__ d_U, 153 CeedScalar *__restrict__ d_V, 154 CeedScalar *slice) { 155 CeedScalar r_U; 156 CeedScalar r_V; 157 158 const int tidx = threadIdx.x; 159 const int tidy = threadIdx.y; 160 const int tidz = threadIdx.z; 161 int dim; 162 163 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 164 elem += gridDim.x*blockDim.z) { 165 for(int comp = 0; comp < BASIS_NCOMP; comp++) { 166 if (!transpose) { 167 readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 168 ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 169 dim = 0; 170 writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 171 } else { 172 dim = 0; 173 readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 174 ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 175 writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 176 } 177 } 178 } 179 } 180 181 //------------------------------------------------------------------------------ 182 // 1D Quadrature weights 183 //------------------------------------------------------------------------------ 184 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 185 CeedScalar *w) { 186 const int tid = threadIdx.x; 187 const CeedScalar weight = qweight1d[tid]; 188 for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 189 elem += gridDim.x*blockDim.y) { 190 const int ind = elem*Q1D + tid; 191 w[ind] = weight; 192 } 193 } 194 195 //------------------------------------------------------------------------------ 196 // 2D 197 //------------------------------------------------------------------------------ 198 199 //------------------------------------------------------------------------------ 200 // Read DoFs 201 //------------------------------------------------------------------------------ 202 inline __device__ void readDofs2d(const int elem, const int tidx, 203 const int tidy, const int comp, 204 const int nelem, const CeedScalar *d_U, 205 CeedScalar &U) { 206 U = (tidx<P1D && tidy<P1D) ? 207 d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 208 } 209 210 //------------------------------------------------------------------------------ 211 // Write DoFs 212 //------------------------------------------------------------------------------ 213 inline __device__ void writeDofs2d(const int elem, const int tidx, 214 const int tidy, const int comp, 215 const int nelem, const CeedScalar &r_V, 216 CeedScalar *d_V) { 217 if (tidx<P1D && tidy<P1D) 218 d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 219 } 220 221 //------------------------------------------------------------------------------ 222 // Read quadrature point data 223 //------------------------------------------------------------------------------ 224 inline __device__ void readQuads2d(const int elem, const int tidx, 225 const int tidy, const int comp, 226 const int dim, const int nelem, 227 const CeedScalar *d_U, CeedScalar &U ) { 228 U = (tidx<Q1D && tidy<Q1D) ? 229 d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 230 dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 231 } 232 233 //------------------------------------------------------------------------------ 234 // Write quadrature point data 235 //------------------------------------------------------------------------------ 236 inline __device__ void writeQuads2d(const int elem, const int tidx, 237 const int tidy, const int comp, 238 const int dim, const int nelem, 239 const CeedScalar &r_V, CeedScalar *d_V) { 240 if (tidx<Q1D && tidy<Q1D) 241 d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 242 dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 243 } 244 245 //------------------------------------------------------------------------------ 246 // 2D tensor contraction x 247 //------------------------------------------------------------------------------ 248 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 249 const int tidy, const int tidz, 250 const CeedScalar &U, const CeedScalar *B, 251 CeedScalar &V) { 252 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 253 __syncthreads(); 254 V = 0.0; 255 if (tidx < Q1D) 256 for (int i = 0; i < P1D; ++i) 257 V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 258 __syncthreads(); 259 } 260 261 //------------------------------------------------------------------------------ 262 // 2D tensor contraction y 263 //------------------------------------------------------------------------------ 264 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 265 const int tidy, const int tidz, 266 const CeedScalar &U, const CeedScalar *B, 267 CeedScalar &V) { 268 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 269 __syncthreads(); 270 V = 0.0; 271 if (tidy < Q1D) 272 for (int i = 0; i < P1D; ++i) 273 V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 274 __syncthreads(); 275 } 276 277 //------------------------------------------------------------------------------ 278 // 2D transpose tensor contraction y 279 //------------------------------------------------------------------------------ 280 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 281 const int tidy, const int tidz, 282 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 283 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 284 __syncthreads(); 285 V = 0.0; 286 if (tidy < P1D) 287 for (int i = 0; i < Q1D; ++i) 288 V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 289 __syncthreads(); 290 } 291 292 //------------------------------------------------------------------------------ 293 // 2D transpose tensor contraction x 294 //------------------------------------------------------------------------------ 295 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 296 const int tidy, const int tidz, 297 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 298 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 299 __syncthreads(); 300 V = 0.0; 301 if (tidx < P1D) 302 for (int i = 0; i < Q1D; ++i) 303 V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 304 __syncthreads(); 305 } 306 307 //------------------------------------------------------------------------------ 308 // 2D interpolate to quadrature points 309 //------------------------------------------------------------------------------ 310 inline __device__ void interp2d(const CeedInt nelem, const int transpose, 311 const CeedScalar *c_B, 312 const CeedScalar *__restrict__ d_U, 313 CeedScalar *__restrict__ d_V, 314 CeedScalar *slice) { 315 CeedScalar r_V; 316 CeedScalar r_t; 317 318 const int tidx = threadIdx.x; 319 const int tidy = threadIdx.y; 320 const int tidz = threadIdx.z; 321 const int blockElem = tidz/BASIS_NCOMP; 322 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 323 const int comp = tidz%BASIS_NCOMP; 324 325 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 326 elem += gridDim.x*elemsPerBlock) { 327 const int comp = tidz%BASIS_NCOMP; 328 r_V = 0.0; 329 r_t = 0.0; 330 if (!transpose) { 331 readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 332 ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 333 ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 334 writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 335 } else { 336 readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 337 ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 338 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 339 writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 340 } 341 } 342 } 343 344 //------------------------------------------------------------------------------ 345 // 2D derivatives at quadrature points 346 //------------------------------------------------------------------------------ 347 inline __device__ void grad2d(const CeedInt nelem, const int transpose, 348 const CeedScalar *c_B, const CeedScalar *c_G, 349 const CeedScalar *__restrict__ d_U, 350 CeedScalar *__restrict__ d_V, CeedScalar *slice) { 351 CeedScalar r_U; 352 CeedScalar r_V; 353 CeedScalar r_t; 354 355 const int tidx = threadIdx.x; 356 const int tidy = threadIdx.y; 357 const int tidz = threadIdx.z; 358 const int blockElem = tidz/BASIS_NCOMP; 359 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 360 const int comp = tidz%BASIS_NCOMP; 361 int dim; 362 363 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 364 elem += gridDim.x*elemsPerBlock) { 365 if (!transpose) { 366 readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 367 ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 368 ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 369 dim = 0; 370 writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 371 ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 372 ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 373 dim = 1; 374 writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 375 } else { 376 dim = 0; 377 readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 378 ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 379 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 380 dim = 1; 381 readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 382 ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 383 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 384 r_V += r_U; 385 writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 386 } 387 } 388 } 389 390 //------------------------------------------------------------------------------ 391 // 2D quadrature weights 392 //------------------------------------------------------------------------------ 393 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 394 CeedScalar *w) { 395 const int i = threadIdx.x; 396 const int j = threadIdx.y; 397 const CeedScalar weight = qweight1d[i]*qweight1d[j]; 398 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 399 elem += gridDim.x*blockDim.z) { 400 const int ind = elem*Q1D*Q1D + i + j*Q1D; 401 w[ind] = weight; 402 } 403 } 404 405 //------------------------------------------------------------------------------ 406 // 3D 407 //------------------------------------------------------------------------------ 408 409 //------------------------------------------------------------------------------ 410 // Read DoFs 411 //------------------------------------------------------------------------------ 412 inline __device__ void readDofs3d(const int elem, const int tidx, 413 const int tidy, const int comp, 414 const int nelem, const CeedScalar *d_U, 415 CeedScalar *r_U) { 416 for (int i = 0; i < P1D; i++) 417 r_U[i] = (tidx < P1D && tidy < P1D) ? 418 d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 419 comp*P1D*P1D*P1D*nelem] : 0.0; 420 for (int i = P1D; i < Q1D; i++) 421 r_U[i] = 0.0; 422 } 423 424 //------------------------------------------------------------------------------ 425 // Write DoFs 426 //------------------------------------------------------------------------------ 427 inline __device__ void writeDofs3d(const int elem, const int tidx, 428 const int tidy, const int comp, 429 const int nelem, const CeedScalar *r_V, 430 CeedScalar *d_V) { 431 if (tidx < P1D && tidy < P1D) { 432 for (int i = 0; i < P1D; i++) 433 d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 434 comp*P1D*P1D*P1D*nelem] = r_V[i]; 435 } 436 } 437 438 //------------------------------------------------------------------------------ 439 // Read quadrature point data 440 //------------------------------------------------------------------------------ 441 inline __device__ void readQuads3d(const int elem, const int tidx, 442 const int tidy, const int comp, 443 const int dim, const int nelem, 444 const CeedScalar *d_U, CeedScalar *r_U) { 445 for (int i = 0; i < Q1D; i++) 446 r_U[i] = (tidx < Q1D && tidy < Q1D) ? 447 d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 448 comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 449 for (int i = Q1D; i < P1D; i++) 450 r_U[i] = 0.0; 451 } 452 453 //------------------------------------------------------------------------------ 454 // Write quadrature point data 455 //------------------------------------------------------------------------------ 456 inline __device__ void writeQuads3d(const int elem, const int tidx, 457 const int tidy, const int comp, 458 const int dim, const int nelem, 459 const CeedScalar *r_V, CeedScalar *d_V) { 460 if (tidx < Q1D && tidy < Q1D) { 461 for (int i = 0; i < Q1D; i++) 462 d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 463 dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 464 } 465 } 466 467 //------------------------------------------------------------------------------ 468 // 3D tensor contract x 469 //------------------------------------------------------------------------------ 470 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 471 const int tidy, const int tidz, 472 const CeedScalar *U, 473 const CeedScalar *B, 474 CeedScalar *V) { 475 for (int k = 0; k < P1D; ++k) { 476 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 477 __syncthreads(); 478 V[k] = 0.0; 479 if (tidx < Q1D && tidy < P1D) 480 for (int i = 0; i < P1D; ++i) 481 V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 482 __syncthreads(); 483 } 484 } 485 486 //------------------------------------------------------------------------------ 487 // 3D tensor contract y 488 //------------------------------------------------------------------------------ 489 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 490 const int tidy, const int tidz, 491 const CeedScalar *U, 492 const CeedScalar *B, 493 CeedScalar *V) { 494 for (int k = 0; k < P1D; ++k) { 495 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 496 __syncthreads(); 497 V[k] = 0.0; 498 if (tidx < Q1D && tidy < Q1D) 499 for (int i = 0; i < P1D; ++i) 500 V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 501 __syncthreads(); 502 } 503 } 504 505 //------------------------------------------------------------------------------ 506 // 3D tensor contract z 507 //------------------------------------------------------------------------------ 508 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 509 const int tidy, const int tidz, 510 const CeedScalar *U, 511 const CeedScalar *B, 512 CeedScalar *V) { 513 for (int k = 0; k < Q1D; ++k) { 514 V[k] = 0.0; 515 if (tidx < Q1D && tidy < Q1D) 516 for (int i = 0; i < P1D; ++i) 517 V[k] += B[i + k*P1D] * U[i]; // Contract z direction 518 } 519 for (int k = Q1D; k < P1D; ++k) 520 V[k] = 0.0; 521 } 522 523 //------------------------------------------------------------------------------ 524 // 3D transpose tensor contract z 525 //------------------------------------------------------------------------------ 526 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 527 const int tidy, const int tidz, 528 const CeedScalar *U, 529 const CeedScalar *B, 530 CeedScalar *V) { 531 for (int k = 0; k < P1D; ++k) { 532 V[k] = 0.0; 533 if (tidx < Q1D && tidy < Q1D) 534 for (int i = 0; i < Q1D; ++i) 535 V[k] += B[k + i*P1D] * U[i]; // Contract z direction 536 } 537 for (int k = P1D; k < Q1D; ++k) 538 V[k] = 0.0; 539 } 540 541 //------------------------------------------------------------------------------ 542 // 3D transpose tensor contract y 543 //------------------------------------------------------------------------------ 544 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 545 const int tidy, const int tidz, 546 const CeedScalar *U, 547 const CeedScalar *B, 548 CeedScalar *V) { 549 for (int k = 0; k < P1D; ++k) { 550 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 551 __syncthreads(); 552 V[k] = 0.0; 553 if (tidx < Q1D && tidy < P1D) 554 for (int i = 0; i < Q1D; ++i) 555 V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 556 __syncthreads(); 557 } 558 } 559 560 //------------------------------------------------------------------------------ 561 // 3D transpose tensor contract x 562 //------------------------------------------------------------------------------ 563 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 564 const int tidy, const int tidz, 565 const CeedScalar *U, 566 const CeedScalar *B, 567 CeedScalar *V) { 568 for (int k = 0; k < P1D; ++k) { 569 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 570 __syncthreads(); 571 V[k] = 0.0; 572 if (tidx < P1D && tidy < P1D) 573 for (int i = 0; i < Q1D; ++i) 574 V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 575 __syncthreads(); 576 } 577 } 578 579 //------------------------------------------------------------------------------ 580 // 3D interpolate to quadrature points 581 //------------------------------------------------------------------------------ 582 inline __device__ void interp3d(const CeedInt nelem, const int transpose, 583 const CeedScalar *c_B, 584 const CeedScalar *__restrict__ d_U, 585 CeedScalar *__restrict__ d_V, 586 CeedScalar *slice) { 587 CeedScalar r_V[T1D]; 588 CeedScalar r_t[T1D]; 589 590 const int tidx = threadIdx.x; 591 const int tidy = threadIdx.y; 592 const int tidz = threadIdx.z; 593 const int blockElem = tidz/BASIS_NCOMP; 594 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 595 const int comp = tidz%BASIS_NCOMP; 596 597 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 598 elem += gridDim.x*elemsPerBlock) { 599 for (int i = 0; i < T1D; ++i) { 600 r_V[i] = 0.0; 601 r_t[i] = 0.0; 602 } 603 if (!transpose) { 604 readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 605 ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 606 ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 607 ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 608 writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 609 } else { 610 readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 611 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 612 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 613 ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 614 writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 615 } 616 } 617 } 618 619 //------------------------------------------------------------------------------ 620 // 3D derivatives at quadrature points 621 //------------------------------------------------------------------------------ 622 inline __device__ void grad3d(const CeedInt nelem, const int transpose, 623 const CeedScalar *c_B, const CeedScalar *c_G, 624 const CeedScalar *__restrict__ d_U, 625 CeedScalar *__restrict__ d_V, 626 CeedScalar *slice) { 627 // Use P1D for one of these 628 CeedScalar r_U[T1D]; 629 CeedScalar r_V[T1D]; 630 CeedScalar r_t[T1D]; 631 632 const int tidx = threadIdx.x; 633 const int tidy = threadIdx.y; 634 const int tidz = threadIdx.z; 635 const int blockElem = tidz/BASIS_NCOMP; 636 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 637 const int comp = tidz%BASIS_NCOMP; 638 int dim; 639 640 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 641 elem += gridDim.x*elemsPerBlock) { 642 for (int i = 0; i < T1D; ++i) { 643 r_U[i] = 0.0; 644 r_V[i] = 0.0; 645 r_t[i] = 0.0; 646 } 647 if (!transpose) { 648 readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 649 ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 650 ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 651 ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 652 dim = 0; 653 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 654 ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 655 ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t); 656 ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 657 dim = 1; 658 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 659 ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 660 ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 661 ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 662 dim = 2; 663 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 664 } else { 665 dim = 0; 666 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 667 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 668 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 669 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 670 dim = 1; 671 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 672 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 673 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U); 674 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 675 add(r_V, r_t); 676 dim = 2; 677 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 678 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 679 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 680 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 681 add(r_V, r_t); 682 writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 683 } 684 } 685 } 686 687 //------------------------------------------------------------------------------ 688 // 3D quadrature weights 689 //------------------------------------------------------------------------------ 690 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 691 CeedScalar *w) { 692 const int i = threadIdx.x; 693 const int j = threadIdx.y; 694 const int k = threadIdx.z; 695 const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 696 for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 697 const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 698 w[ind] = weight; 699 } 700 } 701 702 703 //------------------------------------------------------------------------------ 704 // Basis kernels 705 //------------------------------------------------------------------------------ 706 707 //------------------------------------------------------------------------------ 708 // Interp kernel by dim 709 //------------------------------------------------------------------------------ 710 extern "C" __global__ void interp(const CeedInt nelem, const int transpose, 711 const CeedScalar *c_B, 712 const CeedScalar *__restrict__ d_U, 713 CeedScalar *__restrict__ d_V) { 714 extern __shared__ double slice[]; 715 if (BASIS_DIM == 1) { 716 interp1d(nelem, transpose, c_B, d_U, d_V, slice); 717 } else if (BASIS_DIM == 2) { 718 interp2d(nelem, transpose, c_B, d_U, d_V, slice); 719 } else if (BASIS_DIM == 3) { 720 interp3d(nelem, transpose, c_B, d_U, d_V, slice); 721 } 722 } 723 724 //------------------------------------------------------------------------------ 725 // Grad kernel by dim 726 //------------------------------------------------------------------------------ 727 extern "C" __global__ void grad(const CeedInt nelem, const int transpose, 728 const CeedScalar *c_B, const CeedScalar *c_G, 729 const CeedScalar *__restrict__ d_U, 730 CeedScalar *__restrict__ d_V) { 731 extern __shared__ double slice[]; 732 if (BASIS_DIM == 1) { 733 grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 734 } else if (BASIS_DIM == 2) { 735 grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 736 } else if (BASIS_DIM == 3) { 737 grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 738 } 739 } 740 741 //------------------------------------------------------------------------------ 742 // Weight kernels by dim 743 //------------------------------------------------------------------------------ 744 extern "C" __global__ void weight(const CeedInt nelem, 745 const CeedScalar *__restrict__ qweight1d, 746 CeedScalar *__restrict__ v) { 747 if (BASIS_DIM == 1) { 748 weight1d(nelem, qweight1d, v); 749 } else if (BASIS_DIM == 2) { 750 weight2d(nelem, qweight1d, v); 751 } else if (BASIS_DIM == 3) { 752 weight3d(nelem, qweight1d, v); 753 } 754 } 755 756 ); 757 // *INDENT-ON* 758 759 //------------------------------------------------------------------------------ 760 // Device initalization 761 //------------------------------------------------------------------------------ 762 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d, 763 CeedScalar **c_B); 764 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d, 765 CeedInt Q1d, CeedScalar **c_B_ptr, 766 CeedScalar **c_G_ptr); 767 768 //------------------------------------------------------------------------------ 769 // Apply basis 770 //------------------------------------------------------------------------------ 771 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem, 772 CeedTransposeMode tmode, 773 CeedEvalMode emode, CeedVector u, 774 CeedVector v) { 775 int ierr; 776 Ceed ceed; 777 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 778 Ceed_Cuda_shared *ceed_Cuda; 779 CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr); 780 CeedBasis_Cuda_shared *data; 781 CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 782 const CeedInt transpose = tmode == CEED_TRANSPOSE; 783 CeedInt dim, ncomp; 784 ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr); 785 ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 786 787 // Read vectors 788 const CeedScalar *d_u; 789 CeedScalar *d_v; 790 if (emode != CEED_EVAL_WEIGHT) { 791 ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr); 792 } 793 ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr); 794 795 // Clear v for transpose mode 796 if (tmode == CEED_TRANSPOSE) { 797 CeedInt length; 798 ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr); 799 ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr); 800 } 801 802 // Apply basis operation 803 switch (emode) { 804 case CEED_EVAL_INTERP: { 805 CeedInt P1d, Q1d; 806 ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 807 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 808 CeedInt thread1d = CeedIntMax(Q1d, P1d); 809 ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B); 810 CeedChkBackend(ierr); 811 void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 812 &d_u, &d_v 813 }; 814 if (dim == 1) { 815 CeedInt elemsPerBlock = 32; 816 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 817 ? 1 : 0 ); 818 CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 819 ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, 1, 820 elemsPerBlock, sharedMem, 821 interpargs); CeedChkBackend(ierr); 822 } else if (dim == 2) { 823 const CeedInt optElems[7] = {0,32,8,6,4,2,8}; 824 // elemsPerBlock must be at least 1 825 CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1); 826 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 827 ? 1 : 0 ); 828 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 829 ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d, 830 ncomp*elemsPerBlock, sharedMem, 831 interpargs); CeedChkBackend(ierr); 832 } else if (dim == 3) { 833 CeedInt elemsPerBlock = 1; 834 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 835 ? 1 : 0 ); 836 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 837 ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d, 838 ncomp*elemsPerBlock, sharedMem, 839 interpargs); CeedChkBackend(ierr); 840 } 841 } break; 842 case CEED_EVAL_GRAD: { 843 CeedInt P1d, Q1d; 844 ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 845 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 846 CeedInt thread1d = CeedIntMax(Q1d, P1d); 847 ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d, 848 Q1d, &data->c_B, &data->c_G); 849 CeedChkBackend(ierr); 850 void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 851 &data->c_G, &d_u, &d_v 852 }; 853 if (dim == 1) { 854 CeedInt elemsPerBlock = 32; 855 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 856 ? 1 : 0 ); 857 CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 858 ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, 1, 859 elemsPerBlock, sharedMem, gradargs); 860 CeedChkBackend(ierr); 861 } else if (dim == 2) { 862 const CeedInt optElems[7] = {0,32,8,6,4,2,8}; 863 // elemsPerBlock must be at least 1 864 CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1); 865 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 866 ? 1 : 0 ); 867 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 868 ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d, 869 ncomp*elemsPerBlock, sharedMem, 870 gradargs); CeedChkBackend(ierr); 871 } else if (dim == 3) { 872 CeedInt elemsPerBlock = 1; 873 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 874 ? 1 : 0 ); 875 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 876 ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d, 877 ncomp*elemsPerBlock, sharedMem, 878 gradargs); CeedChkBackend(ierr); 879 } 880 } break; 881 case CEED_EVAL_WEIGHT: { 882 CeedInt Q1d; 883 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 884 void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 885 if (dim == 1) { 886 const CeedInt elemsPerBlock = 32/Q1d; 887 const CeedInt gridsize = nelem/elemsPerBlock + ( ( 888 nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 889 ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, 890 elemsPerBlock, 1, weightargs); 891 CeedChkBackend(ierr); 892 } else if (dim == 2) { 893 const CeedInt optElems = 32/(Q1d*Q1d); 894 const CeedInt elemsPerBlock = optElems>0?optElems:1; 895 const CeedInt gridsize = nelem/elemsPerBlock + ( ( 896 nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 897 ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, 898 elemsPerBlock, weightargs); 899 CeedChkBackend(ierr); 900 } else if (dim == 3) { 901 const CeedInt gridsize = nelem; 902 ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 903 weightargs); 904 CeedChkBackend(ierr); 905 } 906 } break; 907 // LCOV_EXCL_START 908 // Evaluate the divergence to/from the quadrature points 909 case CEED_EVAL_DIV: 910 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported"); 911 // Evaluate the curl to/from the quadrature points 912 case CEED_EVAL_CURL: 913 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported"); 914 // Take no action, BasisApply should not have been called 915 case CEED_EVAL_NONE: 916 return CeedError(ceed, CEED_ERROR_BACKEND, 917 "CEED_EVAL_NONE does not make sense in this context"); 918 // LCOV_EXCL_STOP 919 } 920 921 // Restore vectors 922 if (emode != CEED_EVAL_WEIGHT) { 923 ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr); 924 } 925 ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr); 926 return CEED_ERROR_SUCCESS; 927 } 928 929 //------------------------------------------------------------------------------ 930 // Destroy basis 931 //------------------------------------------------------------------------------ 932 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) { 933 int ierr; 934 Ceed ceed; 935 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 936 937 CeedBasis_Cuda_shared *data; 938 ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 939 940 CeedChk_Cu(ceed, cuModuleUnload(data->module)); 941 942 ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr); 943 ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr); 944 ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr); 945 ierr = cudaFree(data->d_collograd1d); CeedChk_Cu(ceed, ierr); 946 947 ierr = CeedFree(&data); CeedChkBackend(ierr); 948 949 return CEED_ERROR_SUCCESS; 950 } 951 952 //------------------------------------------------------------------------------ 953 // Create tensor basis 954 //------------------------------------------------------------------------------ 955 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 956 const CeedScalar *interp1d, 957 const CeedScalar *grad1d, 958 const CeedScalar *qref1d, 959 const CeedScalar *qweight1d, 960 CeedBasis basis) { 961 int ierr; 962 Ceed ceed; 963 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 964 CeedBasis_Cuda_shared *data; 965 ierr = CeedCalloc(1, &data); CeedChkBackend(ierr); 966 967 // Copy basis data to GPU 968 const CeedInt qBytes = Q1d * sizeof(CeedScalar); 969 ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr); 970 ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes, 971 cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 972 973 const CeedInt iBytes = qBytes * P1d; 974 ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr); 975 ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes, 976 cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 977 978 ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr); 979 ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes, 980 cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 981 982 // Compute collocated gradient and copy to GPU 983 data->d_collograd1d = NULL; 984 if (dim == 3 && Q1d >= P1d) { 985 CeedScalar *collograd1d; 986 ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr); 987 ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr); 988 ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 989 CeedChk_Cu(ceed, ierr); 990 ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 991 cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 992 ierr = CeedFree(&collograd1d); CeedChkBackend(ierr); 993 } 994 995 // Compile basis kernels 996 CeedInt ncomp; 997 ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 998 ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 8, 999 "Q1D", Q1d, 1000 "P1D", P1d, 1001 "T1D", CeedIntMax(Q1d, P1d), 1002 "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 1003 Q1d : P1d, dim), 1004 "BASIS_DIM", dim, 1005 "BASIS_NCOMP", ncomp, 1006 "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 1007 "BASIS_NQPT", CeedIntPow(Q1d, dim) 1008 ); CeedChkBackend(ierr); 1009 ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp); 1010 CeedChkBackend(ierr); 1011 ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad); 1012 CeedChkBackend(ierr); 1013 ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight); 1014 CeedChkBackend(ierr); 1015 1016 ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr); 1017 1018 // Register backend functions 1019 ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 1020 CeedBasisApplyTensor_Cuda_shared); 1021 CeedChkBackend(ierr); 1022 ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 1023 CeedBasisDestroy_Cuda_shared); CeedChkBackend(ierr); 1024 return CEED_ERROR_SUCCESS; 1025 } 1026 //------------------------------------------------------------------------------ 1027