xref: /libCEED/include/ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h (revision a71fcd9fac4e7a8dfa69a197fd7b41b8f31fd6a3)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 
10 //------------------------------------------------------------------------------
11 // Diagonal assembly kernels
12 //------------------------------------------------------------------------------
13 
14 typedef enum {
15   /// Perform no evaluation (either because there is no data or it is already at
16   /// quadrature points)
17   CEED_EVAL_NONE   = 0,
18   /// Interpolate from nodes to quadrature points
19   CEED_EVAL_INTERP = 1,
20   /// Evaluate gradients at quadrature points from input in a nodal basis
21   CEED_EVAL_GRAD   = 2,
22   /// Evaluate divergence at quadrature points from input in a nodal basis
23   CEED_EVAL_DIV    = 4,
24   /// Evaluate curl at quadrature points from input in a nodal basis
25   CEED_EVAL_CURL   = 8,
26   /// Using no input, evaluate quadrature weights on the reference element
27   CEED_EVAL_WEIGHT = 16,
28 } CeedEvalMode;
29 
30 //------------------------------------------------------------------------------
31 // Get Basis Emode Pointer
32 //------------------------------------------------------------------------------
33 extern "C" __device__ void CeedOperatorGetBasisPointer_Cuda(const CeedScalar **basisptr,
34     CeedEvalMode emode, const CeedScalar *identity, const CeedScalar *interp,
35     const CeedScalar *grad) {
36   switch (emode) {
37   case CEED_EVAL_NONE:
38     *basisptr = identity;
39     break;
40   case CEED_EVAL_INTERP:
41     *basisptr = interp;
42     break;
43   case CEED_EVAL_GRAD:
44     *basisptr = grad;
45     break;
46   case CEED_EVAL_WEIGHT:
47   case CEED_EVAL_DIV:
48   case CEED_EVAL_CURL:
49     break; // Caught by QF Assembly
50   }
51 }
52 
53 //------------------------------------------------------------------------------
54 // Core code for diagonal assembly
55 //------------------------------------------------------------------------------
56 __device__ void diagonalCore(const CeedInt nelem,
57     const CeedScalar maxnorm, const bool pointBlock,
58     const CeedScalar *identity,
59     const CeedScalar *interpin, const CeedScalar *gradin,
60     const CeedScalar *interpout, const CeedScalar *gradout,
61     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
62     const CeedScalar *__restrict__ assembledqfarray,
63     CeedScalar *__restrict__ elemdiagarray) {
64   const int tid = threadIdx.x; // running with P threads, tid is evec node
65   const CeedScalar qfvaluebound = maxnorm*1e-12;
66 
67   // Compute the diagonal of B^T D B
68   // Each element
69   for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < nelem;
70        e += gridDim.x*blockDim.z) {
71     CeedInt dout = -1;
72     // Each basis eval mode pair
73     for (CeedInt eout = 0; eout < NUMEMODEOUT; eout++) {
74       const CeedScalar *bt = NULL;
75       if (emodeout[eout] == CEED_EVAL_GRAD)
76         dout += 1;
77       CeedOperatorGetBasisPointer_Cuda(&bt, emodeout[eout], identity, interpout,
78                                       &gradout[dout*NQPTS*NNODES]);
79       CeedInt din = -1;
80       for (CeedInt ein = 0; ein < NUMEMODEIN; ein++) {
81         const CeedScalar *b = NULL;
82         if (emodein[ein] == CEED_EVAL_GRAD)
83           din += 1;
84         CeedOperatorGetBasisPointer_Cuda(&b, emodein[ein], identity, interpin,
85                                         &gradin[din*NQPTS*NNODES]);
86         // Each component
87         for (CeedInt compOut = 0; compOut < NCOMP; compOut++) {
88           // Each qpoint/node pair
89           if (pointBlock) {
90             // Point Block Diagonal
91             for (CeedInt compIn = 0; compIn < NCOMP; compIn++) {
92               CeedScalar evalue = 0.;
93               for (CeedInt q = 0; q < NQPTS; q++) {
94                 const CeedScalar qfvalue =
95                   assembledqfarray[((((ein*NCOMP+compIn)*NUMEMODEOUT+eout)*
96                                      NCOMP+compOut)*nelem+e)*NQPTS+q];
97                 if (abs(qfvalue) > qfvaluebound)
98                   evalue += bt[q*NNODES+tid] * qfvalue * b[q*NNODES+tid];
99               }
100               elemdiagarray[((compOut*NCOMP+compIn)*nelem+e)*NNODES+tid] += evalue;
101             }
102           } else {
103             // Diagonal Only
104             CeedScalar evalue = 0.;
105             for (CeedInt q = 0; q < NQPTS; q++) {
106               const CeedScalar qfvalue =
107                 assembledqfarray[((((ein*NCOMP+compOut)*NUMEMODEOUT+eout)*
108                                    NCOMP+compOut)*nelem+e)*NQPTS+q];
109               if (abs(qfvalue) > qfvaluebound)
110                 evalue += bt[q*NNODES+tid] * qfvalue * b[q*NNODES+tid];
111             }
112             elemdiagarray[(compOut*nelem+e)*NNODES+tid] += evalue;
113           }
114         }
115       }
116     }
117   }
118 }
119 
120 //------------------------------------------------------------------------------
121 // Linear diagonal
122 //------------------------------------------------------------------------------
123 extern "C" __global__ void linearDiagonal(const CeedInt nelem,
124     const CeedScalar maxnorm, const CeedScalar *identity,
125     const CeedScalar *interpin, const CeedScalar *gradin,
126     const CeedScalar *interpout, const CeedScalar *gradout,
127     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
128     const CeedScalar *__restrict__ assembledqfarray,
129     CeedScalar *__restrict__ elemdiagarray) {
130   diagonalCore(nelem, maxnorm, false, identity, interpin, gradin, interpout,
131                gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
132 }
133 
134 //------------------------------------------------------------------------------
135 // Linear point block diagonal
136 //------------------------------------------------------------------------------
137 extern "C" __global__ void linearPointBlockDiagonal(const CeedInt nelem,
138     const CeedScalar maxnorm, const CeedScalar *identity,
139     const CeedScalar *interpin, const CeedScalar *gradin,
140     const CeedScalar *interpout, const CeedScalar *gradout,
141     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
142     const CeedScalar *__restrict__ assembledqfarray,
143     CeedScalar *__restrict__ elemdiagarray) {
144   diagonalCore(nelem, maxnorm, true, identity, interpin, gradin, interpout,
145                gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
146 }
147 
148 //------------------------------------------------------------------------------
149