1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include <ceed.h> 18 #include <ceed-backend.h> 19 #include <hip/hip_runtime.h> 20 #include <stddef.h> 21 #include "ceed-hip-shared.h" 22 #include "../hip/ceed-hip.h" 23 #include "../hip/ceed-hip-compile.h" 24 25 //------------------------------------------------------------------------------ 26 // Shared mem kernels 27 //------------------------------------------------------------------------------ 28 // *INDENT-OFF* 29 static const char *kernelsShared = QUOTE( 30 31 //------------------------------------------------------------------------------ 32 // Sum input into output 33 //------------------------------------------------------------------------------ 34 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 35 for (int i = 0; i < P1D; i++) 36 r_V[i] += r_U[i]; 37 } 38 39 //------------------------------------------------------------------------------ 40 // Load matrices for basis actions 41 //------------------------------------------------------------------------------ 42 inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) { 43 CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x; 44 for (CeedInt i = tid; i < P1D*Q1D; i += blockDim.x*blockDim.y*blockDim.z) 45 B[i] = d_B[i]; 46 } 47 48 //------------------------------------------------------------------------------ 49 // 1D 50 //------------------------------------------------------------------------------ 51 52 //------------------------------------------------------------------------------ 53 // Read DoFs 54 //------------------------------------------------------------------------------ 55 inline __device__ void readDofs1d(const int elem, const int tidx, 56 const int tidy, const int tidz,const int comp, 57 const int nelem, const CeedScalar *d_U, 58 CeedScalar *slice) { 59 for (int i = 0; i < P1D; i++) 60 slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 61 for (int i = P1D; i < Q1D; i++) 62 slice[i + tidz*T1D] = 0.0; 63 } 64 65 //------------------------------------------------------------------------------ 66 // Write DoFs 67 //------------------------------------------------------------------------------ 68 inline __device__ void writeDofs1d(const int elem, const int tidx, 69 const int tidy, const int comp, 70 const int nelem, const CeedScalar &r_V, 71 CeedScalar *d_V) { 72 if (tidx<P1D) 73 d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 74 } 75 76 //------------------------------------------------------------------------------ 77 // Read quadrature point data 78 //------------------------------------------------------------------------------ 79 inline __device__ void readQuads1d(const int elem, const int tidx, 80 const int tidy, const int tidz, const int comp, 81 const int dim, const int nelem, 82 const CeedScalar *d_U, CeedScalar *slice) { 83 for (int i = 0; i < Q1D; i++) 84 slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 85 dim*BASIS_NCOMP*nelem*Q1D]; 86 for (int i = Q1D; i < P1D; i++) 87 slice[i + tidz*T1D] = 0.0; 88 } 89 90 //------------------------------------------------------------------------------ 91 // Write quadrature point data 92 //------------------------------------------------------------------------------ 93 inline __device__ void writeQuads1d(const int elem, const int tidx, 94 const int tidy, const int comp, 95 const int dim, const int nelem, 96 const CeedScalar &r_V, CeedScalar *d_V) { 97 if (tidx<Q1D) 98 d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 99 } 100 101 //------------------------------------------------------------------------------ 102 // 1D tensor contraction 103 //------------------------------------------------------------------------------ 104 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 105 const int tidy, const int tidz, 106 const CeedScalar &U, const CeedScalar *B, 107 CeedScalar &V) { 108 V = 0.0; 109 for (int i = 0; i < P1D; ++i) 110 V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 111 } 112 113 //------------------------------------------------------------------------------ 114 // 1D transpose tensor contraction 115 //------------------------------------------------------------------------------ 116 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 117 const int tidy, const int tidz, 118 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 119 V = 0.0; 120 for (int i = 0; i < Q1D; ++i) 121 V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 122 } 123 124 //------------------------------------------------------------------------------ 125 // 1D interpolate to quadrature points 126 //------------------------------------------------------------------------------ 127 inline __device__ void interp1d(const CeedInt nelem, const int transpose, 128 const CeedScalar *s_B, 129 const CeedScalar *__restrict__ d_U, 130 CeedScalar *__restrict__ d_V, 131 CeedScalar *slice) { 132 CeedScalar r_V; 133 CeedScalar r_t; 134 135 const int tidx = threadIdx.x; 136 const int tidy = threadIdx.y; 137 const int tidz = threadIdx.z; 138 139 140 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 141 elem += gridDim.x*blockDim.z) { 142 for (int comp = 0; comp < BASIS_NCOMP; comp++) { 143 if (!transpose) { 144 readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 145 ContractX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 146 writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 147 } else { 148 readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 149 ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 150 writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 151 } 152 } 153 } 154 } 155 156 //------------------------------------------------------------------------------ 157 // 1D derivatives at quadrature points 158 //------------------------------------------------------------------------------ 159 inline __device__ void grad1d(const CeedInt nelem, const int transpose, 160 const CeedScalar *s_B, const CeedScalar *s_G, 161 const CeedScalar *__restrict__ d_U, 162 CeedScalar *__restrict__ d_V, 163 CeedScalar *slice) { 164 CeedScalar r_U; 165 CeedScalar r_V; 166 167 const int tidx = threadIdx.x; 168 const int tidy = threadIdx.y; 169 const int tidz = threadIdx.z; 170 int dim; 171 172 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 173 elem += gridDim.x*blockDim.z) { 174 for(int comp = 0; comp < BASIS_NCOMP; comp++) { 175 if (!transpose) { 176 readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 177 ContractX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 178 dim = 0; 179 writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 180 } else { 181 dim = 0; 182 readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 183 ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 184 writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 185 } 186 } 187 } 188 } 189 190 //------------------------------------------------------------------------------ 191 // 1D Quadrature weights 192 //------------------------------------------------------------------------------ 193 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 194 CeedScalar *w) { 195 const int tid = threadIdx.x; 196 const CeedScalar weight = qweight1d[tid]; 197 for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 198 elem += gridDim.x*blockDim.y) { 199 const int ind = elem*Q1D + tid; 200 w[ind] = weight; 201 } 202 } 203 204 //------------------------------------------------------------------------------ 205 // 2D 206 //------------------------------------------------------------------------------ 207 208 //------------------------------------------------------------------------------ 209 // Read DoFs 210 //------------------------------------------------------------------------------ 211 inline __device__ void readDofs2d(const int elem, const int tidx, 212 const int tidy, const int comp, 213 const int nelem, const CeedScalar *d_U, 214 CeedScalar &U) { 215 U = (tidx<P1D && tidy<P1D) ? 216 d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 217 } 218 219 //------------------------------------------------------------------------------ 220 // Write DoFs 221 //------------------------------------------------------------------------------ 222 inline __device__ void writeDofs2d(const int elem, const int tidx, 223 const int tidy, const int comp, 224 const int nelem, const CeedScalar &r_V, 225 CeedScalar *d_V) { 226 if (tidx<P1D && tidy<P1D) 227 d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 228 } 229 230 //------------------------------------------------------------------------------ 231 // Read quadrature point data 232 //------------------------------------------------------------------------------ 233 inline __device__ void readQuads2d(const int elem, const int tidx, 234 const int tidy, const int comp, 235 const int dim, const int nelem, 236 const CeedScalar *d_U, CeedScalar &U ) { 237 U = (tidx<Q1D && tidy<Q1D) ? 238 d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 239 dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 240 } 241 242 //------------------------------------------------------------------------------ 243 // Write quadrature point data 244 //------------------------------------------------------------------------------ 245 inline __device__ void writeQuads2d(const int elem, const int tidx, 246 const int tidy, const int comp, 247 const int dim, const int nelem, 248 const CeedScalar &r_V, CeedScalar *d_V) { 249 if (tidx<Q1D && tidy<Q1D) 250 d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 251 dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 252 } 253 254 //------------------------------------------------------------------------------ 255 // 2D tensor contraction x 256 //------------------------------------------------------------------------------ 257 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 258 const int tidy, const int tidz, 259 const CeedScalar &U, const CeedScalar *B, 260 CeedScalar &V) { 261 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 262 __syncthreads(); 263 V = 0.0; 264 if (tidx < Q1D) 265 for (int i = 0; i < P1D; ++i) 266 V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 267 __syncthreads(); 268 } 269 270 //------------------------------------------------------------------------------ 271 // 2D tensor contraction y 272 //------------------------------------------------------------------------------ 273 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 274 const int tidy, const int tidz, 275 const CeedScalar &U, const CeedScalar *B, 276 CeedScalar &V) { 277 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 278 __syncthreads(); 279 V = 0.0; 280 if (tidy < Q1D) 281 for (int i = 0; i < P1D; ++i) 282 V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 283 __syncthreads(); 284 } 285 286 //------------------------------------------------------------------------------ 287 // 2D transpose tensor contraction y 288 //------------------------------------------------------------------------------ 289 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 290 const int tidy, const int tidz, 291 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 292 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 293 __syncthreads(); 294 V = 0.0; 295 if (tidy < P1D) 296 for (int i = 0; i < Q1D; ++i) 297 V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 298 __syncthreads(); 299 } 300 301 //------------------------------------------------------------------------------ 302 // 2D transpose tensor contraction x 303 //------------------------------------------------------------------------------ 304 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 305 const int tidy, const int tidz, 306 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 307 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 308 __syncthreads(); 309 V = 0.0; 310 if (tidx < P1D) 311 for (int i = 0; i < Q1D; ++i) 312 V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 313 __syncthreads(); 314 } 315 316 //------------------------------------------------------------------------------ 317 // 2D interpolate to quadrature points 318 //------------------------------------------------------------------------------ 319 inline __device__ void interp2d(const CeedInt nelem, const int transpose, 320 const CeedScalar *s_B, 321 const CeedScalar *__restrict__ d_U, 322 CeedScalar *__restrict__ d_V, 323 CeedScalar *slice) { 324 CeedScalar r_V; 325 CeedScalar r_t; 326 327 const int tidx = threadIdx.x; 328 const int tidy = threadIdx.y; 329 const int tidz = threadIdx.z; 330 const int blockElem = tidz/BASIS_NCOMP; 331 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 332 const int comp = tidz%BASIS_NCOMP; 333 334 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 335 elem += gridDim.x*elemsPerBlock) { 336 const int comp = tidz%BASIS_NCOMP; 337 r_V = 0.0; 338 r_t = 0.0; 339 if (!transpose) { 340 readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 341 ContractX2d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 342 ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 343 writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 344 } else { 345 readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 346 ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 347 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 348 writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 349 } 350 } 351 } 352 353 //------------------------------------------------------------------------------ 354 // 2D derivatives at quadrature points 355 //------------------------------------------------------------------------------ 356 inline __device__ void grad2d(const CeedInt nelem, const int transpose, 357 const CeedScalar *s_B, const CeedScalar *s_G, 358 const CeedScalar *__restrict__ d_U, 359 CeedScalar *__restrict__ d_V, CeedScalar *slice) { 360 CeedScalar r_U; 361 CeedScalar r_V; 362 CeedScalar r_t; 363 364 const int tidx = threadIdx.x; 365 const int tidy = threadIdx.y; 366 const int tidz = threadIdx.z; 367 const int blockElem = tidz/BASIS_NCOMP; 368 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 369 const int comp = tidz%BASIS_NCOMP; 370 int dim; 371 372 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 373 elem += gridDim.x*elemsPerBlock) { 374 if (!transpose) { 375 readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 376 ContractX2d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 377 ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 378 dim = 0; 379 writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 380 ContractX2d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 381 ContractY2d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 382 dim = 1; 383 writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 384 } else { 385 dim = 0; 386 readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 387 ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 388 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 389 dim = 1; 390 readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 391 ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 392 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 393 r_V += r_U; 394 writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 395 } 396 } 397 } 398 399 //------------------------------------------------------------------------------ 400 // 2D quadrature weights 401 //------------------------------------------------------------------------------ 402 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 403 CeedScalar *w) { 404 const int i = threadIdx.x; 405 const int j = threadIdx.y; 406 const CeedScalar weight = qweight1d[i]*qweight1d[j]; 407 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 408 elem += gridDim.x*blockDim.z) { 409 const int ind = elem*Q1D*Q1D + i + j*Q1D; 410 w[ind] = weight; 411 } 412 } 413 414 //------------------------------------------------------------------------------ 415 // 3D 416 //------------------------------------------------------------------------------ 417 418 //------------------------------------------------------------------------------ 419 // Read DoFs 420 //------------------------------------------------------------------------------ 421 inline __device__ void readDofs3d(const int elem, const int tidx, 422 const int tidy, const int comp, 423 const int nelem, const CeedScalar *d_U, 424 CeedScalar *r_U) { 425 for (int i = 0; i < P1D; i++) 426 r_U[i] = (tidx < P1D && tidy < P1D) ? 427 d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 428 comp*P1D*P1D*P1D*nelem] : 0.0; 429 for (int i = P1D; i < Q1D; i++) 430 r_U[i] = 0.0; 431 } 432 433 //------------------------------------------------------------------------------ 434 // Write DoFs 435 //------------------------------------------------------------------------------ 436 inline __device__ void writeDofs3d(const int elem, const int tidx, 437 const int tidy, const int comp, 438 const int nelem, const CeedScalar *r_V, 439 CeedScalar *d_V) { 440 if (tidx < P1D && tidy < P1D) { 441 for (int i = 0; i < P1D; i++) 442 d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 443 comp*P1D*P1D*P1D*nelem] = r_V[i]; 444 } 445 } 446 447 //------------------------------------------------------------------------------ 448 // Read quadrature point data 449 //------------------------------------------------------------------------------ 450 inline __device__ void readQuads3d(const int elem, const int tidx, 451 const int tidy, const int comp, 452 const int dim, const int nelem, 453 const CeedScalar *d_U, CeedScalar *r_U) { 454 for (int i = 0; i < Q1D; i++) 455 r_U[i] = (tidx < Q1D && tidy < Q1D) ? 456 d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 457 comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 458 for (int i = Q1D; i < P1D; i++) 459 r_U[i] = 0.0; 460 } 461 462 //------------------------------------------------------------------------------ 463 // Write quadrature point data 464 //------------------------------------------------------------------------------ 465 inline __device__ void writeQuads3d(const int elem, const int tidx, 466 const int tidy, const int comp, 467 const int dim, const int nelem, 468 const CeedScalar *r_V, CeedScalar *d_V) { 469 if (tidx < Q1D && tidy < Q1D) { 470 for (int i = 0; i < Q1D; i++) 471 d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 472 dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 473 } 474 } 475 476 //------------------------------------------------------------------------------ 477 // 3D tensor contract x 478 //------------------------------------------------------------------------------ 479 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 480 const int tidy, const int tidz, 481 const CeedScalar *U, 482 const CeedScalar *B, 483 CeedScalar *V) { 484 for (int k = 0; k < P1D; ++k) { 485 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 486 __syncthreads(); 487 V[k] = 0.0; 488 if (tidx < Q1D && tidy < P1D) 489 for (int i = 0; i < P1D; ++i) 490 V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 491 __syncthreads(); 492 } 493 } 494 495 //------------------------------------------------------------------------------ 496 // 3D tensor contract y 497 //------------------------------------------------------------------------------ 498 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 499 const int tidy, const int tidz, 500 const CeedScalar *U, 501 const CeedScalar *B, 502 CeedScalar *V) { 503 for (int k = 0; k < P1D; ++k) { 504 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 505 __syncthreads(); 506 V[k] = 0.0; 507 if (tidx < Q1D && tidy < Q1D) 508 for (int i = 0; i < P1D; ++i) 509 V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 510 __syncthreads(); 511 } 512 } 513 514 //------------------------------------------------------------------------------ 515 // 3D tensor contract z 516 //------------------------------------------------------------------------------ 517 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 518 const int tidy, const int tidz, 519 const CeedScalar *U, 520 const CeedScalar *B, 521 CeedScalar *V) { 522 for (int k = 0; k < Q1D; ++k) { 523 V[k] = 0.0; 524 if (tidx < Q1D && tidy < Q1D) 525 for (int i = 0; i < P1D; ++i) 526 V[k] += B[i + k*P1D] * U[i]; // Contract z direction 527 } 528 for (int k = Q1D; k < P1D; ++k) 529 V[k] = 0.0; 530 } 531 532 //------------------------------------------------------------------------------ 533 // 3D transpose tensor contract z 534 //------------------------------------------------------------------------------ 535 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 536 const int tidy, const int tidz, 537 const CeedScalar *U, 538 const CeedScalar *B, 539 CeedScalar *V) { 540 for (int k = 0; k < P1D; ++k) { 541 V[k] = 0.0; 542 if (tidx < Q1D && tidy < Q1D) 543 for (int i = 0; i < Q1D; ++i) 544 V[k] += B[k + i*P1D] * U[i]; // Contract z direction 545 } 546 for (int k = P1D; k < Q1D; ++k) 547 V[k] = 0.0; 548 } 549 550 //------------------------------------------------------------------------------ 551 // 3D transpose tensor contract y 552 //------------------------------------------------------------------------------ 553 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 554 const int tidy, const int tidz, 555 const CeedScalar *U, 556 const CeedScalar *B, 557 CeedScalar *V) { 558 for (int k = 0; k < P1D; ++k) { 559 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 560 __syncthreads(); 561 V[k] = 0.0; 562 if (tidx < Q1D && tidy < P1D) 563 for (int i = 0; i < Q1D; ++i) 564 V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 565 __syncthreads(); 566 } 567 } 568 569 //------------------------------------------------------------------------------ 570 // 3D transpose tensor contract x 571 //------------------------------------------------------------------------------ 572 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 573 const int tidy, const int tidz, 574 const CeedScalar *U, 575 const CeedScalar *B, 576 CeedScalar *V) { 577 for (int k = 0; k < P1D; ++k) { 578 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 579 __syncthreads(); 580 V[k] = 0.0; 581 if (tidx < P1D && tidy < P1D) 582 for (int i = 0; i < Q1D; ++i) 583 V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 584 __syncthreads(); 585 } 586 } 587 588 //------------------------------------------------------------------------------ 589 // 3D interpolate to quadrature points 590 //------------------------------------------------------------------------------ 591 inline __device__ void interp3d(const CeedInt nelem, const int transpose, 592 const CeedScalar *s_B, 593 const CeedScalar *__restrict__ d_U, 594 CeedScalar *__restrict__ d_V, 595 CeedScalar *slice) { 596 CeedScalar r_V[T1D]; 597 CeedScalar r_t[T1D]; 598 599 const int tidx = threadIdx.x; 600 const int tidy = threadIdx.y; 601 const int tidz = threadIdx.z; 602 const int blockElem = tidz/BASIS_NCOMP; 603 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 604 const int comp = tidz%BASIS_NCOMP; 605 606 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 607 elem += gridDim.x*elemsPerBlock) { 608 for (int i = 0; i < T1D; ++i) { 609 r_V[i] = 0.0; 610 r_t[i] = 0.0; 611 } 612 if (!transpose) { 613 readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 614 ContractX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 615 ContractY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 616 ContractZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 617 writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 618 } else { 619 readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 620 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 621 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 622 ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 623 writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 624 } 625 } 626 } 627 628 //------------------------------------------------------------------------------ 629 // 3D derivatives at quadrature points 630 //------------------------------------------------------------------------------ 631 inline __device__ void grad3d(const CeedInt nelem, const int transpose, 632 const CeedScalar *s_B, const CeedScalar *s_G, 633 const CeedScalar *__restrict__ d_U, 634 CeedScalar *__restrict__ d_V, 635 CeedScalar *slice) { 636 // Use P1D for one of these 637 CeedScalar r_U[T1D]; 638 CeedScalar r_V[T1D]; 639 CeedScalar r_t[T1D]; 640 641 const int tidx = threadIdx.x; 642 const int tidy = threadIdx.y; 643 const int tidz = threadIdx.z; 644 const int blockElem = tidz/BASIS_NCOMP; 645 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 646 const int comp = tidz%BASIS_NCOMP; 647 int dim; 648 649 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 650 elem += gridDim.x*elemsPerBlock) { 651 for (int i = 0; i < T1D; ++i) { 652 r_U[i] = 0.0; 653 r_V[i] = 0.0; 654 r_t[i] = 0.0; 655 } 656 if (!transpose) { 657 readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 658 ContractX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 659 ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 660 ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 661 dim = 0; 662 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 663 ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V); 664 ContractY3d(slice, tidx, tidy, tidz, r_V, s_G, r_t); 665 ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 666 dim = 1; 667 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 668 ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V); 669 ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 670 ContractZ3d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 671 dim = 2; 672 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 673 } else { 674 dim = 0; 675 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 676 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 677 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 678 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 679 dim = 1; 680 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 681 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 682 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_G, r_U); 683 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 684 add(r_V, r_t); 685 dim = 2; 686 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 687 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 688 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 689 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 690 add(r_V, r_t); 691 writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 692 } 693 } 694 } 695 696 //------------------------------------------------------------------------------ 697 // 3D quadrature weights 698 //------------------------------------------------------------------------------ 699 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 700 CeedScalar *w) { 701 const int i = threadIdx.x; 702 const int j = threadIdx.y; 703 const int k = threadIdx.z; 704 const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 705 for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 706 const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 707 w[ind] = weight; 708 } 709 } 710 711 712 //------------------------------------------------------------------------------ 713 // Basis kernels 714 //------------------------------------------------------------------------------ 715 716 //------------------------------------------------------------------------------ 717 // Interp kernel by dim 718 //------------------------------------------------------------------------------ 719 extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp( 720 const CeedInt nelem, const int transpose, 721 CeedScalar *d_interp1d, 722 const CeedScalar *__restrict__ d_U, 723 CeedScalar *__restrict__ d_V) { 724 725 HIP_DYNAMIC_SHARED( double, slice) 726 // load interp1d into shared memory 727 __shared__ double s_B[P1D*Q1D]; 728 loadMatrix(d_interp1d, s_B); 729 __syncthreads(); 730 731 if (BASIS_DIM == 1) { 732 interp1d(nelem, transpose, s_B, d_U, d_V, slice); 733 } else if (BASIS_DIM == 2) { 734 interp2d(nelem, transpose, s_B, d_U, d_V, slice); 735 } else if (BASIS_DIM == 3) { 736 interp3d(nelem, transpose, s_B, d_U, d_V, slice); 737 } 738 } 739 740 //------------------------------------------------------------------------------ 741 // Grad kernel by dim 742 //------------------------------------------------------------------------------ 743 extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem, 744 const int transpose, 745 CeedScalar *d_interp1d, CeedScalar *d_grad1d, 746 const CeedScalar *__restrict__ d_U, 747 CeedScalar *__restrict__ d_V) { 748 HIP_DYNAMIC_SHARED( double, slice) 749 // load interp1d and grad1d into shared memory 750 __shared__ double s_B[P1D*Q1D]; 751 loadMatrix(d_interp1d, s_B); 752 __shared__ double s_G[P1D*Q1D]; 753 loadMatrix(d_grad1d, s_G); 754 __syncthreads(); 755 756 if (BASIS_DIM == 1) { 757 grad1d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 758 } else if (BASIS_DIM == 2) { 759 grad2d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 760 } else if (BASIS_DIM == 3) { 761 grad3d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 762 } 763 } 764 765 //------------------------------------------------------------------------------ 766 // Weight kernels by dim 767 //------------------------------------------------------------------------------ 768 extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem, 769 const CeedScalar *__restrict__ qweight1d, 770 CeedScalar *__restrict__ v) { 771 if (BASIS_DIM == 1) { 772 weight1d(nelem, qweight1d, v); 773 } else if (BASIS_DIM == 2) { 774 weight2d(nelem, qweight1d, v); 775 } else if (BASIS_DIM == 3) { 776 weight3d(nelem, qweight1d, v); 777 } 778 } 779 780 ); 781 // *INDENT-ON* 782 783 //------------------------------------------------------------------------------ 784 // Compute a block size based on required minimum threads 785 //------------------------------------------------------------------------------ 786 static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) { 787 CeedInt maxSize = 1024; // Max total threads per block 788 CeedInt currentSize = 64; // Start with one group 789 790 while(currentSize < maxSize) { 791 if (currentSize > required) 792 break; 793 else 794 currentSize = currentSize * 2; 795 } 796 return currentSize; 797 } 798 799 //------------------------------------------------------------------------------ 800 // Compute required thread block sizes for basis kernels given P, Q, dim, and 801 // ncomp 802 //------------------------------------------------------------------------------ 803 static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d, 804 const CeedInt Q1d, 805 const CeedInt ncomp, CeedInt *blksizes) { 806 807 // Note that this will use the same block sizes for all dimensions when compiling, 808 // but as each basis object is defined for a particular dimension, we will never 809 // call any kernels except the ones for the dimension for which we have computed the 810 // block sizes. 811 const CeedInt thread1d = CeedIntMax(P1d, Q1d); 812 switch (dim) { 813 case 1: { 814 // Interp kernels: 815 blksizes[0] = 256; 816 817 // Grad kernels: 818 blksizes[1] = 256; 819 820 // Weight kernels: 821 blksizes[2] = 256; 822 823 } break; 824 case 2: { 825 // Interp kernels: 826 CeedInt required = thread1d * thread1d * ncomp; 827 blksizes[0] = ComputeBlockSizeFromRequirement(required); 828 829 // Grad kernels: currently use same required minimum threads 830 blksizes[1] = ComputeBlockSizeFromRequirement(required); 831 832 // Weight kernels: 833 required = CeedIntMax(64, Q1d * Q1d); 834 blksizes[2] = ComputeBlockSizeFromRequirement(required); 835 836 } break; 837 case 3: { 838 // Interp kernels: 839 CeedInt required = thread1d * thread1d * ncomp; 840 blksizes[0] = ComputeBlockSizeFromRequirement(required); 841 842 // Grad kernels: currently use same required minimum threads 843 blksizes[1] = ComputeBlockSizeFromRequirement(required); 844 845 // Weight kernels: 846 required = Q1d * Q1d * Q1d; 847 blksizes[2] = ComputeBlockSizeFromRequirement(required); 848 } 849 } 850 851 return CEED_ERROR_SUCCESS; 852 } 853 854 //------------------------------------------------------------------------------ 855 // Apply basis 856 //------------------------------------------------------------------------------ 857 int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem, 858 CeedTransposeMode tmode, 859 CeedEvalMode emode, CeedVector u, 860 CeedVector v) { 861 int ierr; 862 Ceed ceed; 863 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 864 Ceed_Hip_shared *ceed_Hip; 865 CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr); 866 CeedBasis_Hip_shared *data; 867 CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 868 const CeedInt transpose = tmode == CEED_TRANSPOSE; 869 CeedInt dim, ncomp; 870 ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr); 871 ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 872 873 // Read vectors 874 const CeedScalar *d_u; 875 CeedScalar *d_v; 876 if (emode != CEED_EVAL_WEIGHT) { 877 ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr); 878 } 879 ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr); 880 881 // Clear v for transpose mode 882 if (tmode == CEED_TRANSPOSE) { 883 CeedInt length; 884 ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr); 885 ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr); 886 } 887 888 // Apply basis operation 889 switch (emode) { 890 case CEED_EVAL_INTERP: { 891 CeedInt P1d, Q1d; 892 CeedInt blksize = data->blksizes[0]; 893 ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 894 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 895 CeedInt thread1d = CeedIntMax(Q1d, P1d); 896 void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, 897 &d_u, &d_v 898 }; 899 if (dim == 1) { 900 CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 901 elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 902 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 903 ? 1 : 0 ); 904 CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 905 ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1, 906 elemsPerBlock, sharedMem, 907 interpargs); CeedChkBackend(ierr); 908 } else if (dim == 2) { 909 // Check if required threads is small enough to do multiple elems 910 const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 911 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 912 ? 1 : 0 ); 913 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 914 ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 915 ncomp*elemsPerBlock, sharedMem, 916 interpargs); CeedChkBackend(ierr); 917 } else if (dim == 3) { 918 CeedInt elemsPerBlock = 1; 919 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 920 ? 1 : 0 ); 921 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 922 ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 923 ncomp*elemsPerBlock, sharedMem, 924 interpargs); CeedChkBackend(ierr); 925 } 926 } break; 927 case CEED_EVAL_GRAD: { 928 CeedInt P1d, Q1d; 929 CeedInt blksize = data->blksizes[1]; 930 ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 931 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 932 CeedInt thread1d = CeedIntMax(Q1d, P1d); 933 void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, 934 &data->d_grad1d, &d_u, &d_v 935 }; 936 if (dim == 1) { 937 CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 938 elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 939 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 940 ? 1 : 0 ); 941 CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 942 ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1, 943 elemsPerBlock, sharedMem, gradargs); 944 CeedChkBackend(ierr); 945 } else if (dim == 2) { 946 // Check if required threads is small enough to do multiple elems 947 const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 948 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 949 ? 1 : 0 ); 950 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 951 ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 952 ncomp*elemsPerBlock, sharedMem, 953 gradargs); CeedChkBackend(ierr); 954 } else if (dim == 3) { 955 CeedInt elemsPerBlock = 1; 956 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 957 ? 1 : 0 ); 958 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 959 ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 960 ncomp*elemsPerBlock, sharedMem, 961 gradargs); CeedChkBackend(ierr); 962 } 963 } break; 964 case CEED_EVAL_WEIGHT: { 965 CeedInt Q1d; 966 CeedInt blksize = data->blksizes[2]; 967 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 968 void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 969 if (dim == 1) { 970 const CeedInt optElems = blksize/Q1d; 971 const CeedInt elemsPerBlock = optElems>0?optElems:1; 972 const CeedInt gridsize = nelem/elemsPerBlock + ( ( 973 nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 974 ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, 975 elemsPerBlock, 1, weightargs); 976 CeedChkBackend(ierr); 977 } else if (dim == 2) { 978 const CeedInt optElems = blksize/(Q1d*Q1d); 979 const CeedInt elemsPerBlock = optElems>0?optElems:1; 980 const CeedInt gridsize = nelem/elemsPerBlock + ( ( 981 nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 982 ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, 983 elemsPerBlock, weightargs); 984 CeedChkBackend(ierr); 985 } else if (dim == 3) { 986 const CeedInt gridsize = nelem; 987 ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 988 weightargs); 989 CeedChkBackend(ierr); 990 } 991 } break; 992 // LCOV_EXCL_START 993 // Evaluate the divergence to/from the quadrature points 994 case CEED_EVAL_DIV: 995 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported"); 996 // Evaluate the curl to/from the quadrature points 997 case CEED_EVAL_CURL: 998 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported"); 999 // Take no action, BasisApply should not have been called 1000 case CEED_EVAL_NONE: 1001 return CeedError(ceed, CEED_ERROR_BACKEND, 1002 "CEED_EVAL_NONE does not make sense in this context"); 1003 // LCOV_EXCL_STOP 1004 } 1005 1006 // Restore vectors 1007 if (emode != CEED_EVAL_WEIGHT) { 1008 ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr); 1009 } 1010 ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr); 1011 return CEED_ERROR_SUCCESS; 1012 } 1013 1014 //------------------------------------------------------------------------------ 1015 // Destroy basis 1016 //------------------------------------------------------------------------------ 1017 static int CeedBasisDestroy_Hip_shared(CeedBasis basis) { 1018 int ierr; 1019 Ceed ceed; 1020 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 1021 1022 CeedBasis_Hip_shared *data; 1023 ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 1024 1025 CeedChk_Hip(ceed, hipModuleUnload(data->module)); 1026 1027 ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr); 1028 ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr); 1029 ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr); 1030 ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr); 1031 1032 ierr = CeedFree(&data); CeedChkBackend(ierr); 1033 1034 return CEED_ERROR_SUCCESS; 1035 } 1036 1037 //------------------------------------------------------------------------------ 1038 // Create tensor basis 1039 //------------------------------------------------------------------------------ 1040 int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 1041 const CeedScalar *interp1d, 1042 const CeedScalar *grad1d, 1043 const CeedScalar *qref1d, 1044 const CeedScalar *qweight1d, 1045 CeedBasis basis) { 1046 int ierr; 1047 Ceed ceed; 1048 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 1049 CeedBasis_Hip_shared *data; 1050 ierr = CeedCalloc(1, &data); CeedChkBackend(ierr); 1051 1052 // Copy basis data to GPU 1053 const CeedInt qBytes = Q1d * sizeof(CeedScalar); 1054 ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr); 1055 ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes, 1056 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1057 1058 const CeedInt iBytes = qBytes * P1d; 1059 ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr); 1060 ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes, 1061 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1062 1063 ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr); 1064 ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes, 1065 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1066 1067 // Compute collocated gradient and copy to GPU 1068 data->d_collograd1d = NULL; 1069 if (dim == 3 && Q1d >= P1d) { 1070 CeedScalar *collograd1d; 1071 ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr); 1072 ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr); 1073 ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 1074 CeedChk_Hip(ceed, ierr); 1075 ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 1076 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1077 ierr = CeedFree(&collograd1d); CeedChkBackend(ierr); 1078 } 1079 1080 // Set number of threads per block for basis kernels 1081 CeedInt ncomp; 1082 ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 1083 ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes); 1084 CeedChkBackend(ierr); 1085 1086 // Compile basis kernels 1087 ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11, 1088 "Q1D", Q1d, 1089 "P1D", P1d, 1090 "T1D", CeedIntMax(Q1d, P1d), 1091 "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 1092 Q1d : P1d, dim), 1093 "BASIS_DIM", dim, 1094 "BASIS_NCOMP", ncomp, 1095 "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 1096 "BASIS_NQPT", CeedIntPow(Q1d, dim), 1097 "INTERP_BLKSIZE", data->blksizes[0], 1098 "GRAD_BLKSIZE", data->blksizes[1], 1099 "WEIGHT_BLKSIZE", data->blksizes[2] 1100 ); CeedChkBackend(ierr); 1101 ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp); 1102 CeedChkBackend(ierr); 1103 ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad); 1104 CeedChkBackend(ierr); 1105 ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight); 1106 CeedChkBackend(ierr); 1107 1108 ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr); 1109 1110 // Register backend functions 1111 ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 1112 CeedBasisApplyTensor_Hip_shared); 1113 CeedChkBackend(ierr); 1114 ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 1115 CeedBasisDestroy_Hip_shared); CeedChkBackend(ierr); 1116 return CEED_ERROR_SUCCESS; 1117 } 1118 //------------------------------------------------------------------------------ 1119