xref: /libCEED/include/ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h (revision 509d4af65d23546c690c9766d8b29e47dc3b3afb)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP operator diagonal assembly
10 
11 #include <ceed.h>
12 
13 #if USE_CEEDSIZE
14 typedef CeedSize IndexType;
15 #else
16 typedef CeedInt IndexType;
17 #endif
18 
19 //------------------------------------------------------------------------------
20 // Get basis pointer
21 //------------------------------------------------------------------------------
22 static __device__ __inline__ void GetBasisPointer(const CeedScalar **basis_ptr, CeedEvalMode eval_modes, const CeedScalar *identity,
23                                                   const CeedScalar *interp, const CeedScalar *grad, const CeedScalar *div, const CeedScalar *curl) {
24   switch (eval_modes) {
25     case CEED_EVAL_NONE:
26       *basis_ptr = identity;
27       break;
28     case CEED_EVAL_INTERP:
29       *basis_ptr = interp;
30       break;
31     case CEED_EVAL_GRAD:
32       *basis_ptr = grad;
33       break;
34     case CEED_EVAL_DIV:
35       *basis_ptr = div;
36       break;
37     case CEED_EVAL_CURL:
38       *basis_ptr = curl;
39       break;
40     case CEED_EVAL_WEIGHT:
41       break;  // Caught by QF assembly
42   }
43 }
44 
45 //------------------------------------------------------------------------------
46 // Core code for diagonal assembly
47 //------------------------------------------------------------------------------
48 extern "C" __launch_bounds__(BLOCK_SIZE) __global__
49     void LinearDiagonal(const CeedInt num_elem, const CeedScalar *identity, const CeedScalar *interp_in, const CeedScalar *grad_in,
50                         const CeedScalar *div_in, const CeedScalar *curl_in, const CeedScalar *interp_out, const CeedScalar *grad_out,
51                         const CeedScalar *div_out, const CeedScalar *curl_out, const CeedEvalMode *eval_modes_in, const CeedEvalMode *eval_modes_out,
52                         const CeedScalar *__restrict__ assembled_qf_array, CeedScalar *__restrict__ elem_diag_array) {
53   const int tid = threadIdx.x;  // Running with P threads
54 
55   if (tid >= NUM_NODES) return;
56 
57   // Compute the diagonal of B^T D B
58   // Each element
59   for (IndexType e = blockIdx.x * blockDim.z + threadIdx.z; e < num_elem; e += gridDim.x * blockDim.z) {
60     // Each basis eval mode pair
61     IndexType    d_out               = 0;
62     CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE;
63 
64     for (IndexType e_out = 0; e_out < NUM_EVAL_MODES_OUT; e_out++) {
65       IndexType         d_in               = 0;
66       CeedEvalMode      eval_modes_in_prev = CEED_EVAL_NONE;
67       const CeedScalar *b_t                = NULL;
68 
69       GetBasisPointer(&b_t, eval_modes_out[e_out], identity, interp_out, grad_out, div_out, curl_out);
70       if (e_out == 0 || eval_modes_out[e_out] != eval_modes_out_prev) d_out = 0;
71       else b_t = &b_t[(++d_out) * NUM_QPTS * NUM_NODES];
72       eval_modes_out_prev = eval_modes_out[e_out];
73 
74       for (IndexType e_in = 0; e_in < NUM_EVAL_MODES_IN; e_in++) {
75         const CeedScalar *b = NULL;
76 
77         GetBasisPointer(&b, eval_modes_in[e_in], identity, interp_in, grad_in, div_in, curl_in);
78         if (e_in == 0 || eval_modes_in[e_in] != eval_modes_in_prev) d_in = 0;
79         else b = &b[(++d_in) * NUM_QPTS * NUM_NODES];
80         eval_modes_in_prev = eval_modes_in[e_in];
81 
82         // Each component
83         for (IndexType comp_out = 0; comp_out < NUM_COMP; comp_out++) {
84 #if USE_POINT_BLOCK
85           // Point block diagonal
86           for (IndexType comp_in = 0; comp_in < NUM_COMP; comp_in++) {
87             CeedScalar e_value = 0.;
88 
89             // Each qpoint/node pair
90             for (IndexType q = 0; q < NUM_QPTS; q++) {
91               const CeedScalar qf_value =
92                   assembled_qf_array[((((e_in * NUM_COMP + comp_in) * NUM_EVAL_MODES_OUT + e_out) * NUM_COMP + comp_out) * num_elem + e) * NUM_QPTS +
93                                      q];
94 
95               e_value += b_t[q * NUM_NODES + tid] * qf_value * b[q * NUM_NODES + tid];
96             }
97             elem_diag_array[((comp_out * NUM_COMP + comp_in) * num_elem + e) * NUM_NODES + tid] += e_value;
98           }
99 #else
100           // Diagonal only
101           CeedScalar e_value = 0.;
102 
103           // Each qpoint/node pair
104           for (IndexType q = 0; q < NUM_QPTS; q++) {
105             const CeedScalar qf_value =
106                 assembled_qf_array[((((e_in * NUM_COMP + comp_out) * NUM_EVAL_MODES_OUT + e_out) * NUM_COMP + comp_out) * num_elem + e) * NUM_QPTS +
107                                    q];
108 
109             e_value += b_t[q * NUM_NODES + tid] * qf_value * b[q * NUM_NODES + tid];
110           }
111           elem_diag_array[(comp_out * num_elem + e) * NUM_NODES + tid] += e_value;
112 #endif
113         }
114       }
115     }
116   }
117 }
118