xref: /libCEED/include/ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h (revision 07328cdeb7cc6fe1ac9596b69aea7fe9a5beddd3)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP operator diagonal assembly
10 #ifndef CEED_HIP_REF_OPERATOR_ASSEMBLE_DIAGONAL_H
11 #define CEED_HIP_REF_OPERATOR_ASSEMBLE_DIAGONAL_H
12 
13 #include <ceed.h>
14 
15 #if CEEDSIZE
16 typedef CeedSize IndexType;
17 #else
18 typedef CeedInt IndexType;
19 #endif
20 
21 //------------------------------------------------------------------------------
22 // Get Basis Emode Pointer
23 //------------------------------------------------------------------------------
24 extern "C" __device__ void CeedOperatorGetBasisPointer_Hip(const CeedScalar **basisptr, CeedEvalMode emode, const CeedScalar *identity,
25                                                            const CeedScalar *interp, const CeedScalar *grad) {
26   switch (emode) {
27     case CEED_EVAL_NONE:
28       *basisptr = identity;
29       break;
30     case CEED_EVAL_INTERP:
31       *basisptr = interp;
32       break;
33     case CEED_EVAL_GRAD:
34       *basisptr = grad;
35       break;
36     case CEED_EVAL_WEIGHT:
37     case CEED_EVAL_DIV:
38     case CEED_EVAL_CURL:
39       break;  // Caught by QF Assembly
40   }
41 }
42 
43 //------------------------------------------------------------------------------
44 // Core code for diagonal assembly
45 //------------------------------------------------------------------------------
46 __device__ void diagonalCore(const CeedInt nelem, const bool pointBlock, const CeedScalar *identity, const CeedScalar *interpin,
47                              const CeedScalar *gradin, const CeedScalar *interpout, const CeedScalar *gradout, const CeedEvalMode *emodein,
48                              const CeedEvalMode *emodeout, const CeedScalar *__restrict__ assembledqfarray, CeedScalar *__restrict__ elemdiagarray) {
49   const int tid = threadIdx.x;  // running with P threads, tid is evec node
50   if (tid >= NNODES) return;
51 
52   // Compute the diagonal of B^T D B
53   // Each element
54   for (IndexType e = blockIdx.x * blockDim.z + threadIdx.z; e < nelem; e += gridDim.x * blockDim.z) {
55     IndexType dout = -1;
56     // Each basis eval mode pair
57     for (IndexType eout = 0; eout < NUMEMODEOUT; eout++) {
58       const CeedScalar *bt = NULL;
59       if (emodeout[eout] == CEED_EVAL_GRAD) dout += 1;
60       CeedOperatorGetBasisPointer_Hip(&bt, emodeout[eout], identity, interpout, &gradout[dout * NQPTS * NNODES]);
61       IndexType din = -1;
62       for (IndexType ein = 0; ein < NUMEMODEIN; ein++) {
63         const CeedScalar *b = NULL;
64         if (emodein[ein] == CEED_EVAL_GRAD) din += 1;
65         CeedOperatorGetBasisPointer_Hip(&b, emodein[ein], identity, interpin, &gradin[din * NQPTS * NNODES]);
66         // Each component
67         for (IndexType compOut = 0; compOut < NCOMP; compOut++) {
68           // Each qpoint/node pair
69           if (pointBlock) {
70             // Point Block Diagonal
71             for (IndexType compIn = 0; compIn < NCOMP; compIn++) {
72               CeedScalar evalue = 0.;
73               for (IndexType q = 0; q < NQPTS; q++) {
74                 const CeedScalar qfvalue =
75                     assembledqfarray[((((ein * NCOMP + compIn) * NUMEMODEOUT + eout) * NCOMP + compOut) * nelem + e) * NQPTS + q];
76                 evalue += bt[q * NNODES + tid] * qfvalue * b[q * NNODES + tid];
77               }
78               elemdiagarray[((compOut * NCOMP + compIn) * nelem + e) * NNODES + tid] += evalue;
79             }
80           } else {
81             // Diagonal Only
82             CeedScalar evalue = 0.;
83             for (IndexType q = 0; q < NQPTS; q++) {
84               const CeedScalar qfvalue =
85                   assembledqfarray[((((ein * NCOMP + compOut) * NUMEMODEOUT + eout) * NCOMP + compOut) * nelem + e) * NQPTS + q];
86               evalue += bt[q * NNODES + tid] * qfvalue * b[q * NNODES + tid];
87             }
88             elemdiagarray[(compOut * nelem + e) * NNODES + tid] += evalue;
89           }
90         }
91       }
92     }
93   }
94 }
95 
96 //------------------------------------------------------------------------------
97 // Linear diagonal
98 //------------------------------------------------------------------------------
99 extern "C" __global__ void linearDiagonal(const CeedInt nelem, const CeedScalar *identity, const CeedScalar *interpin, const CeedScalar *gradin,
100                                           const CeedScalar *interpout, const CeedScalar *gradout, const CeedEvalMode *emodein,
101                                           const CeedEvalMode *emodeout, const CeedScalar *__restrict__ assembledqfarray,
102                                           CeedScalar *__restrict__ elemdiagarray) {
103   diagonalCore(nelem, false, identity, interpin, gradin, interpout, gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
104 }
105 
106 //------------------------------------------------------------------------------
107 // Linear point block diagonal
108 //------------------------------------------------------------------------------
109 extern "C" __global__ void linearPointBlockDiagonal(const CeedInt nelem, const CeedScalar *identity, const CeedScalar *interpin,
110                                                     const CeedScalar *gradin, const CeedScalar *interpout, const CeedScalar *gradout,
111                                                     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
112                                                     const CeedScalar *__restrict__ assembledqfarray, CeedScalar *__restrict__ elemdiagarray) {
113   diagonalCore(nelem, true, identity, interpin, gradin, interpout, gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
114 }
115 
116 //------------------------------------------------------------------------------
117 
118 #endif  // CEED_HIP_REF_OPERATOR_ASSEMBLE_DIAGONAL_H
119