xref: /libCEED/include/ceed/jit-source/hip/hip-ref-basis-nontensor.h (revision a0154adecfab8547cdc0febbbf40ac009dbe9d1d)
1*a0154adeSJed Brown // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*a0154adeSJed Brown // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*a0154adeSJed Brown //
4*a0154adeSJed Brown // SPDX-License-Identifier: BSD-2-Clause
5*a0154adeSJed Brown //
6*a0154adeSJed Brown // This file is part of CEED:  http://github.com/ceed
7*a0154adeSJed Brown 
8*a0154adeSJed Brown #include <ceed/ceed.h>
9*a0154adeSJed Brown 
10*a0154adeSJed Brown //------------------------------------------------------------------------------
11*a0154adeSJed Brown // Non-Tensor Basis Kernels
12*a0154adeSJed Brown //------------------------------------------------------------------------------
13*a0154adeSJed Brown 
14*a0154adeSJed Brown //------------------------------------------------------------------------------
15*a0154adeSJed Brown // Interp
16*a0154adeSJed Brown //------------------------------------------------------------------------------
17*a0154adeSJed Brown extern "C" __global__ void Interp(const CeedInt num_elem, const CeedInt transpose,
18*a0154adeSJed Brown                                   const CeedScalar *d_B,
19*a0154adeSJed Brown                                   const CeedScalar *__restrict__ d_U,
20*a0154adeSJed Brown                                   CeedScalar *__restrict__ d_V) {
21*a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
22*a0154adeSJed Brown 
23*a0154adeSJed Brown   const CeedScalar *U;
24*a0154adeSJed Brown   CeedScalar V;
25*a0154adeSJed Brown   //TODO load B in shared memory if blockDim.z > 1?
26*a0154adeSJed Brown 
27*a0154adeSJed Brown   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem;
28*a0154adeSJed Brown        elem += gridDim.x*blockDim.z) {
29*a0154adeSJed Brown     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
30*a0154adeSJed Brown       if (transpose) { // run with P threads
31*a0154adeSJed Brown         U = d_U + elem*BASIS_Q + comp*num_elem*BASIS_Q;
32*a0154adeSJed Brown         V = 0.0;
33*a0154adeSJed Brown         for (CeedInt i = 0; i < BASIS_Q; i++)
34*a0154adeSJed Brown           V += d_B[t_id + i*BASIS_P]*U[i];
35*a0154adeSJed Brown 
36*a0154adeSJed Brown         d_V[elem*BASIS_P + comp*num_elem*BASIS_P + t_id] = V;
37*a0154adeSJed Brown       } else { // run with Q threads
38*a0154adeSJed Brown         U = d_U + elem*BASIS_P + comp*num_elem*BASIS_P;
39*a0154adeSJed Brown         V = 0.0;
40*a0154adeSJed Brown         for (CeedInt i = 0; i < BASIS_P; i++)
41*a0154adeSJed Brown           V += d_B[i + t_id*BASIS_P]*U[i];
42*a0154adeSJed Brown 
43*a0154adeSJed Brown         d_V[elem*BASIS_Q + comp*num_elem*BASIS_Q + t_id] = V;
44*a0154adeSJed Brown       }
45*a0154adeSJed Brown     }
46*a0154adeSJed Brown   }
47*a0154adeSJed Brown }
48*a0154adeSJed Brown 
49*a0154adeSJed Brown //------------------------------------------------------------------------------
50*a0154adeSJed Brown // Grad
51*a0154adeSJed Brown //------------------------------------------------------------------------------
52*a0154adeSJed Brown extern "C" __global__ void Grad(const CeedInt num_elem, const CeedInt transpose,
53*a0154adeSJed Brown                                 const CeedScalar *d_G,
54*a0154adeSJed Brown                                 const CeedScalar *__restrict__ d_U,
55*a0154adeSJed Brown                                 CeedScalar *__restrict__ d_V) {
56*a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
57*a0154adeSJed Brown 
58*a0154adeSJed Brown   const CeedScalar *U;
59*a0154adeSJed Brown   //TODO load G in shared memory if blockDim.z > 1?
60*a0154adeSJed Brown 
61*a0154adeSJed Brown   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem;
62*a0154adeSJed Brown        elem += gridDim.x*blockDim.z) {
63*a0154adeSJed Brown     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
64*a0154adeSJed Brown       if (transpose) { // run with P threads
65*a0154adeSJed Brown         CeedScalar V = 0.0;
66*a0154adeSJed Brown         for (CeedInt dim = 0; dim < BASIS_DIM; dim++) {
67*a0154adeSJed Brown           U = d_U + elem*BASIS_Q + comp*num_elem*BASIS_Q +
68*a0154adeSJed Brown               dim*BASIS_NUM_COMP*num_elem*BASIS_Q;
69*a0154adeSJed Brown           for (CeedInt i = 0; i < BASIS_Q; i++)
70*a0154adeSJed Brown             V += d_G[t_id + i*BASIS_P + dim*BASIS_P*BASIS_Q]*U[i];
71*a0154adeSJed Brown         }
72*a0154adeSJed Brown         d_V[elem*BASIS_P + comp*num_elem*BASIS_P + t_id] = V;
73*a0154adeSJed Brown       } else { // run with Q threads
74*a0154adeSJed Brown         CeedScalar V[BASIS_DIM];
75*a0154adeSJed Brown         U = d_U + elem*BASIS_P + comp*num_elem*BASIS_P;
76*a0154adeSJed Brown         for (CeedInt dim = 0; dim < BASIS_DIM; dim++)
77*a0154adeSJed Brown           V[dim] = 0.0;
78*a0154adeSJed Brown 
79*a0154adeSJed Brown         for (CeedInt i = 0; i < BASIS_P; i++) {
80*a0154adeSJed Brown           const CeedScalar val = U[i];
81*a0154adeSJed Brown           for(CeedInt dim = 0; dim < BASIS_DIM; dim++)
82*a0154adeSJed Brown             V[dim] += d_G[i + t_id*BASIS_P + dim*BASIS_P*BASIS_Q]*val;
83*a0154adeSJed Brown         }
84*a0154adeSJed Brown         for (CeedInt dim = 0; dim < BASIS_DIM; dim++) {
85*a0154adeSJed Brown           d_V[elem*BASIS_Q + comp*num_elem*BASIS_Q +
86*a0154adeSJed Brown               dim*BASIS_NUM_COMP*num_elem*BASIS_Q + t_id] = V[dim];
87*a0154adeSJed Brown         }
88*a0154adeSJed Brown       }
89*a0154adeSJed Brown     }
90*a0154adeSJed Brown   }
91*a0154adeSJed Brown }
92*a0154adeSJed Brown 
93*a0154adeSJed Brown //------------------------------------------------------------------------------
94*a0154adeSJed Brown // Weight
95*a0154adeSJed Brown //------------------------------------------------------------------------------
96*a0154adeSJed Brown extern "C" __global__ void Weight(const CeedInt num_elem,
97*a0154adeSJed Brown                                   const CeedScalar *__restrict__ qweight,
98*a0154adeSJed Brown                                   CeedScalar *__restrict__ d_V) {
99*a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
100*a0154adeSJed Brown   //TODO load qweight in shared memory if blockDim.z > 1?
101*a0154adeSJed Brown   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem;
102*a0154adeSJed Brown        elem += gridDim.x*blockDim.z) {
103*a0154adeSJed Brown     d_V[elem*BASIS_Q + t_id] = qweight[t_id];
104*a0154adeSJed Brown   }
105*a0154adeSJed Brown }
106*a0154adeSJed Brown 
107*a0154adeSJed Brown //------------------------------------------------------------------------------
108