xref: /libCEED/include/ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h (revision 9bd0a4de615e3a1434ac7c89598f4ee8661d99f4)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 
10 //------------------------------------------------------------------------------
11 // Get Basis Emode Pointer
12 //------------------------------------------------------------------------------
13 extern "C" __device__ void CeedOperatorGetBasisPointer_Hip(const CeedScalar **basisptr, CeedEvalMode emode, const CeedScalar *identity,
14                                                            const CeedScalar *interp, const CeedScalar *grad) {
15   switch (emode) {
16     case CEED_EVAL_NONE:
17       *basisptr = identity;
18       break;
19     case CEED_EVAL_INTERP:
20       *basisptr = interp;
21       break;
22     case CEED_EVAL_GRAD:
23       *basisptr = grad;
24       break;
25     case CEED_EVAL_WEIGHT:
26     case CEED_EVAL_DIV:
27     case CEED_EVAL_CURL:
28       break;  // Caught by QF Assembly
29   }
30 }
31 
32 //------------------------------------------------------------------------------
33 // Core code for diagonal assembly
34 //------------------------------------------------------------------------------
35 __device__ void diagonalCore(const CeedInt nelem, const bool pointBlock, const CeedScalar *identity, const CeedScalar *interpin,
36                              const CeedScalar *gradin, const CeedScalar *interpout, const CeedScalar *gradout, const CeedEvalMode *emodein,
37                              const CeedEvalMode *emodeout, const CeedScalar *__restrict__ assembledqfarray, CeedScalar *__restrict__ elemdiagarray) {
38   const int tid = threadIdx.x;  // running with P threads, tid is evec node
39   if (tid >= NNODES) return;
40 
41   // Compute the diagonal of B^T D B
42   // Each element
43   for (CeedInt e = blockIdx.x * blockDim.z + threadIdx.z; e < nelem; e += gridDim.x * blockDim.z) {
44     CeedInt dout = -1;
45     // Each basis eval mode pair
46     for (CeedInt eout = 0; eout < NUMEMODEOUT; eout++) {
47       const CeedScalar *bt = NULL;
48       if (emodeout[eout] == CEED_EVAL_GRAD) dout += 1;
49       CeedOperatorGetBasisPointer_Hip(&bt, emodeout[eout], identity, interpout, &gradout[dout * NQPTS * NNODES]);
50       CeedInt din = -1;
51       for (CeedInt ein = 0; ein < NUMEMODEIN; ein++) {
52         const CeedScalar *b = NULL;
53         if (emodein[ein] == CEED_EVAL_GRAD) din += 1;
54         CeedOperatorGetBasisPointer_Hip(&b, emodein[ein], identity, interpin, &gradin[din * NQPTS * NNODES]);
55         // Each component
56         for (CeedInt compOut = 0; compOut < NCOMP; compOut++) {
57           // Each qpoint/node pair
58           if (pointBlock) {
59             // Point Block Diagonal
60             for (CeedInt compIn = 0; compIn < NCOMP; compIn++) {
61               CeedScalar evalue = 0.;
62               for (CeedInt q = 0; q < NQPTS; q++) {
63                 const CeedScalar qfvalue =
64                     assembledqfarray[((((ein * NCOMP + compIn) * NUMEMODEOUT + eout) * NCOMP + compOut) * nelem + e) * NQPTS + q];
65                 evalue += bt[q * NNODES + tid] * qfvalue * b[q * NNODES + tid];
66               }
67               elemdiagarray[((compOut * NCOMP + compIn) * nelem + e) * NNODES + tid] += evalue;
68             }
69           } else {
70             // Diagonal Only
71             CeedScalar evalue = 0.;
72             for (CeedInt q = 0; q < NQPTS; q++) {
73               const CeedScalar qfvalue =
74                   assembledqfarray[((((ein * NCOMP + compOut) * NUMEMODEOUT + eout) * NCOMP + compOut) * nelem + e) * NQPTS + q];
75               evalue += bt[q * NNODES + tid] * qfvalue * b[q * NNODES + tid];
76             }
77             elemdiagarray[(compOut * nelem + e) * NNODES + tid] += evalue;
78           }
79         }
80       }
81     }
82   }
83 }
84 
85 //------------------------------------------------------------------------------
86 // Linear diagonal
87 //------------------------------------------------------------------------------
88 extern "C" __global__ void linearDiagonal(const CeedInt nelem, const CeedScalar *identity, const CeedScalar *interpin, const CeedScalar *gradin,
89                                           const CeedScalar *interpout, const CeedScalar *gradout, const CeedEvalMode *emodein,
90                                           const CeedEvalMode *emodeout, const CeedScalar *__restrict__ assembledqfarray,
91                                           CeedScalar *__restrict__ elemdiagarray) {
92   diagonalCore(nelem, false, identity, interpin, gradin, interpout, gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
93 }
94 
95 //------------------------------------------------------------------------------
96 // Linear point block diagonal
97 //------------------------------------------------------------------------------
98 extern "C" __global__ void linearPointBlockDiagonal(const CeedInt nelem, const CeedScalar *identity, const CeedScalar *interpin,
99                                                     const CeedScalar *gradin, const CeedScalar *interpout, const CeedScalar *gradout,
100                                                     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
101                                                     const CeedScalar *__restrict__ assembledqfarray, CeedScalar *__restrict__ elemdiagarray) {
102   diagonalCore(nelem, true, identity, interpin, gradin, interpout, gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
103 }
104 
105 //------------------------------------------------------------------------------
106