xref: /libCEED/include/ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h (revision fb455ff073519dc60531e3d0b72267e590b5c938)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 
10 //------------------------------------------------------------------------------
11 // Diagonal assembly kernels
12 //------------------------------------------------------------------------------
13 
14 typedef enum {
15   /// Perform no evaluation (either because there is no data or it is already at
16   /// quadrature points)
17   CEED_EVAL_NONE   = 0,
18   /// Interpolate from nodes to quadrature points
19   CEED_EVAL_INTERP = 1,
20   /// Evaluate gradients at quadrature points from input in a nodal basis
21   CEED_EVAL_GRAD   = 2,
22   /// Evaluate divergence at quadrature points from input in a nodal basis
23   CEED_EVAL_DIV    = 4,
24   /// Evaluate curl at quadrature points from input in a nodal basis
25   CEED_EVAL_CURL   = 8,
26   /// Using no input, evaluate quadrature weights on the reference element
27   CEED_EVAL_WEIGHT = 16,
28 } CeedEvalMode;
29 
30 //------------------------------------------------------------------------------
31 // Get Basis Emode Pointer
32 //------------------------------------------------------------------------------
33 extern "C" __device__ void CeedOperatorGetBasisPointer_Cuda(const CeedScalar **basisptr,
34     CeedEvalMode emode, const CeedScalar *identity, const CeedScalar *interp,
35     const CeedScalar *grad) {
36   switch (emode) {
37   case CEED_EVAL_NONE:
38     *basisptr = identity;
39     break;
40   case CEED_EVAL_INTERP:
41     *basisptr = interp;
42     break;
43   case CEED_EVAL_GRAD:
44     *basisptr = grad;
45     break;
46   case CEED_EVAL_WEIGHT:
47   case CEED_EVAL_DIV:
48   case CEED_EVAL_CURL:
49     break; // Caught by QF Assembly
50   }
51 }
52 
53 //------------------------------------------------------------------------------
54 // Core code for diagonal assembly
55 //------------------------------------------------------------------------------
56 __device__ void diagonalCore(const CeedInt nelem,
57     const bool pointBlock, const CeedScalar *identity,
58     const CeedScalar *interpin, const CeedScalar *gradin,
59     const CeedScalar *interpout, const CeedScalar *gradout,
60     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
61     const CeedScalar *__restrict__ assembledqfarray,
62     CeedScalar *__restrict__ elemdiagarray) {
63   const int tid = threadIdx.x; // running with P threads, tid is evec node
64   if (tid >= NNODES) return;
65 
66   // Compute the diagonal of B^T D B
67   // Each element
68   for (CeedInt e = blockIdx.x*blockDim.z + threadIdx.z; e < nelem;
69        e += gridDim.x*blockDim.z) {
70     CeedInt dout = -1;
71     // Each basis eval mode pair
72     for (CeedInt eout = 0; eout < NUMEMODEOUT; eout++) {
73       const CeedScalar *bt = NULL;
74       if (emodeout[eout] == CEED_EVAL_GRAD)
75         dout += 1;
76       CeedOperatorGetBasisPointer_Cuda(&bt, emodeout[eout], identity, interpout,
77                                       &gradout[dout*NQPTS*NNODES]);
78       CeedInt din = -1;
79       for (CeedInt ein = 0; ein < NUMEMODEIN; ein++) {
80         const CeedScalar *b = NULL;
81         if (emodein[ein] == CEED_EVAL_GRAD)
82           din += 1;
83         CeedOperatorGetBasisPointer_Cuda(&b, emodein[ein], identity, interpin,
84                                         &gradin[din*NQPTS*NNODES]);
85         // Each component
86         for (CeedInt compOut = 0; compOut < NCOMP; compOut++) {
87           // Each qpoint/node pair
88           if (pointBlock) {
89             // Point Block Diagonal
90             for (CeedInt compIn = 0; compIn < NCOMP; compIn++) {
91               CeedScalar evalue = 0.;
92               for (CeedInt q = 0; q < NQPTS; q++) {
93                 const CeedScalar qfvalue =
94                   assembledqfarray[((((ein*NCOMP+compIn)*NUMEMODEOUT+eout)*
95                                      NCOMP+compOut)*nelem+e)*NQPTS+q];
96                 evalue += bt[q*NNODES+tid] * qfvalue * b[q*NNODES+tid];
97               }
98               elemdiagarray[((compOut*NCOMP+compIn)*nelem+e)*NNODES+tid] += evalue;
99             }
100           } else {
101             // Diagonal Only
102             CeedScalar evalue = 0.;
103             for (CeedInt q = 0; q < NQPTS; q++) {
104               const CeedScalar qfvalue =
105                 assembledqfarray[((((ein*NCOMP+compOut)*NUMEMODEOUT+eout)*
106                                    NCOMP+compOut)*nelem+e)*NQPTS+q];
107               evalue += bt[q*NNODES+tid] * qfvalue * b[q*NNODES+tid];
108             }
109             elemdiagarray[(compOut*nelem+e)*NNODES+tid] += evalue;
110           }
111         }
112       }
113     }
114   }
115 }
116 
117 //------------------------------------------------------------------------------
118 // Linear diagonal
119 //------------------------------------------------------------------------------
120 extern "C" __global__ void linearDiagonal(const CeedInt nelem,
121     const CeedScalar *identity,
122     const CeedScalar *interpin, const CeedScalar *gradin,
123     const CeedScalar *interpout, const CeedScalar *gradout,
124     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
125     const CeedScalar *__restrict__ assembledqfarray,
126     CeedScalar *__restrict__ elemdiagarray) {
127   diagonalCore(nelem, false, identity, interpin, gradin, interpout,
128                gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
129 }
130 
131 //------------------------------------------------------------------------------
132 // Linear point block diagonal
133 //------------------------------------------------------------------------------
134 extern "C" __global__ void linearPointBlockDiagonal(const CeedInt nelem,
135     const CeedScalar *identity,
136     const CeedScalar *interpin, const CeedScalar *gradin,
137     const CeedScalar *interpout, const CeedScalar *gradout,
138     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
139     const CeedScalar *__restrict__ assembledqfarray,
140     CeedScalar *__restrict__ elemdiagarray) {
141   diagonalCore(nelem, true, identity, interpin, gradin, interpout,
142                gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
143 }
144 
145 //------------------------------------------------------------------------------
146