xref: /libCEED/include/ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h (revision caee03026e6576cbf7a399c2fc51bb918c77f451)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 
10 //------------------------------------------------------------------------------
11 // Diagonal assembly kernels
12 //------------------------------------------------------------------------------
13 
14 typedef enum {
15   /// Perform no evaluation (either because there is no data or it is already at
16   /// quadrature points)
17   CEED_EVAL_NONE = 0,
18   /// Interpolate from nodes to quadrature points
19   CEED_EVAL_INTERP = 1,
20   /// Evaluate gradients at quadrature points from input in a nodal basis
21   CEED_EVAL_GRAD = 2,
22   /// Evaluate divergence at quadrature points from input in a nodal basis
23   CEED_EVAL_DIV = 4,
24   /// Evaluate curl at quadrature points from input in a nodal basis
25   CEED_EVAL_CURL = 8,
26   /// Using no input, evaluate quadrature weights on the reference element
27   CEED_EVAL_WEIGHT = 16,
28 } CeedEvalMode;
29 
30 //------------------------------------------------------------------------------
31 // Get Basis Emode Pointer
32 //------------------------------------------------------------------------------
33 extern "C" __device__ void CeedOperatorGetBasisPointer_Cuda(const CeedScalar **basisptr, CeedEvalMode emode, const CeedScalar *identity,
34                                                             const CeedScalar *interp, const CeedScalar *grad) {
35   switch (emode) {
36     case CEED_EVAL_NONE:
37       *basisptr = identity;
38       break;
39     case CEED_EVAL_INTERP:
40       *basisptr = interp;
41       break;
42     case CEED_EVAL_GRAD:
43       *basisptr = grad;
44       break;
45     case CEED_EVAL_WEIGHT:
46     case CEED_EVAL_DIV:
47     case CEED_EVAL_CURL:
48       break;  // Caught by QF Assembly
49   }
50 }
51 
52 //------------------------------------------------------------------------------
53 // Core code for diagonal assembly
54 //------------------------------------------------------------------------------
55 __device__ void diagonalCore(const CeedInt nelem, const bool pointBlock, const CeedScalar *identity, const CeedScalar *interpin,
56                              const CeedScalar *gradin, const CeedScalar *interpout, const CeedScalar *gradout, const CeedEvalMode *emodein,
57                              const CeedEvalMode *emodeout, const CeedScalar *__restrict__ assembledqfarray, CeedScalar *__restrict__ elemdiagarray) {
58   const int tid = threadIdx.x;  // running with P threads, tid is evec node
59   if (tid >= NNODES) return;
60 
61   // Compute the diagonal of B^T D B
62   // Each element
63   for (CeedInt e = blockIdx.x * blockDim.z + threadIdx.z; e < nelem; e += gridDim.x * blockDim.z) {
64     CeedInt dout = -1;
65     // Each basis eval mode pair
66     for (CeedInt eout = 0; eout < NUMEMODEOUT; eout++) {
67       const CeedScalar *bt = NULL;
68       if (emodeout[eout] == CEED_EVAL_GRAD) dout += 1;
69       CeedOperatorGetBasisPointer_Cuda(&bt, emodeout[eout], identity, interpout, &gradout[dout * NQPTS * NNODES]);
70       CeedInt din = -1;
71       for (CeedInt ein = 0; ein < NUMEMODEIN; ein++) {
72         const CeedScalar *b = NULL;
73         if (emodein[ein] == CEED_EVAL_GRAD) din += 1;
74         CeedOperatorGetBasisPointer_Cuda(&b, emodein[ein], identity, interpin, &gradin[din * NQPTS * NNODES]);
75         // Each component
76         for (CeedInt compOut = 0; compOut < NCOMP; compOut++) {
77           // Each qpoint/node pair
78           if (pointBlock) {
79             // Point Block Diagonal
80             for (CeedInt compIn = 0; compIn < NCOMP; compIn++) {
81               CeedScalar evalue = 0.;
82               for (CeedInt q = 0; q < NQPTS; q++) {
83                 const CeedScalar qfvalue =
84                     assembledqfarray[((((ein * NCOMP + compIn) * NUMEMODEOUT + eout) * NCOMP + compOut) * nelem + e) * NQPTS + q];
85                 evalue += bt[q * NNODES + tid] * qfvalue * b[q * NNODES + tid];
86               }
87               elemdiagarray[((compOut * NCOMP + compIn) * nelem + e) * NNODES + tid] += evalue;
88             }
89           } else {
90             // Diagonal Only
91             CeedScalar evalue = 0.;
92             for (CeedInt q = 0; q < NQPTS; q++) {
93               const CeedScalar qfvalue =
94                   assembledqfarray[((((ein * NCOMP + compOut) * NUMEMODEOUT + eout) * NCOMP + compOut) * nelem + e) * NQPTS + q];
95               evalue += bt[q * NNODES + tid] * qfvalue * b[q * NNODES + tid];
96             }
97             elemdiagarray[(compOut * nelem + e) * NNODES + tid] += evalue;
98           }
99         }
100       }
101     }
102   }
103 }
104 
105 //------------------------------------------------------------------------------
106 // Linear diagonal
107 //------------------------------------------------------------------------------
108 extern "C" __global__ void linearDiagonal(const CeedInt nelem, const CeedScalar *identity, const CeedScalar *interpin, const CeedScalar *gradin,
109                                           const CeedScalar *interpout, const CeedScalar *gradout, const CeedEvalMode *emodein,
110                                           const CeedEvalMode *emodeout, const CeedScalar *__restrict__ assembledqfarray,
111                                           CeedScalar *__restrict__ elemdiagarray) {
112   diagonalCore(nelem, false, identity, interpin, gradin, interpout, gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
113 }
114 
115 //------------------------------------------------------------------------------
116 // Linear point block diagonal
117 //------------------------------------------------------------------------------
118 extern "C" __global__ void linearPointBlockDiagonal(const CeedInt nelem, const CeedScalar *identity, const CeedScalar *interpin,
119                                                     const CeedScalar *gradin, const CeedScalar *interpout, const CeedScalar *gradout,
120                                                     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
121                                                     const CeedScalar *__restrict__ assembledqfarray, CeedScalar *__restrict__ elemdiagarray) {
122   diagonalCore(nelem, true, identity, interpin, gradin, interpout, gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
123 }
124 
125 //------------------------------------------------------------------------------
126