xref: /libCEED/include/ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h (revision 004e49868906b3e3ec4a252ac682c88f9414881a)
107b31e0eSJeremy L Thompson // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
207b31e0eSJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
307b31e0eSJeremy L Thompson //
407b31e0eSJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
507b31e0eSJeremy L Thompson //
607b31e0eSJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
707b31e0eSJeremy L Thompson 
8b2165e7aSSebastian Grimberg /// @file
9b2165e7aSSebastian Grimberg /// Internal header for HIP operator diagonal assembly
1094b7b29bSJeremy L Thompson #ifndef CEED_HIP_REF_OPERATOR_ASSEMBLE_DIAGONAL_H
1194b7b29bSJeremy L Thompson #define CEED_HIP_REF_OPERATOR_ASSEMBLE_DIAGONAL_H
12b2165e7aSSebastian Grimberg 
13c9c2c079SJeremy L Thompson #include <ceed.h>
1407b31e0eSJeremy L Thompson 
15*004e4986SSebastian Grimberg #if USE_CEEDSIZE
169330daecSnbeams typedef CeedSize IndexType;
179330daecSnbeams #else
189330daecSnbeams typedef CeedInt IndexType;
199330daecSnbeams #endif
209330daecSnbeams 
2107b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
22*004e4986SSebastian Grimberg // Get basis pointer
2307b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
24*004e4986SSebastian Grimberg static __device__ __inline__ void GetBasisPointer(const CeedScalar **basis_ptr, CeedEvalMode eval_modes, const CeedScalar *identity,
25*004e4986SSebastian Grimberg                                                   const CeedScalar *interp, const CeedScalar *grad, const CeedScalar *div, const CeedScalar *curl) {
26*004e4986SSebastian Grimberg   switch (eval_modes) {
2707b31e0eSJeremy L Thompson     case CEED_EVAL_NONE:
28*004e4986SSebastian Grimberg       *basis_ptr = identity;
2907b31e0eSJeremy L Thompson       break;
3007b31e0eSJeremy L Thompson     case CEED_EVAL_INTERP:
31*004e4986SSebastian Grimberg       *basis_ptr = interp;
3207b31e0eSJeremy L Thompson       break;
3307b31e0eSJeremy L Thompson     case CEED_EVAL_GRAD:
34*004e4986SSebastian Grimberg       *basis_ptr = grad;
35*004e4986SSebastian Grimberg       break;
36*004e4986SSebastian Grimberg     case CEED_EVAL_DIV:
37*004e4986SSebastian Grimberg       *basis_ptr = div;
38*004e4986SSebastian Grimberg       break;
39*004e4986SSebastian Grimberg     case CEED_EVAL_CURL:
40*004e4986SSebastian Grimberg       *basis_ptr = curl;
4107b31e0eSJeremy L Thompson       break;
4207b31e0eSJeremy L Thompson     case CEED_EVAL_WEIGHT:
43*004e4986SSebastian Grimberg       break;  // Caught by QF assembly
4407b31e0eSJeremy L Thompson   }
4507b31e0eSJeremy L Thompson }
4607b31e0eSJeremy L Thompson 
4707b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
4807b31e0eSJeremy L Thompson // Core code for diagonal assembly
4907b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
50*004e4986SSebastian Grimberg static __device__ __inline__ void DiagonalCore(const CeedInt num_elem, const bool is_point_block, const CeedScalar *identity,
51*004e4986SSebastian Grimberg                                                const CeedScalar *interp_in, const CeedScalar *grad_in, const CeedScalar *div_in,
52*004e4986SSebastian Grimberg                                                const CeedScalar *curl_in, const CeedScalar *interp_out, const CeedScalar *grad_out,
53*004e4986SSebastian Grimberg                                                const CeedScalar *div_out, const CeedScalar *curl_out, const CeedEvalMode *eval_modes_in,
54*004e4986SSebastian Grimberg                                                const CeedEvalMode *eval_modes_out, const CeedScalar *__restrict__ assembled_qf_array,
55*004e4986SSebastian Grimberg                                                CeedScalar *__restrict__ elem_diag_array) {
56*004e4986SSebastian Grimberg   const int tid = threadIdx.x;  // Running with P threads
57*004e4986SSebastian Grimberg 
58*004e4986SSebastian Grimberg   if (tid >= NUM_NODES) return;
5907b31e0eSJeremy L Thompson 
6007b31e0eSJeremy L Thompson   // Compute the diagonal of B^T D B
6107b31e0eSJeremy L Thompson   // Each element
62*004e4986SSebastian Grimberg   for (IndexType e = blockIdx.x * blockDim.z + threadIdx.z; e < num_elem; e += gridDim.x * blockDim.z) {
6307b31e0eSJeremy L Thompson     // Each basis eval mode pair
64*004e4986SSebastian Grimberg     IndexType    d_out               = 0;
65*004e4986SSebastian Grimberg     CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE;
66*004e4986SSebastian Grimberg 
67*004e4986SSebastian Grimberg     for (IndexType e_out = 0; e_out < NUM_EVAL_MODES_OUT; e_out++) {
68*004e4986SSebastian Grimberg       IndexType         d_in               = 0;
69*004e4986SSebastian Grimberg       CeedEvalMode      eval_modes_in_prev = CEED_EVAL_NONE;
70*004e4986SSebastian Grimberg       const CeedScalar *b_t                = NULL;
71*004e4986SSebastian Grimberg 
72*004e4986SSebastian Grimberg       GetBasisPointer(&b_t, eval_modes_out[e_out], identity, interp_out, grad_out, div_out, curl_out);
73*004e4986SSebastian Grimberg       if (e_out == 0 || eval_modes_out[e_out] != eval_modes_out_prev) d_out = 0;
74*004e4986SSebastian Grimberg       else b_t = &b_t[(++d_out) * NUM_QPTS * NUM_NODES];
75*004e4986SSebastian Grimberg       eval_modes_out_prev = eval_modes_out[e_out];
76*004e4986SSebastian Grimberg 
77*004e4986SSebastian Grimberg       for (IndexType e_in = 0; e_in < NUM_EVAL_MODES_IN; e_in++) {
7807b31e0eSJeremy L Thompson         const CeedScalar *b = NULL;
79*004e4986SSebastian Grimberg 
80*004e4986SSebastian Grimberg         GetBasisPointer(&b, eval_modes_in[e_in], identity, interp_in, grad_in, div_in, curl_in);
81*004e4986SSebastian Grimberg         if (e_in == 0 || eval_modes_in[e_in] != eval_modes_in_prev) d_in = 0;
82*004e4986SSebastian Grimberg         else b = &b[(++d_in) * NUM_QPTS * NUM_NODES];
83*004e4986SSebastian Grimberg         eval_modes_in_prev = eval_modes_in[e_in];
84*004e4986SSebastian Grimberg 
8507b31e0eSJeremy L Thompson         // Each component
86*004e4986SSebastian Grimberg         for (IndexType comp_out = 0; comp_out < NUM_COMP; comp_out++) {
8707b31e0eSJeremy L Thompson           // Each qpoint/node pair
88*004e4986SSebastian Grimberg           if (is_point_block) {
89*004e4986SSebastian Grimberg             // Point block diagonal
90*004e4986SSebastian Grimberg             for (IndexType comp_in = 0; comp_in < NUM_COMP; comp_in++) {
91*004e4986SSebastian Grimberg               CeedScalar e_value = 0.;
92*004e4986SSebastian Grimberg 
93*004e4986SSebastian Grimberg               for (IndexType q = 0; q < NUM_QPTS; q++) {
94*004e4986SSebastian Grimberg                 const CeedScalar qf_value =
95*004e4986SSebastian Grimberg                     assembled_qf_array[((((e_in * NUM_COMP + comp_in) * NUM_EVAL_MODES_OUT + e_out) * NUM_COMP + comp_out) * num_elem + e) *
96*004e4986SSebastian Grimberg                                            NUM_QPTS +
97*004e4986SSebastian Grimberg                                        q];
98*004e4986SSebastian Grimberg 
99*004e4986SSebastian Grimberg                 e_value += b_t[q * NUM_NODES + tid] * qf_value * b[q * NUM_NODES + tid];
10007b31e0eSJeremy L Thompson               }
101*004e4986SSebastian Grimberg               elem_diag_array[((comp_out * NUM_COMP + comp_in) * num_elem + e) * NUM_NODES + tid] += e_value;
10207b31e0eSJeremy L Thompson             }
10307b31e0eSJeremy L Thompson           } else {
104*004e4986SSebastian Grimberg             // Diagonal only
105*004e4986SSebastian Grimberg             CeedScalar e_value = 0.;
106*004e4986SSebastian Grimberg 
107*004e4986SSebastian Grimberg             for (IndexType q = 0; q < NUM_QPTS; q++) {
108*004e4986SSebastian Grimberg               const CeedScalar qf_value =
109*004e4986SSebastian Grimberg                   assembled_qf_array[((((e_in * NUM_COMP + comp_out) * NUM_EVAL_MODES_OUT + e_out) * NUM_COMP + comp_out) * num_elem + e) * NUM_QPTS +
110*004e4986SSebastian Grimberg                                      q];
111*004e4986SSebastian Grimberg 
112*004e4986SSebastian Grimberg               e_value += b_t[q * NUM_NODES + tid] * qf_value * b[q * NUM_NODES + tid];
11307b31e0eSJeremy L Thompson             }
114*004e4986SSebastian Grimberg             elem_diag_array[(comp_out * num_elem + e) * NUM_NODES + tid] += e_value;
11507b31e0eSJeremy L Thompson           }
11607b31e0eSJeremy L Thompson         }
11707b31e0eSJeremy L Thompson       }
11807b31e0eSJeremy L Thompson     }
11907b31e0eSJeremy L Thompson   }
12007b31e0eSJeremy L Thompson }
12107b31e0eSJeremy L Thompson 
12207b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
12307b31e0eSJeremy L Thompson // Linear diagonal
12407b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
125*004e4986SSebastian Grimberg extern "C" __global__ void LinearDiagonal(const CeedInt num_elem, const CeedScalar *identity, const CeedScalar *interp_in, const CeedScalar *grad_in,
126*004e4986SSebastian Grimberg                                           const CeedScalar *div_in, const CeedScalar *curl_in, const CeedScalar *interp_out,
127*004e4986SSebastian Grimberg                                           const CeedScalar *grad_out, const CeedScalar *div_out, const CeedScalar *curl_out,
128*004e4986SSebastian Grimberg                                           const CeedEvalMode *eval_modes_in, const CeedEvalMode *eval_modes_out,
129*004e4986SSebastian Grimberg                                           const CeedScalar *__restrict__ assembled_qf_array, CeedScalar *__restrict__ elem_diag_array) {
130*004e4986SSebastian Grimberg   DiagonalCore(num_elem, false, identity, interp_in, grad_in, div_in, curl_in, interp_out, grad_out, div_out, curl_out, eval_modes_in, eval_modes_out,
131*004e4986SSebastian Grimberg                assembled_qf_array, elem_diag_array);
13207b31e0eSJeremy L Thompson }
13307b31e0eSJeremy L Thompson 
13407b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
13507b31e0eSJeremy L Thompson // Linear point block diagonal
13607b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
137*004e4986SSebastian Grimberg extern "C" __global__ void LinearPointBlockDiagonal(const CeedInt num_elem, const CeedScalar *identity, const CeedScalar *interp_in,
138*004e4986SSebastian Grimberg                                                     const CeedScalar *grad_in, const CeedScalar *div_in, const CeedScalar *curl_in,
139*004e4986SSebastian Grimberg                                                     const CeedScalar *interp_out, const CeedScalar *grad_out, const CeedScalar *div_out,
140*004e4986SSebastian Grimberg                                                     const CeedScalar *curl_out, const CeedEvalMode *eval_modes_in, const CeedEvalMode *eval_modes_out,
141*004e4986SSebastian Grimberg                                                     const CeedScalar *__restrict__ assembled_qf_array, CeedScalar *__restrict__ elem_diag_array) {
142*004e4986SSebastian Grimberg   DiagonalCore(num_elem, true, identity, interp_in, grad_in, div_in, curl_in, interp_out, grad_out, div_out, curl_out, eval_modes_in, eval_modes_out,
143*004e4986SSebastian Grimberg                assembled_qf_array, elem_diag_array);
14407b31e0eSJeremy L Thompson }
14507b31e0eSJeremy L Thompson 
14607b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
147b2165e7aSSebastian Grimberg 
14894b7b29bSJeremy L Thompson #endif  // CEED_HIP_REF_OPERATOR_ASSEMBLE_DIAGONAL_H
149