1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include <ceed/ceed.h> 18 #include <ceed/backend.h> 19 #include <hip/hip_runtime.h> 20 #include <stddef.h> 21 #include "ceed-hip-shared.h" 22 #include "../hip/ceed-hip.h" 23 #include "../hip/ceed-hip-compile.h" 24 25 //------------------------------------------------------------------------------ 26 // Shared mem kernels 27 //------------------------------------------------------------------------------ 28 // *INDENT-OFF* 29 static const char *kernelsShared = QUOTE( 30 31 //------------------------------------------------------------------------------ 32 // Sum input into output 33 //------------------------------------------------------------------------------ 34 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 35 for (int i = 0; i < P1D; i++) 36 r_V[i] += r_U[i]; 37 } 38 39 //------------------------------------------------------------------------------ 40 // Load matrices for basis actions 41 //------------------------------------------------------------------------------ 42 inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) { 43 CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x; 44 for (CeedInt i = tid; i < P1D*Q1D; i += blockDim.x*blockDim.y*blockDim.z) 45 B[i] = d_B[i]; 46 } 47 48 //------------------------------------------------------------------------------ 49 // 1D 50 //------------------------------------------------------------------------------ 51 52 //------------------------------------------------------------------------------ 53 // Read DoFs 54 //------------------------------------------------------------------------------ 55 inline __device__ void readDofs1d(const int elem, const int tidx, 56 const int tidy, const int tidz,const int comp, 57 const int nelem, const CeedScalar *d_U, 58 CeedScalar *slice) { 59 for (int i = 0; i < P1D; i++) 60 slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 61 for (int i = P1D; i < Q1D; i++) 62 slice[i + tidz*T1D] = 0.0; 63 } 64 65 //------------------------------------------------------------------------------ 66 // Write DoFs 67 //------------------------------------------------------------------------------ 68 inline __device__ void writeDofs1d(const int elem, const int tidx, 69 const int tidy, const int comp, 70 const int nelem, const CeedScalar &r_V, 71 CeedScalar *d_V) { 72 if (tidx<P1D) 73 d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 74 } 75 76 //------------------------------------------------------------------------------ 77 // Read quadrature point data 78 //------------------------------------------------------------------------------ 79 inline __device__ void readQuads1d(const int elem, const int tidx, 80 const int tidy, const int tidz, const int comp, 81 const int dim, const int nelem, 82 const CeedScalar *d_U, CeedScalar *slice) { 83 for (int i = 0; i < Q1D; i++) 84 slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 85 dim*BASIS_NCOMP*nelem*Q1D]; 86 for (int i = Q1D; i < P1D; i++) 87 slice[i + tidz*T1D] = 0.0; 88 } 89 90 //------------------------------------------------------------------------------ 91 // Write quadrature point data 92 //------------------------------------------------------------------------------ 93 inline __device__ void writeQuads1d(const int elem, const int tidx, 94 const int tidy, const int comp, 95 const int dim, const int nelem, 96 const CeedScalar &r_V, CeedScalar *d_V) { 97 if (tidx<Q1D) 98 d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 99 } 100 101 //------------------------------------------------------------------------------ 102 // 1D tensor contraction 103 //------------------------------------------------------------------------------ 104 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 105 const int tidy, const int tidz, 106 const CeedScalar &U, const CeedScalar *B, 107 CeedScalar &V) { 108 V = 0.0; 109 for (int i = 0; i < P1D; ++i) 110 V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 111 } 112 113 //------------------------------------------------------------------------------ 114 // 1D transpose tensor contraction 115 //------------------------------------------------------------------------------ 116 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 117 const int tidy, const int tidz, 118 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 119 V = 0.0; 120 for (int i = 0; i < Q1D; ++i) 121 V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 122 } 123 124 //------------------------------------------------------------------------------ 125 // 1D interpolate to quadrature points 126 //------------------------------------------------------------------------------ 127 inline __device__ void interp1d(const CeedInt nelem, const int transpose, 128 const CeedScalar *s_B, 129 const CeedScalar *__restrict__ d_U, 130 CeedScalar *__restrict__ d_V, 131 CeedScalar *slice) { 132 CeedScalar r_V; 133 CeedScalar r_t; 134 135 const int tidx = threadIdx.x; 136 const int tidy = threadIdx.y; 137 const int tidz = threadIdx.z; 138 139 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 140 elem += gridDim.x*blockDim.z) { 141 for (int comp = 0; comp < BASIS_NCOMP; comp++) { 142 if (!transpose) { 143 readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 144 ContractX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 145 writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 146 } else { 147 readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 148 ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 149 writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 150 } 151 } 152 } 153 } 154 155 //------------------------------------------------------------------------------ 156 // 1D derivatives at quadrature points 157 //------------------------------------------------------------------------------ 158 inline __device__ void grad1d(const CeedInt nelem, const int transpose, 159 const CeedScalar *s_B, const CeedScalar *s_G, 160 const CeedScalar *__restrict__ d_U, 161 CeedScalar *__restrict__ d_V, 162 CeedScalar *slice) { 163 CeedScalar r_U; 164 CeedScalar r_V; 165 166 const int tidx = threadIdx.x; 167 const int tidy = threadIdx.y; 168 const int tidz = threadIdx.z; 169 int dim; 170 171 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 172 elem += gridDim.x*blockDim.z) { 173 for(int comp = 0; comp < BASIS_NCOMP; comp++) { 174 if (!transpose) { 175 readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 176 ContractX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 177 dim = 0; 178 writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 179 } else { 180 dim = 0; 181 readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 182 ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 183 writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 184 } 185 } 186 } 187 } 188 189 //------------------------------------------------------------------------------ 190 // 1D Quadrature weights 191 //------------------------------------------------------------------------------ 192 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 193 CeedScalar *w) { 194 const int tid = threadIdx.x; 195 const CeedScalar weight = qweight1d[tid]; 196 for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 197 elem += gridDim.x*blockDim.y) { 198 const int ind = elem*Q1D + tid; 199 w[ind] = weight; 200 } 201 } 202 203 //------------------------------------------------------------------------------ 204 // 2D 205 //------------------------------------------------------------------------------ 206 207 //------------------------------------------------------------------------------ 208 // Read DoFs 209 //------------------------------------------------------------------------------ 210 inline __device__ void readDofs2d(const int elem, const int tidx, 211 const int tidy, const int comp, 212 const int nelem, const CeedScalar *d_U, 213 CeedScalar &U) { 214 U = (tidx<P1D && tidy<P1D) ? 215 d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 216 } 217 218 //------------------------------------------------------------------------------ 219 // Write DoFs 220 //------------------------------------------------------------------------------ 221 inline __device__ void writeDofs2d(const int elem, const int tidx, 222 const int tidy, const int comp, 223 const int nelem, const CeedScalar &r_V, 224 CeedScalar *d_V) { 225 if (tidx<P1D && tidy<P1D) 226 d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 227 } 228 229 //------------------------------------------------------------------------------ 230 // Read quadrature point data 231 //------------------------------------------------------------------------------ 232 inline __device__ void readQuads2d(const int elem, const int tidx, 233 const int tidy, const int comp, 234 const int dim, const int nelem, 235 const CeedScalar *d_U, CeedScalar &U ) { 236 U = (tidx<Q1D && tidy<Q1D) ? 237 d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 238 dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 239 } 240 241 //------------------------------------------------------------------------------ 242 // Write quadrature point data 243 //------------------------------------------------------------------------------ 244 inline __device__ void writeQuads2d(const int elem, const int tidx, 245 const int tidy, const int comp, 246 const int dim, const int nelem, 247 const CeedScalar &r_V, CeedScalar *d_V) { 248 if (tidx<Q1D && tidy<Q1D) 249 d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 250 dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 251 } 252 253 //------------------------------------------------------------------------------ 254 // 2D tensor contraction x 255 //------------------------------------------------------------------------------ 256 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 257 const int tidy, const int tidz, 258 const CeedScalar &U, const CeedScalar *B, 259 CeedScalar &V) { 260 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 261 __syncthreads(); 262 V = 0.0; 263 if (tidx < Q1D) 264 for (int i = 0; i < P1D; ++i) 265 V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 266 __syncthreads(); 267 } 268 269 //------------------------------------------------------------------------------ 270 // 2D tensor contraction y 271 //------------------------------------------------------------------------------ 272 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 273 const int tidy, const int tidz, 274 const CeedScalar &U, const CeedScalar *B, 275 CeedScalar &V) { 276 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 277 __syncthreads(); 278 V = 0.0; 279 if (tidy < Q1D) 280 for (int i = 0; i < P1D; ++i) 281 V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 282 __syncthreads(); 283 } 284 285 //------------------------------------------------------------------------------ 286 // 2D transpose tensor contraction y 287 //------------------------------------------------------------------------------ 288 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 289 const int tidy, const int tidz, 290 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 291 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 292 __syncthreads(); 293 V = 0.0; 294 if (tidy < P1D) 295 for (int i = 0; i < Q1D; ++i) 296 V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 297 __syncthreads(); 298 } 299 300 //------------------------------------------------------------------------------ 301 // 2D transpose tensor contraction x 302 //------------------------------------------------------------------------------ 303 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 304 const int tidy, const int tidz, 305 const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 306 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 307 __syncthreads(); 308 V = 0.0; 309 if (tidx < P1D) 310 for (int i = 0; i < Q1D; ++i) 311 V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 312 __syncthreads(); 313 } 314 315 //------------------------------------------------------------------------------ 316 // 2D interpolate to quadrature points 317 //------------------------------------------------------------------------------ 318 inline __device__ void interp2d(const CeedInt nelem, const int transpose, 319 const CeedScalar *s_B, 320 const CeedScalar *__restrict__ d_U, 321 CeedScalar *__restrict__ d_V, 322 CeedScalar *slice) { 323 CeedScalar r_V; 324 CeedScalar r_t; 325 326 const int tidx = threadIdx.x; 327 const int tidy = threadIdx.y; 328 const int tidz = threadIdx.z; 329 const int blockElem = tidz/BASIS_NCOMP; 330 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 331 const int comp = tidz%BASIS_NCOMP; 332 333 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 334 elem += gridDim.x*elemsPerBlock) { 335 const int comp = tidz%BASIS_NCOMP; 336 r_V = 0.0; 337 r_t = 0.0; 338 if (!transpose) { 339 readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 340 ContractX2d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 341 ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 342 writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 343 } else { 344 readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 345 ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 346 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 347 writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 348 } 349 } 350 } 351 352 //------------------------------------------------------------------------------ 353 // 2D derivatives at quadrature points 354 //------------------------------------------------------------------------------ 355 inline __device__ void grad2d(const CeedInt nelem, const int transpose, 356 const CeedScalar *s_B, const CeedScalar *s_G, 357 const CeedScalar *__restrict__ d_U, 358 CeedScalar *__restrict__ d_V, CeedScalar *slice) { 359 CeedScalar r_U; 360 CeedScalar r_V; 361 CeedScalar r_t; 362 363 const int tidx = threadIdx.x; 364 const int tidy = threadIdx.y; 365 const int tidz = threadIdx.z; 366 const int blockElem = tidz/BASIS_NCOMP; 367 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 368 const int comp = tidz%BASIS_NCOMP; 369 int dim; 370 371 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 372 elem += gridDim.x*elemsPerBlock) { 373 if (!transpose) { 374 readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 375 ContractX2d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 376 ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 377 dim = 0; 378 writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 379 ContractX2d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 380 ContractY2d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 381 dim = 1; 382 writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 383 } else { 384 dim = 0; 385 readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 386 ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 387 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 388 dim = 1; 389 readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 390 ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 391 ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 392 r_V += r_U; 393 writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 394 } 395 } 396 } 397 398 //------------------------------------------------------------------------------ 399 // 2D quadrature weights 400 //------------------------------------------------------------------------------ 401 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 402 CeedScalar *w) { 403 const int i = threadIdx.x; 404 const int j = threadIdx.y; 405 const CeedScalar weight = qweight1d[i]*qweight1d[j]; 406 for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 407 elem += gridDim.x*blockDim.z) { 408 const int ind = elem*Q1D*Q1D + i + j*Q1D; 409 w[ind] = weight; 410 } 411 } 412 413 //------------------------------------------------------------------------------ 414 // 3D 415 //------------------------------------------------------------------------------ 416 417 //------------------------------------------------------------------------------ 418 // Read DoFs 419 //------------------------------------------------------------------------------ 420 inline __device__ void readDofs3d(const int elem, const int tidx, 421 const int tidy, const int comp, 422 const int nelem, const CeedScalar *d_U, 423 CeedScalar *r_U) { 424 for (int i = 0; i < P1D; i++) 425 r_U[i] = (tidx < P1D && tidy < P1D) ? 426 d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 427 comp*P1D*P1D*P1D*nelem] : 0.0; 428 for (int i = P1D; i < Q1D; i++) 429 r_U[i] = 0.0; 430 } 431 432 //------------------------------------------------------------------------------ 433 // Write DoFs 434 //------------------------------------------------------------------------------ 435 inline __device__ void writeDofs3d(const int elem, const int tidx, 436 const int tidy, const int comp, 437 const int nelem, const CeedScalar *r_V, 438 CeedScalar *d_V) { 439 if (tidx < P1D && tidy < P1D) { 440 for (int i = 0; i < P1D; i++) 441 d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 442 comp*P1D*P1D*P1D*nelem] = r_V[i]; 443 } 444 } 445 446 //------------------------------------------------------------------------------ 447 // Read quadrature point data 448 //------------------------------------------------------------------------------ 449 inline __device__ void readQuads3d(const int elem, const int tidx, 450 const int tidy, const int comp, 451 const int dim, const int nelem, 452 const CeedScalar *d_U, CeedScalar *r_U) { 453 for (int i = 0; i < Q1D; i++) 454 r_U[i] = (tidx < Q1D && tidy < Q1D) ? 455 d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 456 comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 457 for (int i = Q1D; i < P1D; i++) 458 r_U[i] = 0.0; 459 } 460 461 //------------------------------------------------------------------------------ 462 // Write quadrature point data 463 //------------------------------------------------------------------------------ 464 inline __device__ void writeQuads3d(const int elem, const int tidx, 465 const int tidy, const int comp, 466 const int dim, const int nelem, 467 const CeedScalar *r_V, CeedScalar *d_V) { 468 if (tidx < Q1D && tidy < Q1D) { 469 for (int i = 0; i < Q1D; i++) 470 d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 471 dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 472 } 473 } 474 475 //------------------------------------------------------------------------------ 476 // 3D tensor contract x 477 //------------------------------------------------------------------------------ 478 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 479 const int tidy, const int tidz, 480 const CeedScalar *U, 481 const CeedScalar *B, 482 CeedScalar *V) { 483 for (int k = 0; k < P1D; ++k) { 484 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 485 __syncthreads(); 486 V[k] = 0.0; 487 if (tidx < Q1D && tidy < P1D) 488 for (int i = 0; i < P1D; ++i) 489 V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 490 __syncthreads(); 491 } 492 } 493 494 //------------------------------------------------------------------------------ 495 // 3D tensor contract y 496 //------------------------------------------------------------------------------ 497 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 498 const int tidy, const int tidz, 499 const CeedScalar *U, 500 const CeedScalar *B, 501 CeedScalar *V) { 502 for (int k = 0; k < P1D; ++k) { 503 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 504 __syncthreads(); 505 V[k] = 0.0; 506 if (tidx < Q1D && tidy < Q1D) 507 for (int i = 0; i < P1D; ++i) 508 V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 509 __syncthreads(); 510 } 511 } 512 513 //------------------------------------------------------------------------------ 514 // 3D tensor contract z 515 //------------------------------------------------------------------------------ 516 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 517 const int tidy, const int tidz, 518 const CeedScalar *U, 519 const CeedScalar *B, 520 CeedScalar *V) { 521 for (int k = 0; k < Q1D; ++k) { 522 V[k] = 0.0; 523 if (tidx < Q1D && tidy < Q1D) 524 for (int i = 0; i < P1D; ++i) 525 V[k] += B[i + k*P1D] * U[i]; // Contract z direction 526 } 527 for (int k = Q1D; k < P1D; ++k) 528 V[k] = 0.0; 529 } 530 531 //------------------------------------------------------------------------------ 532 // 3D transpose tensor contract z 533 //------------------------------------------------------------------------------ 534 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 535 const int tidy, const int tidz, 536 const CeedScalar *U, 537 const CeedScalar *B, 538 CeedScalar *V) { 539 for (int k = 0; k < P1D; ++k) { 540 V[k] = 0.0; 541 if (tidx < Q1D && tidy < Q1D) 542 for (int i = 0; i < Q1D; ++i) 543 V[k] += B[k + i*P1D] * U[i]; // Contract z direction 544 } 545 for (int k = P1D; k < Q1D; ++k) 546 V[k] = 0.0; 547 } 548 549 //------------------------------------------------------------------------------ 550 // 3D transpose tensor contract y 551 //------------------------------------------------------------------------------ 552 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 553 const int tidy, const int tidz, 554 const CeedScalar *U, 555 const CeedScalar *B, 556 CeedScalar *V) { 557 for (int k = 0; k < P1D; ++k) { 558 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 559 __syncthreads(); 560 V[k] = 0.0; 561 if (tidx < Q1D && tidy < P1D) 562 for (int i = 0; i < Q1D; ++i) 563 V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 564 __syncthreads(); 565 } 566 } 567 568 //------------------------------------------------------------------------------ 569 // 3D transpose tensor contract x 570 //------------------------------------------------------------------------------ 571 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 572 const int tidy, const int tidz, 573 const CeedScalar *U, 574 const CeedScalar *B, 575 CeedScalar *V) { 576 for (int k = 0; k < P1D; ++k) { 577 slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 578 __syncthreads(); 579 V[k] = 0.0; 580 if (tidx < P1D && tidy < P1D) 581 for (int i = 0; i < Q1D; ++i) 582 V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 583 __syncthreads(); 584 } 585 } 586 587 //------------------------------------------------------------------------------ 588 // 3D interpolate to quadrature points 589 //------------------------------------------------------------------------------ 590 inline __device__ void interp3d(const CeedInt nelem, const int transpose, 591 const CeedScalar *s_B, 592 const CeedScalar *__restrict__ d_U, 593 CeedScalar *__restrict__ d_V, 594 CeedScalar *slice) { 595 CeedScalar r_V[T1D]; 596 CeedScalar r_t[T1D]; 597 598 const int tidx = threadIdx.x; 599 const int tidy = threadIdx.y; 600 const int tidz = threadIdx.z; 601 const int blockElem = tidz/BASIS_NCOMP; 602 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 603 const int comp = tidz%BASIS_NCOMP; 604 605 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 606 elem += gridDim.x*elemsPerBlock) { 607 for (int i = 0; i < T1D; ++i) { 608 r_V[i] = 0.0; 609 r_t[i] = 0.0; 610 } 611 if (!transpose) { 612 readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 613 ContractX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 614 ContractY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 615 ContractZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 616 writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 617 } else { 618 readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 619 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 620 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 621 ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 622 writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 623 } 624 } 625 } 626 627 //------------------------------------------------------------------------------ 628 // 3D derivatives at quadrature points 629 //------------------------------------------------------------------------------ 630 inline __device__ void grad3d(const CeedInt nelem, const int transpose, 631 const CeedScalar *s_B, const CeedScalar *s_G, 632 const CeedScalar *__restrict__ d_U, 633 CeedScalar *__restrict__ d_V, 634 CeedScalar *slice) { 635 // Use P1D for one of these 636 CeedScalar r_U[T1D]; 637 CeedScalar r_V[T1D]; 638 CeedScalar r_t[T1D]; 639 640 const int tidx = threadIdx.x; 641 const int tidy = threadIdx.y; 642 const int tidz = threadIdx.z; 643 const int blockElem = tidz/BASIS_NCOMP; 644 const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 645 const int comp = tidz%BASIS_NCOMP; 646 int dim; 647 648 for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 649 elem += gridDim.x*elemsPerBlock) { 650 for (int i = 0; i < T1D; ++i) { 651 r_U[i] = 0.0; 652 r_V[i] = 0.0; 653 r_t[i] = 0.0; 654 } 655 if (!transpose) { 656 readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 657 ContractX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 658 ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 659 ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 660 dim = 0; 661 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 662 ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V); 663 ContractY3d(slice, tidx, tidy, tidz, r_V, s_G, r_t); 664 ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 665 dim = 1; 666 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 667 ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V); 668 ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 669 ContractZ3d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 670 dim = 2; 671 writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 672 } else { 673 dim = 0; 674 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 675 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 676 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 677 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 678 dim = 1; 679 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 680 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 681 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_G, r_U); 682 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 683 add(r_V, r_t); 684 dim = 2; 685 readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 686 ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 687 ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 688 ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 689 add(r_V, r_t); 690 writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 691 } 692 } 693 } 694 695 //------------------------------------------------------------------------------ 696 // 3D quadrature weights 697 //------------------------------------------------------------------------------ 698 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 699 CeedScalar *w) { 700 const int i = threadIdx.x; 701 const int j = threadIdx.y; 702 const int k = threadIdx.z; 703 const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 704 for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 705 const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 706 w[ind] = weight; 707 } 708 } 709 710 //------------------------------------------------------------------------------ 711 // Basis kernels 712 //------------------------------------------------------------------------------ 713 714 //------------------------------------------------------------------------------ 715 // Interp kernel by dim 716 //------------------------------------------------------------------------------ 717 extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp( 718 const CeedInt nelem, const int transpose, 719 CeedScalar *d_interp1d, 720 const CeedScalar *__restrict__ d_U, 721 CeedScalar *__restrict__ d_V) { 722 723 HIP_DYNAMIC_SHARED( CeedScalar, slice) 724 // load interp1d into shared memory 725 __shared__ CeedScalar s_B[P1D*Q1D]; 726 loadMatrix(d_interp1d, s_B); 727 __syncthreads(); 728 729 if (BASIS_DIM == 1) { 730 interp1d(nelem, transpose, s_B, d_U, d_V, slice); 731 } else if (BASIS_DIM == 2) { 732 interp2d(nelem, transpose, s_B, d_U, d_V, slice); 733 } else if (BASIS_DIM == 3) { 734 interp3d(nelem, transpose, s_B, d_U, d_V, slice); 735 } 736 } 737 738 //------------------------------------------------------------------------------ 739 // Grad kernel by dim 740 //------------------------------------------------------------------------------ 741 extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem, 742 const int transpose, 743 CeedScalar *d_interp1d, CeedScalar *d_grad1d, 744 const CeedScalar *__restrict__ d_U, 745 CeedScalar *__restrict__ d_V) { 746 HIP_DYNAMIC_SHARED( CeedScalar, slice) 747 // load interp1d and grad1d into shared memory 748 __shared__ CeedScalar s_B[P1D*Q1D]; 749 loadMatrix(d_interp1d, s_B); 750 __shared__ CeedScalar s_G[P1D*Q1D]; 751 loadMatrix(d_grad1d, s_G); 752 __syncthreads(); 753 754 if (BASIS_DIM == 1) { 755 grad1d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 756 } else if (BASIS_DIM == 2) { 757 grad2d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 758 } else if (BASIS_DIM == 3) { 759 grad3d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 760 } 761 } 762 763 //------------------------------------------------------------------------------ 764 // Weight kernels by dim 765 //------------------------------------------------------------------------------ 766 extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem, 767 const CeedScalar *__restrict__ qweight1d, 768 CeedScalar *__restrict__ v) { 769 if (BASIS_DIM == 1) { 770 weight1d(nelem, qweight1d, v); 771 } else if (BASIS_DIM == 2) { 772 weight2d(nelem, qweight1d, v); 773 } else if (BASIS_DIM == 3) { 774 weight3d(nelem, qweight1d, v); 775 } 776 } 777 778 ); 779 // *INDENT-ON* 780 781 //------------------------------------------------------------------------------ 782 // Compute a block size based on required minimum threads 783 //------------------------------------------------------------------------------ 784 static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) { 785 CeedInt maxSize = 1024; // Max total threads per block 786 CeedInt currentSize = 64; // Start with one group 787 788 while(currentSize < maxSize) { 789 if (currentSize > required) 790 break; 791 else 792 currentSize = currentSize * 2; 793 } 794 return currentSize; 795 } 796 797 //------------------------------------------------------------------------------ 798 // Compute required thread block sizes for basis kernels given P, Q, dim, and 799 // ncomp 800 //------------------------------------------------------------------------------ 801 static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d, 802 const CeedInt Q1d, 803 const CeedInt ncomp, CeedInt *blksizes) { 804 805 // Note that this will use the same block sizes for all dimensions when compiling, 806 // but as each basis object is defined for a particular dimension, we will never 807 // call any kernels except the ones for the dimension for which we have computed the 808 // block sizes. 809 const CeedInt thread1d = CeedIntMax(P1d, Q1d); 810 switch (dim) { 811 case 1: { 812 // Interp kernels: 813 blksizes[0] = 256; 814 815 // Grad kernels: 816 blksizes[1] = 256; 817 818 // Weight kernels: 819 blksizes[2] = 256; 820 821 } break; 822 case 2: { 823 // Interp kernels: 824 CeedInt required = thread1d * thread1d * ncomp; 825 blksizes[0] = ComputeBlockSizeFromRequirement(required); 826 827 // Grad kernels: currently use same required minimum threads 828 blksizes[1] = ComputeBlockSizeFromRequirement(required); 829 830 // Weight kernels: 831 required = CeedIntMax(64, Q1d * Q1d); 832 blksizes[2] = ComputeBlockSizeFromRequirement(required); 833 834 } break; 835 case 3: { 836 // Interp kernels: 837 CeedInt required = thread1d * thread1d * ncomp; 838 blksizes[0] = ComputeBlockSizeFromRequirement(required); 839 840 // Grad kernels: currently use same required minimum threads 841 blksizes[1] = ComputeBlockSizeFromRequirement(required); 842 843 // Weight kernels: 844 required = Q1d * Q1d * Q1d; 845 blksizes[2] = ComputeBlockSizeFromRequirement(required); 846 } 847 } 848 849 return CEED_ERROR_SUCCESS; 850 } 851 852 //------------------------------------------------------------------------------ 853 // Apply basis 854 //------------------------------------------------------------------------------ 855 int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem, 856 CeedTransposeMode tmode, 857 CeedEvalMode emode, CeedVector u, 858 CeedVector v) { 859 int ierr; 860 Ceed ceed; 861 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 862 Ceed_Hip *ceed_Hip; 863 CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr); 864 CeedBasis_Hip_shared *data; 865 CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 866 const CeedInt transpose = tmode == CEED_TRANSPOSE; 867 CeedInt dim, ncomp; 868 ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr); 869 ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 870 871 // Read vectors 872 const CeedScalar *d_u; 873 CeedScalar *d_v; 874 if (emode != CEED_EVAL_WEIGHT) { 875 ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr); 876 } 877 ierr = CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr); 878 879 // Clear v for transpose mode 880 if (tmode == CEED_TRANSPOSE) { 881 CeedInt length; 882 ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr); 883 ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr); 884 } 885 886 // Apply basis operation 887 switch (emode) { 888 case CEED_EVAL_INTERP: { 889 CeedInt P1d, Q1d; 890 CeedInt blksize = data->blksizes[0]; 891 ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 892 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 893 CeedInt thread1d = CeedIntMax(Q1d, P1d); 894 void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, 895 &d_u, &d_v 896 }; 897 if (dim == 1) { 898 CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 899 elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 900 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 901 ? 1 : 0 ); 902 CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 903 ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1, 904 elemsPerBlock, sharedMem, 905 interpargs); CeedChkBackend(ierr); 906 } else if (dim == 2) { 907 // Check if required threads is small enough to do multiple elems 908 const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 909 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 910 ? 1 : 0 ); 911 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 912 ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 913 ncomp*elemsPerBlock, sharedMem, 914 interpargs); CeedChkBackend(ierr); 915 } else if (dim == 3) { 916 CeedInt elemsPerBlock = 1; 917 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 918 ? 1 : 0 ); 919 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 920 ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 921 ncomp*elemsPerBlock, sharedMem, 922 interpargs); CeedChkBackend(ierr); 923 } 924 } break; 925 case CEED_EVAL_GRAD: { 926 CeedInt P1d, Q1d; 927 CeedInt blksize = data->blksizes[1]; 928 ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 929 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 930 CeedInt thread1d = CeedIntMax(Q1d, P1d); 931 void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, 932 &data->d_grad1d, &d_u, &d_v 933 }; 934 if (dim == 1) { 935 CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 936 elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 937 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 938 ? 1 : 0 ); 939 CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 940 ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1, 941 elemsPerBlock, sharedMem, gradargs); 942 CeedChkBackend(ierr); 943 } else if (dim == 2) { 944 // Check if required threads is small enough to do multiple elems 945 const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 946 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 947 ? 1 : 0 ); 948 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 949 ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 950 ncomp*elemsPerBlock, sharedMem, 951 gradargs); CeedChkBackend(ierr); 952 } else if (dim == 3) { 953 CeedInt elemsPerBlock = 1; 954 CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 955 ? 1 : 0 ); 956 CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 957 ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 958 ncomp*elemsPerBlock, sharedMem, 959 gradargs); CeedChkBackend(ierr); 960 } 961 } break; 962 case CEED_EVAL_WEIGHT: { 963 CeedInt Q1d; 964 CeedInt blksize = data->blksizes[2]; 965 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 966 void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 967 if (dim == 1) { 968 const CeedInt optElems = blksize/Q1d; 969 const CeedInt elemsPerBlock = optElems>0?optElems:1; 970 const CeedInt gridsize = nelem/elemsPerBlock + ( ( 971 nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 972 ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, 973 elemsPerBlock, 1, weightargs); 974 CeedChkBackend(ierr); 975 } else if (dim == 2) { 976 const CeedInt optElems = blksize/(Q1d*Q1d); 977 const CeedInt elemsPerBlock = optElems>0?optElems:1; 978 const CeedInt gridsize = nelem/elemsPerBlock + ( ( 979 nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 980 ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, 981 elemsPerBlock, weightargs); 982 CeedChkBackend(ierr); 983 } else if (dim == 3) { 984 const CeedInt gridsize = nelem; 985 ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 986 weightargs); 987 CeedChkBackend(ierr); 988 } 989 } break; 990 // LCOV_EXCL_START 991 // Evaluate the divergence to/from the quadrature points 992 case CEED_EVAL_DIV: 993 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported"); 994 // Evaluate the curl to/from the quadrature points 995 case CEED_EVAL_CURL: 996 return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported"); 997 // Take no action, BasisApply should not have been called 998 case CEED_EVAL_NONE: 999 return CeedError(ceed, CEED_ERROR_BACKEND, 1000 "CEED_EVAL_NONE does not make sense in this context"); 1001 // LCOV_EXCL_STOP 1002 } 1003 1004 // Restore vectors 1005 if (emode != CEED_EVAL_WEIGHT) { 1006 ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr); 1007 } 1008 ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr); 1009 return CEED_ERROR_SUCCESS; 1010 } 1011 1012 //------------------------------------------------------------------------------ 1013 // Destroy basis 1014 //------------------------------------------------------------------------------ 1015 static int CeedBasisDestroy_Hip_shared(CeedBasis basis) { 1016 int ierr; 1017 Ceed ceed; 1018 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 1019 1020 CeedBasis_Hip_shared *data; 1021 ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 1022 1023 CeedChk_Hip(ceed, hipModuleUnload(data->module)); 1024 1025 ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr); 1026 ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr); 1027 ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr); 1028 ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr); 1029 1030 ierr = CeedFree(&data); CeedChkBackend(ierr); 1031 1032 return CEED_ERROR_SUCCESS; 1033 } 1034 1035 //------------------------------------------------------------------------------ 1036 // Create tensor basis 1037 //------------------------------------------------------------------------------ 1038 int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 1039 const CeedScalar *interp1d, 1040 const CeedScalar *grad1d, 1041 const CeedScalar *qref1d, 1042 const CeedScalar *qweight1d, 1043 CeedBasis basis) { 1044 int ierr; 1045 Ceed ceed; 1046 ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 1047 CeedBasis_Hip_shared *data; 1048 ierr = CeedCalloc(1, &data); CeedChkBackend(ierr); 1049 1050 // Copy basis data to GPU 1051 const CeedInt qBytes = Q1d * sizeof(CeedScalar); 1052 ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr); 1053 ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes, 1054 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1055 1056 const CeedInt iBytes = qBytes * P1d; 1057 ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr); 1058 ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes, 1059 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1060 1061 ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr); 1062 ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes, 1063 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1064 1065 // Compute collocated gradient and copy to GPU 1066 data->d_collograd1d = NULL; 1067 if (dim == 3 && Q1d >= P1d) { 1068 CeedScalar *collograd1d; 1069 ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr); 1070 ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr); 1071 ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 1072 CeedChk_Hip(ceed, ierr); 1073 ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 1074 hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1075 ierr = CeedFree(&collograd1d); CeedChkBackend(ierr); 1076 } 1077 1078 // Set number of threads per block for basis kernels 1079 CeedInt ncomp; 1080 ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 1081 ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes); 1082 CeedChkBackend(ierr); 1083 1084 // Compile basis kernels 1085 ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11, 1086 "Q1D", Q1d, 1087 "P1D", P1d, 1088 "T1D", CeedIntMax(Q1d, P1d), 1089 "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 1090 Q1d : P1d, dim), 1091 "BASIS_DIM", dim, 1092 "BASIS_NCOMP", ncomp, 1093 "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 1094 "BASIS_NQPT", CeedIntPow(Q1d, dim), 1095 "INTERP_BLKSIZE", data->blksizes[0], 1096 "GRAD_BLKSIZE", data->blksizes[1], 1097 "WEIGHT_BLKSIZE", data->blksizes[2] 1098 ); CeedChkBackend(ierr); 1099 ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp); 1100 CeedChkBackend(ierr); 1101 ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad); 1102 CeedChkBackend(ierr); 1103 ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight); 1104 CeedChkBackend(ierr); 1105 1106 ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr); 1107 1108 // Register backend functions 1109 ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 1110 CeedBasisApplyTensor_Hip_shared); 1111 CeedChkBackend(ierr); 1112 ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 1113 CeedBasisDestroy_Hip_shared); CeedChkBackend(ierr); 1114 return CEED_ERROR_SUCCESS; 1115 } 1116 //------------------------------------------------------------------------------ 1117