xref: /libCEED/include/ceed/jit-source/hip/hip-ref-operator-assemble-diagonal.h (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
207b31e0eSJeremy L Thompson // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
307b31e0eSJeremy L Thompson //
407b31e0eSJeremy L Thompson // SPDX-License-Identifier: BSD-2-Clause
507b31e0eSJeremy L Thompson //
607b31e0eSJeremy L Thompson // This file is part of CEED:  http://github.com/ceed
707b31e0eSJeremy L Thompson 
8b2165e7aSSebastian Grimberg /// @file
9b2165e7aSSebastian Grimberg /// Internal header for HIP operator diagonal assembly
10c0b5abf0SJeremy L Thompson #include <ceed/types.h>
1107b31e0eSJeremy L Thompson 
12004e4986SSebastian Grimberg #if USE_CEEDSIZE
139330daecSnbeams typedef CeedSize IndexType;
149330daecSnbeams #else
159330daecSnbeams typedef CeedInt IndexType;
169330daecSnbeams #endif
179330daecSnbeams 
1807b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
19004e4986SSebastian Grimberg // Get basis pointer
2007b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
GetBasisPointer(const CeedScalar ** basis_ptr,CeedEvalMode eval_modes,const CeedScalar * identity,const CeedScalar * interp,const CeedScalar * grad,const CeedScalar * div,const CeedScalar * curl)21004e4986SSebastian Grimberg static __device__ __inline__ void GetBasisPointer(const CeedScalar **basis_ptr, CeedEvalMode eval_modes, const CeedScalar *identity,
22004e4986SSebastian Grimberg                                                   const CeedScalar *interp, const CeedScalar *grad, const CeedScalar *div, const CeedScalar *curl) {
23004e4986SSebastian Grimberg   switch (eval_modes) {
2407b31e0eSJeremy L Thompson     case CEED_EVAL_NONE:
25004e4986SSebastian Grimberg       *basis_ptr = identity;
2607b31e0eSJeremy L Thompson       break;
2707b31e0eSJeremy L Thompson     case CEED_EVAL_INTERP:
28004e4986SSebastian Grimberg       *basis_ptr = interp;
2907b31e0eSJeremy L Thompson       break;
3007b31e0eSJeremy L Thompson     case CEED_EVAL_GRAD:
31004e4986SSebastian Grimberg       *basis_ptr = grad;
32004e4986SSebastian Grimberg       break;
33004e4986SSebastian Grimberg     case CEED_EVAL_DIV:
34004e4986SSebastian Grimberg       *basis_ptr = div;
35004e4986SSebastian Grimberg       break;
36004e4986SSebastian Grimberg     case CEED_EVAL_CURL:
37004e4986SSebastian Grimberg       *basis_ptr = curl;
3807b31e0eSJeremy L Thompson       break;
3907b31e0eSJeremy L Thompson     case CEED_EVAL_WEIGHT:
40004e4986SSebastian Grimberg       break;  // Caught by QF assembly
4107b31e0eSJeremy L Thompson   }
4207b31e0eSJeremy L Thompson }
4307b31e0eSJeremy L Thompson 
4407b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
4507b31e0eSJeremy L Thompson // Core code for diagonal assembly
4607b31e0eSJeremy L Thompson //------------------------------------------------------------------------------
__launch_bounds__(BLOCK_SIZE)47cbfe683aSSebastian Grimberg extern "C" __launch_bounds__(BLOCK_SIZE) __global__
48cbfe683aSSebastian Grimberg     void LinearDiagonal(const CeedInt num_elem, const CeedScalar *identity, const CeedScalar *interp_in, const CeedScalar *grad_in,
49cbfe683aSSebastian Grimberg                         const CeedScalar *div_in, const CeedScalar *curl_in, const CeedScalar *interp_out, const CeedScalar *grad_out,
50cbfe683aSSebastian Grimberg                         const CeedScalar *div_out, const CeedScalar *curl_out, const CeedEvalMode *eval_modes_in, const CeedEvalMode *eval_modes_out,
51cbfe683aSSebastian Grimberg                         const CeedScalar *__restrict__ assembled_qf_array, CeedScalar *__restrict__ elem_diag_array) {
52004e4986SSebastian Grimberg   const int tid = threadIdx.x;  // Running with P threads
53004e4986SSebastian Grimberg 
54004e4986SSebastian Grimberg   if (tid >= NUM_NODES) return;
5507b31e0eSJeremy L Thompson 
5607b31e0eSJeremy L Thompson   // Compute the diagonal of B^T D B
5707b31e0eSJeremy L Thompson   // Each element
58004e4986SSebastian Grimberg   for (IndexType e = blockIdx.x * blockDim.z + threadIdx.z; e < num_elem; e += gridDim.x * blockDim.z) {
5907b31e0eSJeremy L Thompson     // Each basis eval mode pair
60004e4986SSebastian Grimberg     IndexType    d_out               = 0;
61004e4986SSebastian Grimberg     CeedEvalMode eval_modes_out_prev = CEED_EVAL_NONE;
62004e4986SSebastian Grimberg 
63004e4986SSebastian Grimberg     for (IndexType e_out = 0; e_out < NUM_EVAL_MODES_OUT; e_out++) {
64004e4986SSebastian Grimberg       IndexType         d_in               = 0;
65004e4986SSebastian Grimberg       CeedEvalMode      eval_modes_in_prev = CEED_EVAL_NONE;
66004e4986SSebastian Grimberg       const CeedScalar *b_t                = NULL;
67004e4986SSebastian Grimberg 
68004e4986SSebastian Grimberg       GetBasisPointer(&b_t, eval_modes_out[e_out], identity, interp_out, grad_out, div_out, curl_out);
69004e4986SSebastian Grimberg       if (e_out == 0 || eval_modes_out[e_out] != eval_modes_out_prev) d_out = 0;
70004e4986SSebastian Grimberg       else b_t = &b_t[(++d_out) * NUM_QPTS * NUM_NODES];
71004e4986SSebastian Grimberg       eval_modes_out_prev = eval_modes_out[e_out];
72004e4986SSebastian Grimberg 
73004e4986SSebastian Grimberg       for (IndexType e_in = 0; e_in < NUM_EVAL_MODES_IN; e_in++) {
7407b31e0eSJeremy L Thompson         const CeedScalar *b = NULL;
75004e4986SSebastian Grimberg 
76004e4986SSebastian Grimberg         GetBasisPointer(&b, eval_modes_in[e_in], identity, interp_in, grad_in, div_in, curl_in);
77004e4986SSebastian Grimberg         if (e_in == 0 || eval_modes_in[e_in] != eval_modes_in_prev) d_in = 0;
78004e4986SSebastian Grimberg         else b = &b[(++d_in) * NUM_QPTS * NUM_NODES];
79004e4986SSebastian Grimberg         eval_modes_in_prev = eval_modes_in[e_in];
80004e4986SSebastian Grimberg 
8107b31e0eSJeremy L Thompson         // Each component
82004e4986SSebastian Grimberg         for (IndexType comp_out = 0; comp_out < NUM_COMP; comp_out++) {
83cbfe683aSSebastian Grimberg #if USE_POINT_BLOCK
84004e4986SSebastian Grimberg           // Point block diagonal
85004e4986SSebastian Grimberg           for (IndexType comp_in = 0; comp_in < NUM_COMP; comp_in++) {
86004e4986SSebastian Grimberg             CeedScalar e_value = 0.;
87004e4986SSebastian Grimberg 
88cbfe683aSSebastian Grimberg             // Each qpoint/node pair
89004e4986SSebastian Grimberg             for (IndexType q = 0; q < NUM_QPTS; q++) {
90004e4986SSebastian Grimberg               const CeedScalar qf_value =
91cbfe683aSSebastian Grimberg                   assembled_qf_array[((((e_in * NUM_COMP + comp_in) * NUM_EVAL_MODES_OUT + e_out) * NUM_COMP + comp_out) * num_elem + e) * NUM_QPTS +
92004e4986SSebastian Grimberg                                      q];
93004e4986SSebastian Grimberg 
94004e4986SSebastian Grimberg               e_value += b_t[q * NUM_NODES + tid] * qf_value * b[q * NUM_NODES + tid];
9507b31e0eSJeremy L Thompson             }
96004e4986SSebastian Grimberg             elem_diag_array[((comp_out * NUM_COMP + comp_in) * num_elem + e) * NUM_NODES + tid] += e_value;
9707b31e0eSJeremy L Thompson           }
98cbfe683aSSebastian Grimberg #else
99004e4986SSebastian Grimberg           // Diagonal only
100004e4986SSebastian Grimberg           CeedScalar e_value = 0.;
101004e4986SSebastian Grimberg 
102cbfe683aSSebastian Grimberg           // Each qpoint/node pair
103004e4986SSebastian Grimberg           for (IndexType q = 0; q < NUM_QPTS; q++) {
104004e4986SSebastian Grimberg             const CeedScalar qf_value =
105004e4986SSebastian Grimberg                 assembled_qf_array[((((e_in * NUM_COMP + comp_out) * NUM_EVAL_MODES_OUT + e_out) * NUM_COMP + comp_out) * num_elem + e) * NUM_QPTS +
106004e4986SSebastian Grimberg                                    q];
107004e4986SSebastian Grimberg 
108004e4986SSebastian Grimberg             e_value += b_t[q * NUM_NODES + tid] * qf_value * b[q * NUM_NODES + tid];
10907b31e0eSJeremy L Thompson           }
110004e4986SSebastian Grimberg           elem_diag_array[(comp_out * num_elem + e) * NUM_NODES + tid] += e_value;
111cbfe683aSSebastian Grimberg #endif
11207b31e0eSJeremy L Thompson         }
11307b31e0eSJeremy L Thompson       }
11407b31e0eSJeremy L Thompson     }
11507b31e0eSJeremy L Thompson   }
11607b31e0eSJeremy L Thompson }
117