xref: /libCEED/include/ceed/jit-source/hip/hip-ref-basis-nontensor.h (revision c9c2c07970382857cc7b4a28d359710237b91a3e)
1a0154adeSJed Brown // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2a0154adeSJed Brown // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3a0154adeSJed Brown //
4a0154adeSJed Brown // SPDX-License-Identifier: BSD-2-Clause
5a0154adeSJed Brown //
6a0154adeSJed Brown // This file is part of CEED:  http://github.com/ceed
7a0154adeSJed Brown 
8*c9c2c079SJeremy L Thompson #include <ceed.h>
9a0154adeSJed Brown 
10a0154adeSJed Brown //------------------------------------------------------------------------------
11a0154adeSJed Brown // Non-Tensor Basis Kernels
12a0154adeSJed Brown //------------------------------------------------------------------------------
13a0154adeSJed Brown 
14a0154adeSJed Brown //------------------------------------------------------------------------------
15a0154adeSJed Brown // Interp
16a0154adeSJed Brown //------------------------------------------------------------------------------
17a0154adeSJed Brown extern "C" __global__ void Interp(const CeedInt num_elem, const CeedInt transpose,
18a0154adeSJed Brown                                   const CeedScalar *d_B,
19a0154adeSJed Brown                                   const CeedScalar *__restrict__ d_U,
20a0154adeSJed Brown                                   CeedScalar *__restrict__ d_V) {
21a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
22a0154adeSJed Brown 
23a0154adeSJed Brown   const CeedScalar *U;
24a0154adeSJed Brown   CeedScalar V;
25a0154adeSJed Brown   //TODO load B in shared memory if blockDim.z > 1?
26a0154adeSJed Brown 
27a0154adeSJed Brown   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem;
28a0154adeSJed Brown        elem += gridDim.x*blockDim.z) {
29a0154adeSJed Brown     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
30a0154adeSJed Brown       if (transpose) { // run with P threads
31a0154adeSJed Brown         U = d_U + elem*BASIS_Q + comp*num_elem*BASIS_Q;
32a0154adeSJed Brown         V = 0.0;
33a0154adeSJed Brown         for (CeedInt i = 0; i < BASIS_Q; i++)
34a0154adeSJed Brown           V += d_B[t_id + i*BASIS_P]*U[i];
35a0154adeSJed Brown 
36a0154adeSJed Brown         d_V[elem*BASIS_P + comp*num_elem*BASIS_P + t_id] = V;
37a0154adeSJed Brown       } else { // run with Q threads
38a0154adeSJed Brown         U = d_U + elem*BASIS_P + comp*num_elem*BASIS_P;
39a0154adeSJed Brown         V = 0.0;
40a0154adeSJed Brown         for (CeedInt i = 0; i < BASIS_P; i++)
41a0154adeSJed Brown           V += d_B[i + t_id*BASIS_P]*U[i];
42a0154adeSJed Brown 
43a0154adeSJed Brown         d_V[elem*BASIS_Q + comp*num_elem*BASIS_Q + t_id] = V;
44a0154adeSJed Brown       }
45a0154adeSJed Brown     }
46a0154adeSJed Brown   }
47a0154adeSJed Brown }
48a0154adeSJed Brown 
49a0154adeSJed Brown //------------------------------------------------------------------------------
50a0154adeSJed Brown // Grad
51a0154adeSJed Brown //------------------------------------------------------------------------------
52a0154adeSJed Brown extern "C" __global__ void Grad(const CeedInt num_elem, const CeedInt transpose,
53a0154adeSJed Brown                                 const CeedScalar *d_G,
54a0154adeSJed Brown                                 const CeedScalar *__restrict__ d_U,
55a0154adeSJed Brown                                 CeedScalar *__restrict__ d_V) {
56a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
57a0154adeSJed Brown 
58a0154adeSJed Brown   const CeedScalar *U;
59a0154adeSJed Brown   //TODO load G in shared memory if blockDim.z > 1?
60a0154adeSJed Brown 
61a0154adeSJed Brown   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem;
62a0154adeSJed Brown        elem += gridDim.x*blockDim.z) {
63a0154adeSJed Brown     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
64a0154adeSJed Brown       if (transpose) { // run with P threads
65a0154adeSJed Brown         CeedScalar V = 0.0;
66a0154adeSJed Brown         for (CeedInt dim = 0; dim < BASIS_DIM; dim++) {
67a0154adeSJed Brown           U = d_U + elem*BASIS_Q + comp*num_elem*BASIS_Q +
68a0154adeSJed Brown               dim*BASIS_NUM_COMP*num_elem*BASIS_Q;
69a0154adeSJed Brown           for (CeedInt i = 0; i < BASIS_Q; i++)
70a0154adeSJed Brown             V += d_G[t_id + i*BASIS_P + dim*BASIS_P*BASIS_Q]*U[i];
71a0154adeSJed Brown         }
72a0154adeSJed Brown         d_V[elem*BASIS_P + comp*num_elem*BASIS_P + t_id] = V;
73a0154adeSJed Brown       } else { // run with Q threads
74a0154adeSJed Brown         CeedScalar V[BASIS_DIM];
75a0154adeSJed Brown         U = d_U + elem*BASIS_P + comp*num_elem*BASIS_P;
76a0154adeSJed Brown         for (CeedInt dim = 0; dim < BASIS_DIM; dim++)
77a0154adeSJed Brown           V[dim] = 0.0;
78a0154adeSJed Brown 
79a0154adeSJed Brown         for (CeedInt i = 0; i < BASIS_P; i++) {
80a0154adeSJed Brown           const CeedScalar val = U[i];
81a0154adeSJed Brown           for(CeedInt dim = 0; dim < BASIS_DIM; dim++)
82a0154adeSJed Brown             V[dim] += d_G[i + t_id*BASIS_P + dim*BASIS_P*BASIS_Q]*val;
83a0154adeSJed Brown         }
84a0154adeSJed Brown         for (CeedInt dim = 0; dim < BASIS_DIM; dim++) {
85a0154adeSJed Brown           d_V[elem*BASIS_Q + comp*num_elem*BASIS_Q +
86a0154adeSJed Brown               dim*BASIS_NUM_COMP*num_elem*BASIS_Q + t_id] = V[dim];
87a0154adeSJed Brown         }
88a0154adeSJed Brown       }
89a0154adeSJed Brown     }
90a0154adeSJed Brown   }
91a0154adeSJed Brown }
92a0154adeSJed Brown 
93a0154adeSJed Brown //------------------------------------------------------------------------------
94a0154adeSJed Brown // Weight
95a0154adeSJed Brown //------------------------------------------------------------------------------
96a0154adeSJed Brown extern "C" __global__ void Weight(const CeedInt num_elem,
97a0154adeSJed Brown                                   const CeedScalar *__restrict__ qweight,
98a0154adeSJed Brown                                   CeedScalar *__restrict__ d_V) {
99a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
100a0154adeSJed Brown   //TODO load qweight in shared memory if blockDim.z > 1?
101a0154adeSJed Brown   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem;
102a0154adeSJed Brown        elem += gridDim.x*blockDim.z) {
103a0154adeSJed Brown     d_V[elem*BASIS_Q + t_id] = qweight[t_id];
104a0154adeSJed Brown   }
105a0154adeSJed Brown }
106a0154adeSJed Brown 
107a0154adeSJed Brown //------------------------------------------------------------------------------
108