xref: /libCEED/include/ceed/jit-source/cuda/cuda-ref-operator-assemble-diagonal.h (revision fa5adaf534a16ab6b728a4ab386a8a8b7691fe89)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 
10 #if CEEDSIZE
11 typedef CeedSize IndexType;
12 #else
13 typedef CeedInt IndexType;
14 #endif
15 
16 //------------------------------------------------------------------------------
17 // Get Basis Emode Pointer
18 //------------------------------------------------------------------------------
19 extern "C" __device__ void CeedOperatorGetBasisPointer_Cuda(const CeedScalar **basisptr, CeedEvalMode emode, const CeedScalar *identity,
20                                                             const CeedScalar *interp, const CeedScalar *grad) {
21   switch (emode) {
22     case CEED_EVAL_NONE:
23       *basisptr = identity;
24       break;
25     case CEED_EVAL_INTERP:
26       *basisptr = interp;
27       break;
28     case CEED_EVAL_GRAD:
29       *basisptr = grad;
30       break;
31     case CEED_EVAL_WEIGHT:
32     case CEED_EVAL_DIV:
33     case CEED_EVAL_CURL:
34       break;  // Caught by QF Assembly
35   }
36 }
37 
38 //------------------------------------------------------------------------------
39 // Core code for diagonal assembly
40 //------------------------------------------------------------------------------
41 __device__ void diagonalCore(const CeedInt nelem, const bool pointBlock, const CeedScalar *identity, const CeedScalar *interpin,
42                              const CeedScalar *gradin, const CeedScalar *interpout, const CeedScalar *gradout, const CeedEvalMode *emodein,
43                              const CeedEvalMode *emodeout, const CeedScalar *__restrict__ assembledqfarray, CeedScalar *__restrict__ elemdiagarray) {
44   const int tid = threadIdx.x;  // running with P threads, tid is evec node
45   if (tid >= NNODES) return;
46 
47   // Compute the diagonal of B^T D B
48   // Each element
49   for (IndexType e = blockIdx.x * blockDim.z + threadIdx.z; e < nelem; e += gridDim.x * blockDim.z) {
50     IndexType dout = -1;
51     // Each basis eval mode pair
52     for (IndexType eout = 0; eout < NUMEMODEOUT; eout++) {
53       const CeedScalar *bt = NULL;
54       if (emodeout[eout] == CEED_EVAL_GRAD) dout += 1;
55       CeedOperatorGetBasisPointer_Cuda(&bt, emodeout[eout], identity, interpout, &gradout[dout * NQPTS * NNODES]);
56       IndexType din = -1;
57       for (IndexType ein = 0; ein < NUMEMODEIN; ein++) {
58         const CeedScalar *b = NULL;
59         if (emodein[ein] == CEED_EVAL_GRAD) din += 1;
60         CeedOperatorGetBasisPointer_Cuda(&b, emodein[ein], identity, interpin, &gradin[din * NQPTS * NNODES]);
61         // Each component
62         for (IndexType compOut = 0; compOut < NCOMP; compOut++) {
63           // Each qpoint/node pair
64           if (pointBlock) {
65             // Point Block Diagonal
66             for (IndexType compIn = 0; compIn < NCOMP; compIn++) {
67               CeedScalar evalue = 0.;
68               for (IndexType q = 0; q < NQPTS; q++) {
69                 const CeedScalar qfvalue =
70                     assembledqfarray[((((ein * NCOMP + compIn) * NUMEMODEOUT + eout) * NCOMP + compOut) * nelem + e) * NQPTS + q];
71                 evalue += bt[q * NNODES + tid] * qfvalue * b[q * NNODES + tid];
72               }
73               elemdiagarray[((compOut * NCOMP + compIn) * nelem + e) * NNODES + tid] += evalue;
74             }
75           } else {
76             // Diagonal Only
77             CeedScalar evalue = 0.;
78             for (IndexType q = 0; q < NQPTS; q++) {
79               const CeedScalar qfvalue =
80                   assembledqfarray[((((ein * NCOMP + compOut) * NUMEMODEOUT + eout) * NCOMP + compOut) * nelem + e) * NQPTS + q];
81               evalue += bt[q * NNODES + tid] * qfvalue * b[q * NNODES + tid];
82             }
83             elemdiagarray[(compOut * nelem + e) * NNODES + tid] += evalue;
84           }
85         }
86       }
87     }
88   }
89 }
90 
91 //------------------------------------------------------------------------------
92 // Linear diagonal
93 //------------------------------------------------------------------------------
94 extern "C" __global__ void linearDiagonal(const CeedInt nelem, const CeedScalar *identity, const CeedScalar *interpin, const CeedScalar *gradin,
95                                           const CeedScalar *interpout, const CeedScalar *gradout, const CeedEvalMode *emodein,
96                                           const CeedEvalMode *emodeout, const CeedScalar *__restrict__ assembledqfarray,
97                                           CeedScalar *__restrict__ elemdiagarray) {
98   diagonalCore(nelem, false, identity, interpin, gradin, interpout, gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
99 }
100 
101 //------------------------------------------------------------------------------
102 // Linear point block diagonal
103 //------------------------------------------------------------------------------
104 extern "C" __global__ void linearPointBlockDiagonal(const CeedInt nelem, const CeedScalar *identity, const CeedScalar *interpin,
105                                                     const CeedScalar *gradin, const CeedScalar *interpout, const CeedScalar *gradout,
106                                                     const CeedEvalMode *emodein, const CeedEvalMode *emodeout,
107                                                     const CeedScalar *__restrict__ assembledqfarray, CeedScalar *__restrict__ elemdiagarray) {
108   diagonalCore(nelem, true, identity, interpin, gradin, interpout, gradout, emodein, emodeout, assembledqfarray, elemdiagarray);
109 }
110 
111 //------------------------------------------------------------------------------
112