1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include <ceed.h> 18 #include <ceed-backend.h> 19 #include <hip/hip_runtime.h> 20 #include <stddef.h> 21 #include "ceed-hip-shared.h" 22 #include "../hip/ceed-hip.h" 23 #include "../hip/ceed-hip-compile.h" 24 25 //------------------------------------------------------------------------------ 26 // Shared mem kernels 27 //------------------------------------------------------------------------------ 28 // *INDENT-OFF* 29 static const char *kernelsShared = QUOTE( 30 31 //------------------------------------------------------------------------------ 32 // Sum input into output 33 //------------------------------------------------------------------------------ 34 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 35 for (int i = 0; i < P1D; i++) 36 r_V[i] += r_U[i]; 37 } 38 39 //------------------------------------------------------------------------------ 40 // 1D 41 //------------------------------------------------------------------------------ 42 43 //------------------------------------------------------------------------------ 44 // Read DoFs 45 //------------------------------------------------------------------------------ 46 inline __device__ void readDofs1d(const int elem, const int tidx, 47 const int tidy, const int tidz,const int comp, 48 const int nelem, const CeedScalar *d_U, 49 CeedScalar *slice) { 50 for (int i = 0; i < P1D; i++) 51 slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 52 for (int i = P1D; i < Q1D; i++) 53 slice[i + tidz*T1D] = 0.0; 54 } 55 56 //------------------------------------------------------------------------------ 57 // Write DoFs 58 //------------------------------------------------------------------------------ 59 inline __device__ void writeDofs1d(const int elem, const int tidx, 60 const int tidy, const int comp, 61 const int nelem, const CeedScalar &r_V, 62 CeedScalar *d_V) { 63 if (tidx<P1D) 64 d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 65 } 66 67 //------------------------------------------------------------------------------ 68 // Read quadrature point data 69 //------------------------------------------------------------------------------ 70 inline __device__ void readQuads1d(const int elem, const int tidx, 71 const int tidy, const int tidz, const int comp, 72 const int dim, const int nelem, 73 const CeedScalar *d_U, CeedScalar *slice) { 74 for (int i = 0; i < Q1D; i++) 75 slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 76 dim*BASIS_NCOMP*nelem*Q1D]; 77 for (int i = Q1D; i < P1D; i++) 78 slice[i + tidz*T1D] = 0.0; 79 } 80 81 //------------------------------------------------------------------------------ 82 // Write quadrature point data 83 //------------------------------------------------------------------------------ 84 inline __device__ void writeQuads1d(const int elem, const int tidx, 85 const int tidy, const int comp, 86 const int dim, const int nelem, 87 const CeedScalar &r_V, CeedScalar *d_V) { 88 if (tidx<Q1D) 89 d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 90 } 91 92 //------------------------------------------------------------------------------ 93 // 1D tensor contraction 94 //------------------------------------------------------------------------------ 95 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 96 const int tidy, const int tidz, 97 const CeedScalar &U, const CeedScalar *B, 98 CeedScalar &V) { 99 V = 0.0; 100 for (int i = 0; i < P1D; ++i) 101 V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 102 } 103 104 //------------------------------------------------------------------------------ 105 // 1D transpose tensor contraction 106 //------------------------------------------------------------------------------ 107 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 108 const int tidy, const int tidz, 109 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 110 V = 0.0; 111 for (int i = 0; i < Q1D; ++i) 112 V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 113 } 114 115 //------------------------------------------------------------------------------ 116 // 1D interpolate to quadrature points 117 //------------------------------------------------------------------------------ 118 inline __device__ void interp1d(const CeedInt nelem, const int transpose, 119 const CeedScalar *c_B, 120 const CeedScalar *__restrict__ d_U, 121 CeedScalar *__restrict__ d_V, 122 CeedScalar *slice) { 123 CeedScalar r_V; 124 CeedScalar r_t; 125 126 const int tidx = threadIdx.x; 127 const int tidy = threadIdx.y; 128 const int tidz = threadIdx.z; 129 130 131 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 132 elem += gridDim.x*blockDim.z) { 133 for (int comp = 0; comp < BASIS_NCOMP; comp++) { 134 if (!transpose) { 135 readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 136 ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 137 writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 138 } else { 139 readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 140 ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 141 writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 142 } 143 } 144 } 145 } 146 147 //------------------------------------------------------------------------------ 148 // 1D derivatives at quadrature points 149 //------------------------------------------------------------------------------ 150 inline __device__ void grad1d(const CeedInt nelem, const int transpose, 151 const CeedScalar *c_B, const CeedScalar *c_G, 152 const CeedScalar *__restrict__ d_U, 153 CeedScalar *__restrict__ d_V, 154 CeedScalar *slice) { 155 CeedScalar r_U; 156 CeedScalar r_V; 157 158 const int tidx = threadIdx.x; 159 const int tidy = threadIdx.y; 160 const int tidz = threadIdx.z; 161 int dim; 162 163 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 164 elem += gridDim.x*blockDim.z) { 165 for(int comp = 0; comp < BASIS_NCOMP; comp++) { 166 if (!transpose) { 167 readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 168 ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 169 dim = 0; 170 writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 171 } else { 172 dim = 0; 173 readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 174 ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 175 writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 176 } 177 } 178 } 179 } 180 181 //------------------------------------------------------------------------------ 182 // 1D Quadrature weights 183 //------------------------------------------------------------------------------ 184 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 185 CeedScalar *w) { 186 const int tid = threadIdx.x; 187 const CeedScalar weight = qweight1d[tid]; 188 for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 189 elem += gridDim.x*blockDim.y) { 190 const int ind = elem*Q1D + tid; 191 w[ind] = weight; 192 } 193 } 194 195 //------------------------------------------------------------------------------ 196 // 2D 197 //------------------------------------------------------------------------------ 198 199 //------------------------------------------------------------------------------ 200 // Read DoFs 201 //------------------------------------------------------------------------------ 202 inline __device__ void readDofs2d(const int elem, const int tidx, 203 const int tidy, const int comp, 204 const int nelem, const CeedScalar *d_U, 205 CeedScalar &U) { 206 U = (tidx<P1D && tidy<P1D) ? 207 d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 208 } 209 210 //------------------------------------------------------------------------------ 211 // Write DoFs 212 //------------------------------------------------------------------------------ 213 inline __device__ void writeDofs2d(const int elem, const int tidx, 214 const int tidy, const int comp, 215 const int nelem, const CeedScalar &r_V, 216 CeedScalar *d_V) { 217 if (tidx<P1D && tidy<P1D) 218 d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 219 } 220 221 //------------------------------------------------------------------------------ 222 // Read quadrature point data 223 //------------------------------------------------------------------------------ 224 inline __device__ void readQuads2d(const int elem, const int tidx, 225 const int tidy, const int comp, 226 const int dim, const int nelem, 227 const CeedScalar *d_U, CeedScalar &U ) { 228 U = (tidx<Q1D && tidy<Q1D) ? 229 d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 230 dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 231 } 232 233 //------------------------------------------------------------------------------ 234 // Write quadrature point data 235 //------------------------------------------------------------------------------ 236 inline __device__ void writeQuads2d(const int elem, const int tidx, 237 const int tidy, const int comp, 238 const int dim, const int nelem, 239 const CeedScalar &r_V, CeedScalar *d_V) { 240 if (tidx<Q1D && tidy<Q1D) 241 d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 242 dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 243 } 244 245 //------------------------------------------------------------------------------ 246 // 2D tensor contraction x 247 //------------------------------------------------------------------------------ 248 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 249 const int tidy, const int tidz, 250 const CeedScalar &U, const CeedScalar *B, 251 CeedScalar &V) { 252 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 253 __syncthreads(); 254 V = 0.0; 255 if (tidx < Q1D) 256 for (int i = 0; i < P1D; ++i) 257 V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 258 __syncthreads(); 259 } 260 261 //------------------------------------------------------------------------------ 262 // 2D tensor contraction y 263 //------------------------------------------------------------------------------ 264 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 265 const int tidy, const int tidz, 266 const CeedScalar &U, const CeedScalar *B, 267 CeedScalar &V) { 268 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 269 __syncthreads(); 270 V = 0.0; 271 if (tidy < Q1D) 272 for (int i = 0; i < P1D; ++i) 273 V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 274 __syncthreads(); 275 } 276 277 //------------------------------------------------------------------------------ 278 // 2D transpose tensor contraction y 279 //------------------------------------------------------------------------------ 280 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 281 const int tidy, const int tidz, 282 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 283 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 284 __syncthreads(); 285 V = 0.0; 286 if (tidy < P1D) 287 for (int i = 0; i < Q1D; ++i) 288 V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 289 __syncthreads(); 290 } 291 292 //------------------------------------------------------------------------------ 293 // 2D transpose tensor contraction x 294 //------------------------------------------------------------------------------ 295 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 296 const int tidy, const int tidz, 297 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 298 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 299 __syncthreads(); 300 V = 0.0; 301 if (tidx < P1D) 302 for (int i = 0; i < Q1D; ++i) 303 V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 304 __syncthreads(); 305 } 306 307 //------------------------------------------------------------------------------ 308 // 2D interpolate to quadrature points 309 //------------------------------------------------------------------------------ 310 inline __device__ void interp2d(const CeedInt nelem, const int transpose, 311 const CeedScalar *c_B, 312 const CeedScalar *__restrict__ d_U, 313 CeedScalar *__restrict__ d_V, 314 CeedScalar *slice) { 315 CeedScalar r_V; 316 CeedScalar r_t; 317 318 const int tidx = threadIdx.x; 319 const int tidy = threadIdx.y; 320 const int tidz = threadIdx.z; 321 const int blockElem = tidz/BASIS_NCOMP; 322 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 323 const int comp = tidz%BASIS_NCOMP; 324 325 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 326 elem += gridDim.x*elemsPerBlock) { 327 const int comp = tidz%BASIS_NCOMP; 328 r_V = 0.0; 329 r_t = 0.0; 330 if (!transpose) { 331 readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 332 ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 333 ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 334 writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 335 } else { 336 readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 337 ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 338 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 339 writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 340 } 341 } 342 } 343 344 //------------------------------------------------------------------------------ 345 // 2D derivatives at quadrature points 346 //------------------------------------------------------------------------------ 347 inline __device__ void grad2d(const CeedInt nelem, const int transpose, 348 const CeedScalar *c_B, const CeedScalar *c_G, 349 const CeedScalar *__restrict__ d_U, 350 CeedScalar *__restrict__ d_V, CeedScalar *slice) { 351 CeedScalar r_U; 352 CeedScalar r_V; 353 CeedScalar r_t; 354 355 const int tidx = threadIdx.x; 356 const int tidy = threadIdx.y; 357 const int tidz = threadIdx.z; 358 const int blockElem = tidz/BASIS_NCOMP; 359 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 360 const int comp = tidz%BASIS_NCOMP; 361 int dim; 362 363 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 364 elem += gridDim.x*elemsPerBlock) { 365 if (!transpose) { 366 readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 367 ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 368 ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 369 dim = 0; 370 writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 371 ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 372 ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 373 dim = 1; 374 writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 375 } else { 376 dim = 0; 377 readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 378 ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 379 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 380 dim = 1; 381 readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 382 ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 383 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 384 r_V += r_U; 385 writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 386 } 387 } 388 } 389 390 //------------------------------------------------------------------------------ 391 // 2D quadrature weights 392 //------------------------------------------------------------------------------ 393 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 394 CeedScalar *w) { 395 const int i = threadIdx.x; 396 const int j = threadIdx.y; 397 const CeedScalar weight = qweight1d[i]*qweight1d[j]; 398 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 399 elem += gridDim.x*blockDim.z) { 400 const int ind = elem*Q1D*Q1D + i + j*Q1D; 401 w[ind] = weight; 402 } 403 } 404 405 //------------------------------------------------------------------------------ 406 // 3D 407 //------------------------------------------------------------------------------ 408 409 //------------------------------------------------------------------------------ 410 // Read DoFs 411 //------------------------------------------------------------------------------ 412 inline __device__ void readDofs3d(const int elem, const int tidx, 413 const int tidy, const int comp, 414 const int nelem, const CeedScalar *d_U, 415 CeedScalar *r_U) { 416 for (int i = 0; i < P1D; i++) 417 r_U[i] = (tidx < P1D && tidy < P1D) ? 418 d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 419 comp*P1D*P1D*P1D*nelem] : 0.0; 420 for (int i = P1D; i < Q1D; i++) 421 r_U[i] = 0.0; 422 } 423 424 //------------------------------------------------------------------------------ 425 // Write DoFs 426 //------------------------------------------------------------------------------ 427 inline __device__ void writeDofs3d(const int elem, const int tidx, 428 const int tidy, const int comp, 429 const int nelem, const CeedScalar *r_V, 430 CeedScalar *d_V) { 431 if (tidx < P1D && tidy < P1D) { 432 for (int i = 0; i < P1D; i++) 433 d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 434 comp*P1D*P1D*P1D*nelem] = r_V[i]; 435 } 436 } 437 438 //------------------------------------------------------------------------------ 439 // Read quadrature point data 440 //------------------------------------------------------------------------------ 441 inline __device__ void readQuads3d(const int elem, const int tidx, 442 const int tidy, const int comp, 443 const int dim, const int nelem, 444 const CeedScalar *d_U, CeedScalar *r_U) { 445 for (int i = 0; i < Q1D; i++) 446 r_U[i] = (tidx < Q1D && tidy < Q1D) ? 447 d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 448 comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 449 for (int i = Q1D; i < P1D; i++) 450 r_U[i] = 0.0; 451 } 452 453 //------------------------------------------------------------------------------ 454 // Write quadrature point data 455 //------------------------------------------------------------------------------ 456 inline __device__ void writeQuads3d(const int elem, const int tidx, 457 const int tidy, const int comp, 458 const int dim, const int nelem, 459 const CeedScalar *r_V, CeedScalar *d_V) { 460 if (tidx < Q1D && tidy < Q1D) { 461 for (int i = 0; i < Q1D; i++) 462 d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 463 dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 464 } 465 } 466 467 //------------------------------------------------------------------------------ 468 // 3D tensor contract x 469 //------------------------------------------------------------------------------ 470 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 471 const int tidy, const int tidz, 472 const CeedScalar *U, 473 const CeedScalar *B, 474 CeedScalar *V) { 475 for (int k = 0; k < P1D; ++k) { 476 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 477 __syncthreads(); 478 V[k] = 0.0; 479 if (tidx < Q1D && tidy < P1D) 480 for (int i = 0; i < P1D; ++i) 481 V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 482 __syncthreads(); 483 } 484 } 485 486 //------------------------------------------------------------------------------ 487 // 3D tensor contract y 488 //------------------------------------------------------------------------------ 489 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 490 const int tidy, const int tidz, 491 const CeedScalar *U, 492 const CeedScalar *B, 493 CeedScalar *V) { 494 for (int k = 0; k < P1D; ++k) { 495 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 496 __syncthreads(); 497 V[k] = 0.0; 498 if (tidx < Q1D && tidy < Q1D) 499 for (int i = 0; i < P1D; ++i) 500 V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 501 __syncthreads(); 502 } 503 } 504 505 //------------------------------------------------------------------------------ 506 // 3D tensor contract z 507 //------------------------------------------------------------------------------ 508 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 509 const int tidy, const int tidz, 510 const CeedScalar *U, 511 const CeedScalar *B, 512 CeedScalar *V) { 513 for (int k = 0; k < Q1D; ++k) { 514 V[k] = 0.0; 515 if (tidx < Q1D && tidy < Q1D) 516 for (int i = 0; i < P1D; ++i) 517 V[k] += B[i + k*P1D] * U[i]; // Contract z direction 518 } 519 for (int k = Q1D; k < P1D; ++k) 520 V[k] = 0.0; 521 } 522 523 //------------------------------------------------------------------------------ 524 // 3D transpose tensor contract z 525 //------------------------------------------------------------------------------ 526 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 527 const int tidy, const int tidz, 528 const CeedScalar *U, 529 const CeedScalar *B, 530 CeedScalar *V) { 531 for (int k = 0; k < P1D; ++k) { 532 V[k] = 0.0; 533 if (tidx < Q1D && tidy < Q1D) 534 for (int i = 0; i < Q1D; ++i) 535 V[k] += B[k + i*P1D] * U[i]; // Contract z direction 536 } 537 for (int k = P1D; k < Q1D; ++k) 538 V[k] = 0.0; 539 } 540 541 //------------------------------------------------------------------------------ 542 // 3D transpose tensor contract y 543 //------------------------------------------------------------------------------ 544 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 545 const int tidy, const int tidz, 546 const CeedScalar *U, 547 const CeedScalar *B, 548 CeedScalar *V) { 549 for (int k = 0; k < P1D; ++k) { 550 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 551 __syncthreads(); 552 V[k] = 0.0; 553 if (tidx < Q1D && tidy < P1D) 554 for (int i = 0; i < Q1D; ++i) 555 V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 556 __syncthreads(); 557 } 558 } 559 560 //------------------------------------------------------------------------------ 561 // 3D transpose tensor contract x 562 //------------------------------------------------------------------------------ 563 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 564 const int tidy, const int tidz, 565 const CeedScalar *U, 566 const CeedScalar *B, 567 CeedScalar *V) { 568 for (int k = 0; k < P1D; ++k) { 569 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 570 __syncthreads(); 571 V[k] = 0.0; 572 if (tidx < P1D && tidy < P1D) 573 for (int i = 0; i < Q1D; ++i) 574 V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 575 __syncthreads(); 576 } 577 } 578 579 //------------------------------------------------------------------------------ 580 // 3D interpolate to quadrature points 581 //------------------------------------------------------------------------------ 582 inline __device__ void interp3d(const CeedInt nelem, const int transpose, 583 const CeedScalar *c_B, 584 const CeedScalar *__restrict__ d_U, 585 CeedScalar *__restrict__ d_V, 586 CeedScalar *slice) { 587 CeedScalar r_V[T1D]; 588 CeedScalar r_t[T1D]; 589 590 const int tidx = threadIdx.x; 591 const int tidy = threadIdx.y; 592 const int tidz = threadIdx.z; 593 const int blockElem = tidz/BASIS_NCOMP; 594 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 595 const int comp = tidz%BASIS_NCOMP; 596 597 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 598 elem += gridDim.x*elemsPerBlock) { 599 for (int i = 0; i < T1D; ++i) { 600 r_V[i] = 0.0; 601 r_t[i] = 0.0; 602 } 603 if (!transpose) { 604 readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 605 ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 606 ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 607 ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 608 writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 609 } else { 610 readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 611 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 612 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 613 ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 614 writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 615 } 616 } 617 } 618 619 //------------------------------------------------------------------------------ 620 // 3D derivatives at quadrature points 621 //------------------------------------------------------------------------------ 622 inline __device__ void grad3d(const CeedInt nelem, const int transpose, 623 const CeedScalar *c_B, const CeedScalar *c_G, 624 const CeedScalar *__restrict__ d_U, 625 CeedScalar *__restrict__ d_V, 626 CeedScalar *slice) { 627 // Use P1D for one of these 628 CeedScalar r_U[T1D]; 629 CeedScalar r_V[T1D]; 630 CeedScalar r_t[T1D]; 631 632 const int tidx = threadIdx.x; 633 const int tidy = threadIdx.y; 634 const int tidz = threadIdx.z; 635 const int blockElem = tidz/BASIS_NCOMP; 636 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 637 const int comp = tidz%BASIS_NCOMP; 638 int dim; 639 640 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 641 elem += gridDim.x*elemsPerBlock) { 642 for (int i = 0; i < T1D; ++i) { 643 r_U[i] = 0.0; 644 r_V[i] = 0.0; 645 r_t[i] = 0.0; 646 } 647 if (!transpose) { 648 readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 649 ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 650 ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 651 ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 652 dim = 0; 653 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 654 ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 655 ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t); 656 ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 657 dim = 1; 658 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 659 ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 660 ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 661 ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 662 dim = 2; 663 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 664 } else { 665 dim = 0; 666 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 667 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 668 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 669 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 670 dim = 1; 671 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 672 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 673 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U); 674 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 675 add(r_V, r_t); 676 dim = 2; 677 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 678 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 679 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 680 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 681 add(r_V, r_t); 682 writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 683 } 684 } 685 } 686 687 //------------------------------------------------------------------------------ 688 // 3D quadrature weights 689 //------------------------------------------------------------------------------ 690 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 691 CeedScalar *w) { 692 const int i = threadIdx.x; 693 const int j = threadIdx.y; 694 const int k = threadIdx.z; 695 const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 696 for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 697 const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 698 w[ind] = weight; 699 } 700 } 701 702 703 //------------------------------------------------------------------------------ 704 // Basis kernels 705 //------------------------------------------------------------------------------ 706 707 //------------------------------------------------------------------------------ 708 // Interp kernel by dim 709 //------------------------------------------------------------------------------ 710 extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp( 711 const CeedInt nelem, const int transpose, 712 const CeedScalar *c_B, 713 const CeedScalar *__restrict__ d_U, 714 CeedScalar *__restrict__ d_V) { 715 HIP_DYNAMIC_SHARED( double, slice) 716 if (BASIS_DIM == 1) { 717 interp1d(nelem, transpose, c_B, d_U, d_V, slice); 718 } else if (BASIS_DIM == 2) { 719 interp2d(nelem, transpose, c_B, d_U, d_V, slice); 720 } else if (BASIS_DIM == 3) { 721 interp3d(nelem, transpose, c_B, d_U, d_V, slice); 722 } 723 } 724 725 //------------------------------------------------------------------------------ 726 // Grad kernel by dim 727 //------------------------------------------------------------------------------ 728 extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem, 729 const int transpose, 730 const CeedScalar *c_B, const CeedScalar *c_G, 731 const CeedScalar *__restrict__ d_U, 732 CeedScalar *__restrict__ d_V) { 733 HIP_DYNAMIC_SHARED( double, slice) 734 if (BASIS_DIM == 1) { 735 grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 736 } else if (BASIS_DIM == 2) { 737 grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 738 } else if (BASIS_DIM == 3) { 739 grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 740 } 741 } 742 743 //------------------------------------------------------------------------------ 744 // Weight kernels by dim 745 //------------------------------------------------------------------------------ 746 extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem, 747 const CeedScalar *__restrict__ qweight1d, 748 CeedScalar *__restrict__ v) { 749 if (BASIS_DIM == 1) { 750 weight1d(nelem, qweight1d, v); 751 } else if (BASIS_DIM == 2) { 752 weight2d(nelem, qweight1d, v); 753 } else if (BASIS_DIM == 3) { 754 weight3d(nelem, qweight1d, v); 755 } 756 } 757 758 ); 759 // *INDENT-ON* 760 761 //------------------------------------------------------------------------------ 762 // Compute a block size based on required minimum threads 763 //------------------------------------------------------------------------------ 764 static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) { 765 CeedInt maxSize = 1024; // Max total threads per block 766 CeedInt currentSize = 64; // Start with one group 767 768 while(currentSize < maxSize) { 769 if (currentSize > required) 770 break; 771 else 772 currentSize = currentSize * 2; 773 } 774 return currentSize; 775 } 776 777 //------------------------------------------------------------------------------ 778 // Compute required thread block sizes for basis kernels given P, Q, dim, and 779 // ncomp 780 //------------------------------------------------------------------------------ 781 static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d, 782 const CeedInt Q1d, 783 const CeedInt ncomp, CeedInt *blksizes) { 784 785 // Note that this will use the same block sizes for all dimensions when compiling, 786 // but as each basis object is defined for a particular dimension, we will never 787 // call any kernels except the ones for the dimension for which we have computed the 788 // block sizes. 789 const CeedInt thread1d = CeedIntMax(P1d, Q1d); 790 switch (dim) { 791 case 1: { 792 // Interp kernels: 793 blksizes[0] = 256; 794 795 // Grad kernels: 796 blksizes[1] = 256; 797 798 // Weight kernels: 799 blksizes[2] = 256; 800 801 } break; 802 case 2: { 803 // Interp kernels: 804 CeedInt required = thread1d * thread1d * ncomp; 805 blksizes[0] = ComputeBlockSizeFromRequirement(required); 806 807 // Grad kernels: currently use same required minimum threads 808 blksizes[1] = ComputeBlockSizeFromRequirement(required); 809 810 // Weight kernels: 811 required = CeedIntMax(64, Q1d * Q1d); 812 blksizes[2] = ComputeBlockSizeFromRequirement(required); 813 814 } break; 815 case 3: { 816 // Interp kernels: 817 CeedInt required = thread1d * thread1d * ncomp; 818 blksizes[0] = ComputeBlockSizeFromRequirement(required); 819 820 // Grad kernels: currently use same required minimum threads 821 blksizes[1] = ComputeBlockSizeFromRequirement(required); 822 823 // Weight kernels: 824 required = Q1d * Q1d * Q1d; 825 blksizes[2] = ComputeBlockSizeFromRequirement(required); 826 } 827 } 828 829 return 0; 830 } 831 832 //------------------------------------------------------------------------------ 833 // Apply basis 834 //------------------------------------------------------------------------------ 835 int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem, 836 CeedTransposeMode tmode, 837 CeedEvalMode emode, CeedVector u, 838 CeedVector v) { 839 int ierr; 840 Ceed ceed; 841 ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 842 Ceed_Hip_shared *ceed_Hip; 843 CeedGetData(ceed, &ceed_Hip); CeedChk(ierr); 844 CeedBasis_Hip_shared *data; 845 CeedBasisGetData(basis, &data); CeedChk(ierr); 846 const CeedInt transpose = tmode == CEED_TRANSPOSE; 847 CeedInt dim, ncomp; 848 ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr); 849 ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 850 851 // Read vectors 852 const CeedScalar *d_u; 853 CeedScalar *d_v; 854 if (emode != CEED_EVAL_WEIGHT) { 855 ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr); 856 } 857 ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr); 858 859 // Clear v for transpose mode 860 if (tmode == CEED_TRANSPOSE) { 861 CeedInt length; 862 ierr = CeedVectorGetLength(v, &length); CeedChk(ierr); 863 ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr); 864 } 865 866 // Apply basis operation 867 switch (emode) { 868 case CEED_EVAL_INTERP: { 869 CeedInt P1d, Q1d; 870 CeedInt blksize = data->blksizes[0]; 871 ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 872 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 873 CeedInt thread1d = CeedIntMax(Q1d, P1d); 874 ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B); 875 CeedChk(ierr); 876 void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 877 &d_u, &d_v 878 }; 879 if (dim == 1) { 880 CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 881 elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 882 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 883 ? 1 : 0 ); 884 CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 885 ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1, 886 elemsPerBlock, sharedMem, 887 interpargs); CeedChk(ierr); 888 } else if (dim == 2) { 889 // Check if required threads is small enough to do multiple elems 890 const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 891 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 892 ? 1 : 0 ); 893 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 894 ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 895 ncomp*elemsPerBlock, sharedMem, 896 interpargs); CeedChk(ierr); 897 } else if (dim == 3) { 898 CeedInt elemsPerBlock = 1; 899 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 900 ? 1 : 0 ); 901 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 902 ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 903 ncomp*elemsPerBlock, sharedMem, 904 interpargs); CeedChk(ierr); 905 } 906 } break; 907 case CEED_EVAL_GRAD: { 908 CeedInt P1d, Q1d; 909 CeedInt blksize = data->blksizes[1]; 910 ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 911 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 912 CeedInt thread1d = CeedIntMax(Q1d, P1d); 913 ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d, 914 Q1d, &data->c_B, &data->c_G); 915 CeedChk(ierr); 916 void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 917 &data->c_G, &d_u, &d_v 918 }; 919 if (dim == 1) { 920 CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 921 elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 922 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 923 ? 1 : 0 ); 924 CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 925 ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1, 926 elemsPerBlock, sharedMem, gradargs); 927 CeedChk(ierr); 928 } else if (dim == 2) { 929 // Check if required threads is small enough to do multiple elems 930 const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 931 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 932 ? 1 : 0 ); 933 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 934 ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 935 ncomp*elemsPerBlock, sharedMem, 936 gradargs); CeedChk(ierr); 937 } else if (dim == 3) { 938 CeedInt elemsPerBlock = 1; 939 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 940 ? 1 : 0 ); 941 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 942 ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 943 ncomp*elemsPerBlock, sharedMem, 944 gradargs); CeedChk(ierr); 945 } 946 } break; 947 case CEED_EVAL_WEIGHT: { 948 CeedInt Q1d; 949 CeedInt blksize = data->blksizes[2]; 950 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 951 void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 952 if (dim == 1) { 953 const CeedInt optElems = blksize/Q1d; 954 const CeedInt elemsPerBlock = optElems>0?optElems:1; 955 const CeedInt gridsize = nelem/elemsPerBlock + ( ( 956 nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 957 ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, 958 elemsPerBlock, 1, weightargs); 959 CeedChk(ierr); 960 } else if (dim == 2) { 961 const CeedInt optElems = blksize/(Q1d*Q1d); 962 const CeedInt elemsPerBlock = optElems>0?optElems:1; 963 const CeedInt gridsize = nelem/elemsPerBlock + ( ( 964 nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 965 ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, 966 elemsPerBlock, weightargs); 967 CeedChk(ierr); 968 } else if (dim == 3) { 969 const CeedInt gridsize = nelem; 970 ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 971 weightargs); 972 CeedChk(ierr); 973 } 974 } break; 975 // LCOV_EXCL_START 976 // Evaluate the divergence to/from the quadrature points 977 case CEED_EVAL_DIV: 978 return CeedError(ceed, 1, "CEED_EVAL_DIV not supported"); 979 // Evaluate the curl to/from the quadrature points 980 case CEED_EVAL_CURL: 981 return CeedError(ceed, 1, "CEED_EVAL_CURL not supported"); 982 // Take no action, BasisApply should not have been called 983 case CEED_EVAL_NONE: 984 return CeedError(ceed, 1, 985 "CEED_EVAL_NONE does not make sense in this context"); 986 // LCOV_EXCL_STOP 987 } 988 989 // Restore vectors 990 if (emode != CEED_EVAL_WEIGHT) { 991 ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr); 992 } 993 ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr); 994 return 0; 995 } 996 997 //------------------------------------------------------------------------------ 998 // Destroy basis 999 //------------------------------------------------------------------------------ 1000 static int CeedBasisDestroy_Hip_shared(CeedBasis basis) { 1001 int ierr; 1002 Ceed ceed; 1003 ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 1004 1005 CeedBasis_Hip_shared *data; 1006 ierr = CeedBasisGetData(basis, &data); CeedChk(ierr); 1007 1008 CeedChk_Hip(ceed, hipModuleUnload(data->module)); 1009 1010 ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr); 1011 ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr); 1012 ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr); 1013 ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr); 1014 1015 ierr = CeedFree(&data); CeedChk(ierr); 1016 1017 return 0; 1018 } 1019 1020 //------------------------------------------------------------------------------ 1021 // Create tensor basis 1022 //------------------------------------------------------------------------------ 1023 int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 1024 const CeedScalar *interp1d, 1025 const CeedScalar *grad1d, 1026 const CeedScalar *qref1d, 1027 const CeedScalar *qweight1d, 1028 CeedBasis basis) { 1029 int ierr; 1030 Ceed ceed; 1031 ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 1032 CeedBasis_Hip_shared *data; 1033 ierr = CeedCalloc(1, &data); CeedChk(ierr); 1034 1035 // Copy basis data to GPU 1036 const CeedInt qBytes = Q1d * sizeof(CeedScalar); 1037 ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr); 1038 ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes, 1039 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1040 1041 const CeedInt iBytes = qBytes * P1d; 1042 ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr); 1043 ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes, 1044 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1045 1046 ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr); 1047 ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes, 1048 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1049 1050 // Compute collocated gradient and copy to GPU 1051 data->d_collograd1d = NULL; 1052 if (dim == 3 && Q1d >= P1d) { 1053 CeedScalar *collograd1d; 1054 ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr); 1055 ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr); 1056 ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 1057 CeedChk_Hip(ceed, ierr); 1058 ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 1059 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1060 ierr = CeedFree(&collograd1d); CeedChk(ierr); 1061 } 1062 1063 // Set number of threads per block for basis kernels 1064 CeedInt ncomp; 1065 ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 1066 ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes); 1067 CeedChk(ierr); 1068 1069 // Compile basis kernels 1070 ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11, 1071 "Q1D", Q1d, 1072 "P1D", P1d, 1073 "T1D", CeedIntMax(Q1d, P1d), 1074 "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 1075 Q1d : P1d, dim), 1076 "BASIS_DIM", dim, 1077 "BASIS_NCOMP", ncomp, 1078 "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 1079 "BASIS_NQPT", CeedIntPow(Q1d, dim), 1080 "INTERP_BLKSIZE", data->blksizes[0], 1081 "GRAD_BLKSIZE", data->blksizes[1], 1082 "WEIGHT_BLKSIZE", data->blksizes[2] 1083 ); CeedChk(ierr); 1084 ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp); 1085 CeedChk(ierr); 1086 ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad); 1087 CeedChk(ierr); 1088 ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight); 1089 CeedChk(ierr); 1090 1091 ierr = CeedBasisSetData(basis, data); CeedChk(ierr); 1092 1093 // Register backend functions 1094 ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 1095 CeedBasisApplyTensor_Hip_shared); 1096 CeedChk(ierr); 1097 ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 1098 CeedBasisDestroy_Hip_shared); CeedChk(ierr); 1099 return 0; 1100 } 1101 //------------------------------------------------------------------------------ 1102