xref: /libCEED/include/ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h (revision 2b730f8b5a9c809740a0b3b302db43a719c636b1)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 
10 //------------------------------------------------------------------------------
11 // Diagonal assembly kernels
12 //------------------------------------------------------------------------------
13 typedef enum {
14   /// Perform no evaluation (either because there is no data or it is already at
15   /// quadrature points)
16   CEED_EVAL_NONE = 0,
17   /// Interpolate from nodes to quadrature points
18   CEED_EVAL_INTERP = 1,
19   /// Evaluate gradients at quadrature points from input in a nodal basis
20   CEED_EVAL_GRAD = 2,
21   /// Evaluate divergence at quadrature points from input in a nodal basis
22   CEED_EVAL_DIV = 4,
23   /// Evaluate curl at quadrature points from input in a nodal basis
24   CEED_EVAL_CURL = 8,
25   /// Using no input, evaluate quadrature weights on the reference element
26   CEED_EVAL_WEIGHT = 16,
27 } CeedEvalMode;
28 
29 //------------------------------------------------------------------------------
30 // Get Basis Emode Pointer
31 //------------------------------------------------------------------------------
32 extern "C" __device__ void CeedOperatorGetBasisPointer_Hip(const CeedScalar **basisptr, CeedEvalMode emode, const CeedScalar *identity,
33                                                            const CeedScalar *interp, const CeedScalar *grad) {
34   switch (emode) {
35     case CEED_EVAL_NONE:
36       *basisptr = identity;
37       break;
38     case CEED_EVAL_INTERP:
39       *basisptr = interp;
40       break;
41     case CEED_EVAL_GRAD:
42       *basisptr = grad;
43       break;
44     case CEED_EVAL_WEIGHT:
45     case CEED_EVAL_DIV:
46     case CEED_EVAL_CURL:
47       break;  // Caught by QF Assembly
48   }
49 }
50 
51 //------------------------------------------------------------------------------
52 // Core code for diagonal assembly
53 //------------------------------------------------------------------------------
54 __device__ void diagonalCore(const CeedInt nelem, const bool pointBlock, const CeedScalar *identity, const CeedScalar *interpin,
55                              const CeedScalar *gradin, const CeedScalar *interpout, const CeedScalar *gradout, const CeedEvalMode *emodein,
56                              const CeedEvalMode *emodeout, const CeedScalar *__restrict__ assembledqfarray, CeedScalar *__restrict__ elemdiagarray) {
57   const int tid = threadIdx.x;  // running with P threads, tid is evec node
58   if (tid >= NNODES) return;
59 
60   // Compute the diagonal of B^T D B
61   // Each element
62   for (CeedInt e = blockIdx.x * blockDim.z + threadIdx.z; e < nelem; e += gridDim.x * blockDim.z) {
63     CeedInt dout = -1;
64     // Each basis eval mode pair
65     for (CeedInt eout = 0; eout < NUMEMODEOUT; eout++) {
66       const CeedScalar *bt = NULL;
67       if (emodeout[eout] == CEED_EVAL_GRAD) dout += 1;
68       CeedOperatorGetBasisPointer_Hip(&bt, emodeout[eout], identity, interpout, &gradout[dout * NQPTS * NNODES]);
69       CeedInt din = -1;
70       for (CeedInt ein = 0; ein < NUMEMODEIN; ein++) {
71         const CeedScalar *b = NULL;
72         if (emodein[ein] == CEED_EVAL_GRAD) din += 1;
73         CeedOperatorGetBasisPointer_Hip(&b, emodein[ein], identity, interpin, &gradin[din * NQPTS * NNODES]);
74         // Each component
75         for (CeedInt compOut = 0; compOut < NCOMP; compOut++) {
76           // Each qpoint/node pair
77           if (pointBlock) {
78             // Point Block Diagonal
79             for (CeedInt compIn = 0; compIn < NCOMP; compIn++) {
80               CeedScalar evalue = 0.;
81               for (CeedInt q = 0; q < NQPTS; q++) {
82                 const CeedScalar qfvalue =
83                     assembledqfarray[((((ein * NCOMP + compIn) * NUMEMODEOUT + eout) * NCOMP + compOut) * nelem + e) * NQPTS + q];
84                 evalue += bt[q * NNODES + tid] * qfvalue * b[q * NNODES + tid];
85               }
86               elemdiagarray[((compOut * NCOMP + compIn) * nelem + e) * NNODES + tid] += evalue;
87             }
88           } else {
89             // Diagonal Only
90             CeedScalar evalue = 0.;
91             for (CeedInt q = 0; q < NQPTS; q++) {
92               const CeedScalar qfvalue =
93                   assembledqfarray[((((ein * NCOMP + compOut) * NUMEMODEOUT + eout) * NCOMP + compOut) * nelem + e) * NQPTS + q];
94               evalue += bt[q * NNODES + tid] * qfvalue * b[q * NNODES + tid];
95             }
96             elemdiagarray[(compOut * nelem + e) * NNODES + tid] += evalue;
97           }
98         }
99       }
100     }
101   }
102 }
103 
104 //------------------------------------------------------------------------------
105 // Linear diagonal
106 //------------------------------------------------------------------------------
107 extern "C" __global__ void linearDiagonal(const CeedInt nelem, const CeedScalar *identity, const CeedScalar *interpin, const CeedScalar *gradin,
108                                           const CeedScalar *interpout, const CeedScalar *gradout, const CeedEvalMode *emodein,
109                                           const CeedEvalMode *emodeout, const CeedScalar *__restrict__ assembledqfarray,
110                                           CeedScalar *__restrict__ elemdiagarray) {
111   diagonalCore(nelem, false, identity, interpin, gradin, interpout, gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
112 }
113 
114 //------------------------------------------------------------------------------
115 // Linear point block diagonal
116 //------------------------------------------------------------------------------
117 extern "C" __global__ void linearPointBlockDiagonal(const CeedInt nelem, const CeedScalar *identity, const CeedScalar *interpin,
118                                                     const CeedScalar *gradin, const CeedScalar *interpout, const CeedScalar *gradout,
119                                                     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
120                                                     const CeedScalar *__restrict__ assembledqfarray, CeedScalar *__restrict__ elemdiagarray) {
121   diagonalCore(nelem, true, identity, interpin, gradin, interpout, gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
122 }
123 
124 //------------------------------------------------------------------------------
125