xref: /libCEED/include/ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h (revision 2288fb5222bbca88523f94a06377dc76b8b46264)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 
10 //------------------------------------------------------------------------------
11 // Diagonal assembly kernels
12 //------------------------------------------------------------------------------
13 typedef enum {
14   /// Perform no evaluation (either because there is no data or it is already at
15   /// quadrature points)
16   CEED_EVAL_NONE   = 0,
17   /// Interpolate from nodes to quadrature points
18   CEED_EVAL_INTERP = 1,
19   /// Evaluate gradients at quadrature points from input in a nodal basis
20   CEED_EVAL_GRAD   = 2,
21   /// Evaluate divergence at quadrature points from input in a nodal basis
22   CEED_EVAL_DIV    = 4,
23   /// Evaluate curl at quadrature points from input in a nodal basis
24   CEED_EVAL_CURL   = 8,
25   /// Using no input, evaluate quadrature weights on the reference element
26   CEED_EVAL_WEIGHT = 16,
27 } CeedEvalMode;
28 
29 //------------------------------------------------------------------------------
30 // Get Basis Emode Pointer
31 //------------------------------------------------------------------------------
32 extern "C" __device__ void CeedOperatorGetBasisPointer_Hip(const CeedScalar **basisptr,
33     CeedEvalMode emode, const CeedScalar *identity, const CeedScalar *interp,
34     const CeedScalar *grad) {
35   switch (emode) {
36   case CEED_EVAL_NONE:
37     *basisptr = identity;
38     break;
39   case CEED_EVAL_INTERP:
40     *basisptr = interp;
41     break;
42   case CEED_EVAL_GRAD:
43     *basisptr = grad;
44     break;
45   case CEED_EVAL_WEIGHT:
46   case CEED_EVAL_DIV:
47   case CEED_EVAL_CURL:
48     break; // Caught by QF Assembly
49   }
50 }
51 
52 //------------------------------------------------------------------------------
53 // Core code for diagonal assembly
54 //------------------------------------------------------------------------------
55 __device__ void diagonalCore(const CeedInt nelem,
56     const CeedScalar maxnorm, const bool pointBlock,
57     const CeedScalar *identity,
58     const CeedScalar *interpin, const CeedScalar *gradin,
59     const CeedScalar *interpout, const CeedScalar *gradout,
60     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
61     const CeedScalar *__restrict__ assembledqfarray,
62     CeedScalar *__restrict__ elemdiagarray) {
63   const int tid = threadIdx.x; // running with P threads, tid is evec node
64   const CeedScalar qfvaluebound = maxnorm*1e-12;
65 
66   // Compute the diagonal of B^T D B
67   // Each element
68   for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < nelem;
69        e += gridDim.x*blockDim.z) {
70     CeedInt dout = -1;
71     // Each basis eval mode pair
72     for (CeedInt eout = 0; eout < NUMEMODEOUT; eout++) {
73       const CeedScalar *bt = NULL;
74       if (emodeout[eout] == CEED_EVAL_GRAD)
75         dout += 1;
76       CeedOperatorGetBasisPointer_Hip(&bt, emodeout[eout], identity, interpout,
77                                       &gradout[dout*NQPTS*NNODES]);
78       CeedInt din = -1;
79       for (CeedInt ein = 0; ein < NUMEMODEIN; ein++) {
80         const CeedScalar *b = NULL;
81         if (emodein[ein] == CEED_EVAL_GRAD)
82           din += 1;
83         CeedOperatorGetBasisPointer_Hip(&b, emodein[ein], identity, interpin,
84                                         &gradin[din*NQPTS*NNODES]);
85         // Each component
86         for (CeedInt compOut = 0; compOut < NCOMP; compOut++) {
87           // Each qpoint/node pair
88           if (pointBlock) {
89             // Point Block Diagonal
90             for (CeedInt compIn = 0; compIn < NCOMP; compIn++) {
91               CeedScalar evalue = 0.;
92               for (CeedInt q = 0; q < NQPTS; q++) {
93                 const CeedScalar qfvalue =
94                   assembledqfarray[((((ein*NCOMP+compIn)*NUMEMODEOUT+eout)*
95                                      NCOMP+compOut)*nelem+e)*NQPTS+q];
96                 if (abs(qfvalue) > qfvaluebound)
97                   evalue += bt[q*NNODES+tid] * qfvalue * b[q*NNODES+tid];
98               }
99               elemdiagarray[((compOut*NCOMP+compIn)*nelem+e)*NNODES+tid] += evalue;
100             }
101           } else {
102             // Diagonal Only
103             CeedScalar evalue = 0.;
104             for (CeedInt q = 0; q < NQPTS; q++) {
105               const CeedScalar qfvalue =
106                 assembledqfarray[((((ein*NCOMP+compOut)*NUMEMODEOUT+eout)*
107                                    NCOMP+compOut)*nelem+e)*NQPTS+q];
108               if (abs(qfvalue) > qfvaluebound)
109                 evalue += bt[q*NNODES+tid] * qfvalue * b[q*NNODES+tid];
110             }
111             elemdiagarray[(compOut*nelem+e)*NNODES+tid] += evalue;
112           }
113         }
114       }
115     }
116   }
117 }
118 
119 //------------------------------------------------------------------------------
120 // Linear diagonal
121 //------------------------------------------------------------------------------
122 extern "C" __global__ void linearDiagonal(const CeedInt nelem,
123     const CeedScalar maxnorm, const CeedScalar *identity,
124     const CeedScalar *interpin, const CeedScalar *gradin,
125     const CeedScalar *interpout, const CeedScalar *gradout,
126     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
127     const CeedScalar *__restrict__ assembledqfarray,
128     CeedScalar *__restrict__ elemdiagarray) {
129   diagonalCore(nelem, maxnorm, false, identity, interpin, gradin, interpout,
130                gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
131 }
132 
133 //------------------------------------------------------------------------------
134 // Linear point block diagonal
135 //------------------------------------------------------------------------------
136 extern "C" __global__ void linearPointBlockDiagonal(const CeedInt nelem,
137     const CeedScalar maxnorm, const CeedScalar *identity,
138     const CeedScalar *interpin, const CeedScalar *gradin,
139     const CeedScalar *interpout, const CeedScalar *gradout,
140     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
141     const CeedScalar *__restrict__ assembledqfarray,
142     CeedScalar *__restrict__ elemdiagarray) {
143   diagonalCore(nelem, maxnorm, true, identity, interpin, gradin, interpout,
144                gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
145 }
146 
147 //------------------------------------------------------------------------------
148