1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include <ceed/ceed.h> 18 #include <ceed/backend.h> 19 #include <cuda.h> 20 #include <cuda_runtime.h> 21 #include <stddef.h> 22 #include "ceed-cuda-shared.h" 23 #include "../cuda/ceed-cuda.h" 24 25 //------------------------------------------------------------------------------ 26 // Shared mem kernels 27 //------------------------------------------------------------------------------ 28 // *INDENT-OFF* 29 static const char *kernelsShared = QUOTE( 30 31 //------------------------------------------------------------------------------ 32 // Sum input into output 33 //------------------------------------------------------------------------------ 34 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 35 for (int i = 0; i < P1D; i++) 36 r_V[i] += r_U[i]; 37 } 38 39 //------------------------------------------------------------------------------ 40 // 1D 41 //------------------------------------------------------------------------------ 42 43 //------------------------------------------------------------------------------ 44 // Read DoFs 45 //------------------------------------------------------------------------------ 46 inline __device__ void readDofs1d(const int elem, const int tidx, 47 const int tidy, const int tidz,const int comp, 48 const int nelem, const CeedScalar *d_U, 49 CeedScalar *slice) { 50 for (int i = 0; i < P1D; i++) 51 slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 52 for (int i = P1D; i < Q1D; i++) 53 slice[i + tidz*T1D] = 0.0; 54 } 55 56 //------------------------------------------------------------------------------ 57 // Write DoFs 58 //------------------------------------------------------------------------------ 59 inline __device__ void writeDofs1d(const int elem, const int tidx, 60 const int tidy, const int comp, 61 const int nelem, const CeedScalar &r_V, 62 CeedScalar *d_V) { 63 if (tidx<P1D) 64 d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 65 } 66 67 //------------------------------------------------------------------------------ 68 // Read quadrature point data 69 //------------------------------------------------------------------------------ 70 inline __device__ void readQuads1d(const int elem, const int tidx, 71 const int tidy, const int tidz, const int comp, 72 const int dim, const int nelem, 73 const CeedScalar *d_U, CeedScalar *slice) { 74 for (int i = 0; i < Q1D; i++) 75 slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 76 dim*BASIS_NCOMP*nelem*Q1D]; 77 for (int i = Q1D; i < P1D; i++) 78 slice[i + tidz*T1D] = 0.0; 79 } 80 81 //------------------------------------------------------------------------------ 82 // Write quadrature point data 83 //------------------------------------------------------------------------------ 84 inline __device__ void writeQuads1d(const int elem, const int tidx, 85 const int tidy, const int comp, 86 const int dim, const int nelem, 87 const CeedScalar &r_V, CeedScalar *d_V) { 88 if (tidx<Q1D) 89 d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 90 } 91 92 //------------------------------------------------------------------------------ 93 // 1D tensor contraction 94 //------------------------------------------------------------------------------ 95 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 96 const int tidy, const int tidz, 97 const CeedScalar &U, const CeedScalar *B, 98 CeedScalar &V) { 99 V = 0.0; 100 for (int i = 0; i < P1D; ++i) 101 V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 102 } 103 104 //------------------------------------------------------------------------------ 105 // 1D transpose tensor contraction 106 //------------------------------------------------------------------------------ 107 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 108 const int tidy, const int tidz, 109 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 110 V = 0.0; 111 for (int i = 0; i < Q1D; ++i) 112 V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 113 } 114 115 //------------------------------------------------------------------------------ 116 // 1D interpolate to quadrature points 117 //------------------------------------------------------------------------------ 118 inline __device__ void interp1d(const CeedInt nelem, const int transpose, 119 const CeedScalar *c_B, 120 const CeedScalar *__restrict__ d_U, 121 CeedScalar *__restrict__ d_V, 122 CeedScalar *slice) { 123 CeedScalar r_V; 124 CeedScalar r_t; 125 126 const int tidx = threadIdx.x; 127 const int tidy = threadIdx.y; 128 const int tidz = threadIdx.z; 129 130 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 131 elem += gridDim.x*blockDim.z) { 132 for (int comp = 0; comp < BASIS_NCOMP; comp++) { 133 if (!transpose) { 134 readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 135 ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 136 writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 137 } else { 138 readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 139 ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 140 writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 141 } 142 } 143 } 144 } 145 146 //------------------------------------------------------------------------------ 147 // 1D derivatives at quadrature points 148 //------------------------------------------------------------------------------ 149 inline __device__ void grad1d(const CeedInt nelem, const int transpose, 150 const CeedScalar *c_B, const CeedScalar *c_G, 151 const CeedScalar *__restrict__ d_U, 152 CeedScalar *__restrict__ d_V, 153 CeedScalar *slice) { 154 CeedScalar r_U; 155 CeedScalar r_V; 156 157 const int tidx = threadIdx.x; 158 const int tidy = threadIdx.y; 159 const int tidz = threadIdx.z; 160 int dim; 161 162 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 163 elem += gridDim.x*blockDim.z) { 164 for(int comp = 0; comp < BASIS_NCOMP; comp++) { 165 if (!transpose) { 166 readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 167 ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 168 dim = 0; 169 writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 170 } else { 171 dim = 0; 172 readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 173 ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 174 writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 175 } 176 } 177 } 178 } 179 180 //------------------------------------------------------------------------------ 181 // 1D Quadrature weights 182 //------------------------------------------------------------------------------ 183 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 184 CeedScalar *w) { 185 const int tid = threadIdx.x; 186 const CeedScalar weight = qweight1d[tid]; 187 for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 188 elem += gridDim.x*blockDim.y) { 189 const int ind = elem*Q1D + tid; 190 w[ind] = weight; 191 } 192 } 193 194 //------------------------------------------------------------------------------ 195 // 2D 196 //------------------------------------------------------------------------------ 197 198 //------------------------------------------------------------------------------ 199 // Read DoFs 200 //------------------------------------------------------------------------------ 201 inline __device__ void readDofs2d(const int elem, const int tidx, 202 const int tidy, const int comp, 203 const int nelem, const CeedScalar *d_U, 204 CeedScalar &U) { 205 U = (tidx<P1D && tidy<P1D) ? 206 d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 207 } 208 209 //------------------------------------------------------------------------------ 210 // Write DoFs 211 //------------------------------------------------------------------------------ 212 inline __device__ void writeDofs2d(const int elem, const int tidx, 213 const int tidy, const int comp, 214 const int nelem, const CeedScalar &r_V, 215 CeedScalar *d_V) { 216 if (tidx<P1D && tidy<P1D) 217 d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 218 } 219 220 //------------------------------------------------------------------------------ 221 // Read quadrature point data 222 //------------------------------------------------------------------------------ 223 inline __device__ void readQuads2d(const int elem, const int tidx, 224 const int tidy, const int comp, 225 const int dim, const int nelem, 226 const CeedScalar *d_U, CeedScalar &U ) { 227 U = (tidx<Q1D && tidy<Q1D) ? 228 d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 229 dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 230 } 231 232 //------------------------------------------------------------------------------ 233 // Write quadrature point data 234 //------------------------------------------------------------------------------ 235 inline __device__ void writeQuads2d(const int elem, const int tidx, 236 const int tidy, const int comp, 237 const int dim, const int nelem, 238 const CeedScalar &r_V, CeedScalar *d_V) { 239 if (tidx<Q1D && tidy<Q1D) 240 d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 241 dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 242 } 243 244 //------------------------------------------------------------------------------ 245 // 2D tensor contraction x 246 //------------------------------------------------------------------------------ 247 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 248 const int tidy, const int tidz, 249 const CeedScalar &U, const CeedScalar *B, 250 CeedScalar &V) { 251 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 252 __syncthreads(); 253 V = 0.0; 254 if (tidx < Q1D) 255 for (int i = 0; i < P1D; ++i) 256 V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 257 __syncthreads(); 258 } 259 260 //------------------------------------------------------------------------------ 261 // 2D tensor contraction y 262 //------------------------------------------------------------------------------ 263 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 264 const int tidy, const int tidz, 265 const CeedScalar &U, const CeedScalar *B, 266 CeedScalar &V) { 267 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 268 __syncthreads(); 269 V = 0.0; 270 if (tidy < Q1D) 271 for (int i = 0; i < P1D; ++i) 272 V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 273 __syncthreads(); 274 } 275 276 //------------------------------------------------------------------------------ 277 // 2D transpose tensor contraction y 278 //------------------------------------------------------------------------------ 279 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 280 const int tidy, const int tidz, 281 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 282 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 283 __syncthreads(); 284 V = 0.0; 285 if (tidy < P1D) 286 for (int i = 0; i < Q1D; ++i) 287 V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 288 __syncthreads(); 289 } 290 291 //------------------------------------------------------------------------------ 292 // 2D transpose tensor contraction x 293 //------------------------------------------------------------------------------ 294 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 295 const int tidy, const int tidz, 296 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 297 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 298 __syncthreads(); 299 V = 0.0; 300 if (tidx < P1D) 301 for (int i = 0; i < Q1D; ++i) 302 V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 303 __syncthreads(); 304 } 305 306 //------------------------------------------------------------------------------ 307 // 2D interpolate to quadrature points 308 //------------------------------------------------------------------------------ 309 inline __device__ void interp2d(const CeedInt nelem, const int transpose, 310 const CeedScalar *c_B, 311 const CeedScalar *__restrict__ d_U, 312 CeedScalar *__restrict__ d_V, 313 CeedScalar *slice) { 314 CeedScalar r_V; 315 CeedScalar r_t; 316 317 const int tidx = threadIdx.x; 318 const int tidy = threadIdx.y; 319 const int tidz = threadIdx.z; 320 const int blockElem = tidz/BASIS_NCOMP; 321 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 322 const int comp = tidz%BASIS_NCOMP; 323 324 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 325 elem += gridDim.x*elemsPerBlock) { 326 const int comp = tidz%BASIS_NCOMP; 327 r_V = 0.0; 328 r_t = 0.0; 329 if (!transpose) { 330 readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 331 ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 332 ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 333 writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 334 } else { 335 readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 336 ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 337 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 338 writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 339 } 340 } 341 } 342 343 //------------------------------------------------------------------------------ 344 // 2D derivatives at quadrature points 345 //------------------------------------------------------------------------------ 346 inline __device__ void grad2d(const CeedInt nelem, const int transpose, 347 const CeedScalar *c_B, const CeedScalar *c_G, 348 const CeedScalar *__restrict__ d_U, 349 CeedScalar *__restrict__ d_V, CeedScalar *slice) { 350 CeedScalar r_U; 351 CeedScalar r_V; 352 CeedScalar r_t; 353 354 const int tidx = threadIdx.x; 355 const int tidy = threadIdx.y; 356 const int tidz = threadIdx.z; 357 const int blockElem = tidz/BASIS_NCOMP; 358 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 359 const int comp = tidz%BASIS_NCOMP; 360 int dim; 361 362 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 363 elem += gridDim.x*elemsPerBlock) { 364 if (!transpose) { 365 readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 366 ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 367 ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 368 dim = 0; 369 writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 370 ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 371 ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 372 dim = 1; 373 writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 374 } else { 375 dim = 0; 376 readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 377 ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 378 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 379 dim = 1; 380 readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 381 ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 382 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 383 r_V += r_U; 384 writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 385 } 386 } 387 } 388 389 //------------------------------------------------------------------------------ 390 // 2D quadrature weights 391 //------------------------------------------------------------------------------ 392 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 393 CeedScalar *w) { 394 const int i = threadIdx.x; 395 const int j = threadIdx.y; 396 const CeedScalar weight = qweight1d[i]*qweight1d[j]; 397 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 398 elem += gridDim.x*blockDim.z) { 399 const int ind = elem*Q1D*Q1D + i + j*Q1D; 400 w[ind] = weight; 401 } 402 } 403 404 //------------------------------------------------------------------------------ 405 // 3D 406 //------------------------------------------------------------------------------ 407 408 //------------------------------------------------------------------------------ 409 // Read DoFs 410 //------------------------------------------------------------------------------ 411 inline __device__ void readDofs3d(const int elem, const int tidx, 412 const int tidy, const int comp, 413 const int nelem, const CeedScalar *d_U, 414 CeedScalar *r_U) { 415 for (int i = 0; i < P1D; i++) 416 r_U[i] = (tidx < P1D && tidy < P1D) ? 417 d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 418 comp*P1D*P1D*P1D*nelem] : 0.0; 419 for (int i = P1D; i < Q1D; i++) 420 r_U[i] = 0.0; 421 } 422 423 //------------------------------------------------------------------------------ 424 // Write DoFs 425 //------------------------------------------------------------------------------ 426 inline __device__ void writeDofs3d(const int elem, const int tidx, 427 const int tidy, const int comp, 428 const int nelem, const CeedScalar *r_V, 429 CeedScalar *d_V) { 430 if (tidx < P1D && tidy < P1D) { 431 for (int i = 0; i < P1D; i++) 432 d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 433 comp*P1D*P1D*P1D*nelem] = r_V[i]; 434 } 435 } 436 437 //------------------------------------------------------------------------------ 438 // Read quadrature point data 439 //------------------------------------------------------------------------------ 440 inline __device__ void readQuads3d(const int elem, const int tidx, 441 const int tidy, const int comp, 442 const int dim, const int nelem, 443 const CeedScalar *d_U, CeedScalar *r_U) { 444 for (int i = 0; i < Q1D; i++) 445 r_U[i] = (tidx < Q1D && tidy < Q1D) ? 446 d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 447 comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 448 for (int i = Q1D; i < P1D; i++) 449 r_U[i] = 0.0; 450 } 451 452 //------------------------------------------------------------------------------ 453 // Write quadrature point data 454 //------------------------------------------------------------------------------ 455 inline __device__ void writeQuads3d(const int elem, const int tidx, 456 const int tidy, const int comp, 457 const int dim, const int nelem, 458 const CeedScalar *r_V, CeedScalar *d_V) { 459 if (tidx < Q1D && tidy < Q1D) { 460 for (int i = 0; i < Q1D; i++) 461 d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 462 dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 463 } 464 } 465 466 //------------------------------------------------------------------------------ 467 // 3D tensor contract x 468 //------------------------------------------------------------------------------ 469 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 470 const int tidy, const int tidz, 471 const CeedScalar *U, 472 const CeedScalar *B, 473 CeedScalar *V) { 474 for (int k = 0; k < P1D; ++k) { 475 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 476 __syncthreads(); 477 V[k] = 0.0; 478 if (tidx < Q1D && tidy < P1D) 479 for (int i = 0; i < P1D; ++i) 480 V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 481 __syncthreads(); 482 } 483 } 484 485 //------------------------------------------------------------------------------ 486 // 3D tensor contract y 487 //------------------------------------------------------------------------------ 488 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 489 const int tidy, const int tidz, 490 const CeedScalar *U, 491 const CeedScalar *B, 492 CeedScalar *V) { 493 for (int k = 0; k < P1D; ++k) { 494 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 495 __syncthreads(); 496 V[k] = 0.0; 497 if (tidx < Q1D && tidy < Q1D) 498 for (int i = 0; i < P1D; ++i) 499 V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 500 __syncthreads(); 501 } 502 } 503 504 //------------------------------------------------------------------------------ 505 // 3D tensor contract z 506 //------------------------------------------------------------------------------ 507 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 508 const int tidy, const int tidz, 509 const CeedScalar *U, 510 const CeedScalar *B, 511 CeedScalar *V) { 512 for (int k = 0; k < Q1D; ++k) { 513 V[k] = 0.0; 514 if (tidx < Q1D && tidy < Q1D) 515 for (int i = 0; i < P1D; ++i) 516 V[k] += B[i + k*P1D] * U[i]; // Contract z direction 517 } 518 for (int k = Q1D; k < P1D; ++k) 519 V[k] = 0.0; 520 } 521 522 //------------------------------------------------------------------------------ 523 // 3D transpose tensor contract z 524 //------------------------------------------------------------------------------ 525 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 526 const int tidy, const int tidz, 527 const CeedScalar *U, 528 const CeedScalar *B, 529 CeedScalar *V) { 530 for (int k = 0; k < P1D; ++k) { 531 V[k] = 0.0; 532 if (tidx < Q1D && tidy < Q1D) 533 for (int i = 0; i < Q1D; ++i) 534 V[k] += B[k + i*P1D] * U[i]; // Contract z direction 535 } 536 for (int k = P1D; k < Q1D; ++k) 537 V[k] = 0.0; 538 } 539 540 //------------------------------------------------------------------------------ 541 // 3D transpose tensor contract y 542 //------------------------------------------------------------------------------ 543 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 544 const int tidy, const int tidz, 545 const CeedScalar *U, 546 const CeedScalar *B, 547 CeedScalar *V) { 548 for (int k = 0; k < P1D; ++k) { 549 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 550 __syncthreads(); 551 V[k] = 0.0; 552 if (tidx < Q1D && tidy < P1D) 553 for (int i = 0; i < Q1D; ++i) 554 V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 555 __syncthreads(); 556 } 557 } 558 559 //------------------------------------------------------------------------------ 560 // 3D transpose tensor contract x 561 //------------------------------------------------------------------------------ 562 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 563 const int tidy, const int tidz, 564 const CeedScalar *U, 565 const CeedScalar *B, 566 CeedScalar *V) { 567 for (int k = 0; k < P1D; ++k) { 568 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 569 __syncthreads(); 570 V[k] = 0.0; 571 if (tidx < P1D && tidy < P1D) 572 for (int i = 0; i < Q1D; ++i) 573 V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 574 __syncthreads(); 575 } 576 } 577 578 //------------------------------------------------------------------------------ 579 // 3D interpolate to quadrature points 580 //------------------------------------------------------------------------------ 581 inline __device__ void interp3d(const CeedInt nelem, const int transpose, 582 const CeedScalar *c_B, 583 const CeedScalar *__restrict__ d_U, 584 CeedScalar *__restrict__ d_V, 585 CeedScalar *slice) { 586 CeedScalar r_V[T1D]; 587 CeedScalar r_t[T1D]; 588 589 const int tidx = threadIdx.x; 590 const int tidy = threadIdx.y; 591 const int tidz = threadIdx.z; 592 const int blockElem = tidz/BASIS_NCOMP; 593 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 594 const int comp = tidz%BASIS_NCOMP; 595 596 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 597 elem += gridDim.x*elemsPerBlock) { 598 for (int i = 0; i < T1D; ++i) { 599 r_V[i] = 0.0; 600 r_t[i] = 0.0; 601 } 602 if (!transpose) { 603 readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 604 ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 605 ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 606 ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 607 writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 608 } else { 609 readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 610 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 611 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 612 ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 613 writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 614 } 615 } 616 } 617 618 //------------------------------------------------------------------------------ 619 // 3D derivatives at quadrature points 620 //------------------------------------------------------------------------------ 621 inline __device__ void grad3d(const CeedInt nelem, const int transpose, 622 const CeedScalar *c_B, const CeedScalar *c_G, 623 const CeedScalar *__restrict__ d_U, 624 CeedScalar *__restrict__ d_V, 625 CeedScalar *slice) { 626 // Use P1D for one of these 627 CeedScalar r_U[T1D]; 628 CeedScalar r_V[T1D]; 629 CeedScalar r_t[T1D]; 630 631 const int tidx = threadIdx.x; 632 const int tidy = threadIdx.y; 633 const int tidz = threadIdx.z; 634 const int blockElem = tidz/BASIS_NCOMP; 635 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 636 const int comp = tidz%BASIS_NCOMP; 637 int dim; 638 639 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 640 elem += gridDim.x*elemsPerBlock) { 641 for (int i = 0; i < T1D; ++i) { 642 r_U[i] = 0.0; 643 r_V[i] = 0.0; 644 r_t[i] = 0.0; 645 } 646 if (!transpose) { 647 readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 648 ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 649 ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 650 ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 651 dim = 0; 652 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 653 ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 654 ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t); 655 ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 656 dim = 1; 657 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 658 ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 659 ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 660 ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 661 dim = 2; 662 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 663 } else { 664 dim = 0; 665 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 666 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 667 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 668 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 669 dim = 1; 670 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 671 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 672 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U); 673 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 674 add(r_V, r_t); 675 dim = 2; 676 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 677 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 678 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 679 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 680 add(r_V, r_t); 681 writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 682 } 683 } 684 } 685 686 //------------------------------------------------------------------------------ 687 // 3D quadrature weights 688 //------------------------------------------------------------------------------ 689 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 690 CeedScalar *w) { 691 const int i = threadIdx.x; 692 const int j = threadIdx.y; 693 const int k = threadIdx.z; 694 const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 695 for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 696 const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 697 w[ind] = weight; 698 } 699 } 700 701 //------------------------------------------------------------------------------ 702 // Basis kernels 703 //------------------------------------------------------------------------------ 704 705 //------------------------------------------------------------------------------ 706 // Interp kernel by dim 707 //------------------------------------------------------------------------------ 708 extern "C" __global__ void interp(const CeedInt nelem, const int transpose, 709 const CeedScalar *c_B, 710 const CeedScalar *__restrict__ d_U, 711 CeedScalar *__restrict__ d_V) { 712 extern __shared__ double slice[]; 713 if (BASIS_DIM == 1) { 714 interp1d(nelem, transpose, c_B, d_U, d_V, slice); 715 } else if (BASIS_DIM == 2) { 716 interp2d(nelem, transpose, c_B, d_U, d_V, slice); 717 } else if (BASIS_DIM == 3) { 718 interp3d(nelem, transpose, c_B, d_U, d_V, slice); 719 } 720 } 721 722 //------------------------------------------------------------------------------ 723 // Grad kernel by dim 724 //------------------------------------------------------------------------------ 725 extern "C" __global__ void grad(const CeedInt nelem, const int transpose, 726 const CeedScalar *c_B, const CeedScalar *c_G, 727 const CeedScalar *__restrict__ d_U, 728 CeedScalar *__restrict__ d_V) { 729 extern __shared__ double slice[]; 730 if (BASIS_DIM == 1) { 731 grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 732 } else if (BASIS_DIM == 2) { 733 grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 734 } else if (BASIS_DIM == 3) { 735 grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 736 } 737 } 738 739 //------------------------------------------------------------------------------ 740 // Weight kernels by dim 741 //------------------------------------------------------------------------------ 742 extern "C" __global__ void weight(const CeedInt nelem, 743 const CeedScalar *__restrict__ qweight1d, 744 CeedScalar *__restrict__ v) { 745 if (BASIS_DIM == 1) { 746 weight1d(nelem, qweight1d, v); 747 } else if (BASIS_DIM == 2) { 748 weight2d(nelem, qweight1d, v); 749 } else if (BASIS_DIM == 3) { 750 weight3d(nelem, qweight1d, v); 751 } 752 } 753 754 ); 755 // *INDENT-ON* 756 757 //------------------------------------------------------------------------------ 758 // Device initalization 759 //------------------------------------------------------------------------------ 760 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d, 761 CeedScalar **c_B); 762 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d, 763 CeedInt Q1d, CeedScalar **c_B_ptr, 764 CeedScalar **c_G_ptr); 765 766 //------------------------------------------------------------------------------ 767 // Apply basis 768 //------------------------------------------------------------------------------ 769 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem, 770 CeedTransposeMode tmode, 771 CeedEvalMode emode, CeedVector u, 772 CeedVector v) { 773 int ierr; 774 Ceed ceed; 775 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 776 Ceed_Cuda_shared *ceed_Cuda; 777 CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr); 778 CeedBasis_Cuda_shared *data; 779 CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 780 const CeedInt transpose = tmode == CEED_TRANSPOSE; 781 CeedInt dim, ncomp; 782 ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr); 783 ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 784 785 // Read vectors 786 const CeedScalar *d_u; 787 CeedScalar *d_v; 788 if (emode != CEED_EVAL_WEIGHT) { 789 ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr); 790 } 791 ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr); 792 793 // Clear v for transpose mode 794 if (tmode == CEED_TRANSPOSE) { 795 CeedInt length; 796 ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr); 797 ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr); 798 } 799 800 // Apply basis operation 801 switch (emode) { 802 case CEED_EVAL_INTERP: { 803 CeedInt P1d, Q1d; 804 ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 805 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 806 CeedInt thread1d = CeedIntMax(Q1d, P1d); 807 ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B); 808 CeedChkBackend(ierr); 809 void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 810 &d_u, &d_v 811 }; 812 if (dim == 1) { 813 CeedInt elemsPerBlock = 32; 814 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 815 ? 1 : 0 ); 816 CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 817 ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, 1, 818 elemsPerBlock, sharedMem, 819 interpargs); CeedChkBackend(ierr); 820 } else if (dim == 2) { 821 const CeedInt optElems[7] = {0,32,8,6,4,2,8}; 822 // elemsPerBlock must be at least 1 823 CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1); 824 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 825 ? 1 : 0 ); 826 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 827 ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d, 828 ncomp*elemsPerBlock, sharedMem, 829 interpargs); CeedChkBackend(ierr); 830 } else if (dim == 3) { 831 CeedInt elemsPerBlock = 1; 832 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 833 ? 1 : 0 ); 834 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 835 ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d, 836 ncomp*elemsPerBlock, sharedMem, 837 interpargs); CeedChkBackend(ierr); 838 } 839 } break; 840 case CEED_EVAL_GRAD: { 841 CeedInt P1d, Q1d; 842 ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 843 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 844 CeedInt thread1d = CeedIntMax(Q1d, P1d); 845 ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d, 846 Q1d, &data->c_B, &data->c_G); 847 CeedChkBackend(ierr); 848 void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 849 &data->c_G, &d_u, &d_v 850 }; 851 if (dim == 1) { 852 CeedInt elemsPerBlock = 32; 853 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 854 ? 1 : 0 ); 855 CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 856 ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, 1, 857 elemsPerBlock, sharedMem, gradargs); 858 CeedChkBackend(ierr); 859 } else if (dim == 2) { 860 const CeedInt optElems[7] = {0,32,8,6,4,2,8}; 861 // elemsPerBlock must be at least 1 862 CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1); 863 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 864 ? 1 : 0 ); 865 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 866 ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d, 867 ncomp*elemsPerBlock, sharedMem, 868 gradargs); CeedChkBackend(ierr); 869 } else if (dim == 3) { 870 CeedInt elemsPerBlock = 1; 871 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 872 ? 1 : 0 ); 873 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 874 ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d, 875 ncomp*elemsPerBlock, sharedMem, 876 gradargs); CeedChkBackend(ierr); 877 } 878 } break; 879 case CEED_EVAL_WEIGHT: { 880 CeedInt Q1d; 881 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 882 void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 883 if (dim == 1) { 884 const CeedInt elemsPerBlock = 32/Q1d; 885 const CeedInt gridsize = nelem/elemsPerBlock + ( ( 886 nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 887 ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, 888 elemsPerBlock, 1, weightargs); 889 CeedChkBackend(ierr); 890 } else if (dim == 2) { 891 const CeedInt optElems = 32/(Q1d*Q1d); 892 const CeedInt elemsPerBlock = optElems>0?optElems:1; 893 const CeedInt gridsize = nelem/elemsPerBlock + ( ( 894 nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 895 ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, 896 elemsPerBlock, weightargs); 897 CeedChkBackend(ierr); 898 } else if (dim == 3) { 899 const CeedInt gridsize = nelem; 900 ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 901 weightargs); 902 CeedChkBackend(ierr); 903 } 904 } break; 905 // LCOV_EXCL_START 906 // Evaluate the divergence to/from the quadrature points 907 case CEED_EVAL_DIV: 908 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported"); 909 // Evaluate the curl to/from the quadrature points 910 case CEED_EVAL_CURL: 911 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported"); 912 // Take no action, BasisApply should not have been called 913 case CEED_EVAL_NONE: 914 return CeedError(ceed, CEED_ERROR_BACKEND, 915 "CEED_EVAL_NONE does not make sense in this context"); 916 // LCOV_EXCL_STOP 917 } 918 919 // Restore vectors 920 if (emode != CEED_EVAL_WEIGHT) { 921 ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr); 922 } 923 ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr); 924 return CEED_ERROR_SUCCESS; 925 } 926 927 //------------------------------------------------------------------------------ 928 // Destroy basis 929 //------------------------------------------------------------------------------ 930 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) { 931 int ierr; 932 Ceed ceed; 933 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 934 935 CeedBasis_Cuda_shared *data; 936 ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 937 938 CeedChk_Cu(ceed, cuModuleUnload(data->module)); 939 940 ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr); 941 ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr); 942 ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr); 943 ierr = cudaFree(data->d_collograd1d); CeedChk_Cu(ceed, ierr); 944 945 ierr = CeedFree(&data); CeedChkBackend(ierr); 946 947 return CEED_ERROR_SUCCESS; 948 } 949 950 //------------------------------------------------------------------------------ 951 // Create tensor basis 952 //------------------------------------------------------------------------------ 953 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 954 const CeedScalar *interp1d, 955 const CeedScalar *grad1d, 956 const CeedScalar *qref1d, 957 const CeedScalar *qweight1d, 958 CeedBasis basis) { 959 int ierr; 960 Ceed ceed; 961 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 962 CeedBasis_Cuda_shared *data; 963 ierr = CeedCalloc(1, &data); CeedChkBackend(ierr); 964 965 // Copy basis data to GPU 966 const CeedInt qBytes = Q1d * sizeof(CeedScalar); 967 ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr); 968 ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes, 969 cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 970 971 const CeedInt iBytes = qBytes * P1d; 972 ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr); 973 ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes, 974 cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 975 976 ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr); 977 ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes, 978 cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 979 980 // Compute collocated gradient and copy to GPU 981 data->d_collograd1d = NULL; 982 if (dim == 3 && Q1d >= P1d) { 983 CeedScalar *collograd1d; 984 ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr); 985 ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr); 986 ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 987 CeedChk_Cu(ceed, ierr); 988 ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 989 cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 990 ierr = CeedFree(&collograd1d); CeedChkBackend(ierr); 991 } 992 993 // Compile basis kernels 994 CeedInt ncomp; 995 ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 996 ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 8, 997 "Q1D", Q1d, 998 "P1D", P1d, 999 "T1D", CeedIntMax(Q1d, P1d), 1000 "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 1001 Q1d : P1d, dim), 1002 "BASIS_DIM", dim, 1003 "BASIS_NCOMP", ncomp, 1004 "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 1005 "BASIS_NQPT", CeedIntPow(Q1d, dim) 1006 ); CeedChkBackend(ierr); 1007 ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp); 1008 CeedChkBackend(ierr); 1009 ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad); 1010 CeedChkBackend(ierr); 1011 ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight); 1012 CeedChkBackend(ierr); 1013 1014 ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr); 1015 1016 // Register backend functions 1017 ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 1018 CeedBasisApplyTensor_Cuda_shared); 1019 CeedChkBackend(ierr); 1020 ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 1021 CeedBasisDestroy_Cuda_shared); CeedChkBackend(ierr); 1022 return CEED_ERROR_SUCCESS; 1023 } 1024 //------------------------------------------------------------------------------ 1025