xref: /libCEED/include/ceed/jit-source/cuda/cuda-ref-basis-nontensor.h (revision 2b730f8b5a9c809740a0b3b302db43a719c636b1)
1a0154adeSJed Brown // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2a0154adeSJed Brown // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3a0154adeSJed Brown //
4a0154adeSJed Brown // SPDX-License-Identifier: BSD-2-Clause
5a0154adeSJed Brown //
6a0154adeSJed Brown // This file is part of CEED:  http://github.com/ceed
7a0154adeSJed Brown 
8c9c2c079SJeremy L Thompson #include <ceed.h>
9a0154adeSJed Brown 
10a0154adeSJed Brown //------------------------------------------------------------------------------
11a0154adeSJed Brown // Non-Tensor Basis Kernels
12a0154adeSJed Brown //------------------------------------------------------------------------------
13a0154adeSJed Brown 
14a0154adeSJed Brown //------------------------------------------------------------------------------
15a0154adeSJed Brown // Interp
16a0154adeSJed Brown //------------------------------------------------------------------------------
17*2b730f8bSJeremy L Thompson extern "C" __global__ void Interp(const CeedInt num_elem, const CeedInt transpose, const CeedScalar *d_B, const CeedScalar *__restrict__ d_U,
18a0154adeSJed Brown                                   CeedScalar *__restrict__ d_V) {
19a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
20a0154adeSJed Brown 
21a0154adeSJed Brown   const CeedScalar *U;
22a0154adeSJed Brown   CeedScalar        V;
23a0154adeSJed Brown   // TODO load B in shared memory if blockDim.z > 1?
24a0154adeSJed Brown 
25*2b730f8bSJeremy L Thompson   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
26a0154adeSJed Brown     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
27a0154adeSJed Brown       if (transpose) {  // run with P threads
28a0154adeSJed Brown         U = d_U + elem * BASIS_Q + comp * num_elem * BASIS_Q;
29a0154adeSJed Brown         V = 0.0;
30*2b730f8bSJeremy L Thompson         for (CeedInt i = 0; i < BASIS_Q; i++) V += d_B[t_id + i * BASIS_P] * U[i];
31a0154adeSJed Brown 
32a0154adeSJed Brown         d_V[elem * BASIS_P + comp * num_elem * BASIS_P + t_id] = V;
33a0154adeSJed Brown       } else {  // run with Q threads
34a0154adeSJed Brown         U = d_U + elem * BASIS_P + comp * num_elem * BASIS_P;
35a0154adeSJed Brown         V = 0.0;
36*2b730f8bSJeremy L Thompson         for (CeedInt i = 0; i < BASIS_P; i++) V += d_B[i + t_id * BASIS_P] * U[i];
37a0154adeSJed Brown 
38a0154adeSJed Brown         d_V[elem * BASIS_Q + comp * num_elem * BASIS_Q + t_id] = V;
39a0154adeSJed Brown       }
40a0154adeSJed Brown     }
41a0154adeSJed Brown   }
42a0154adeSJed Brown }
43a0154adeSJed Brown 
44a0154adeSJed Brown //------------------------------------------------------------------------------
45a0154adeSJed Brown // Grad
46a0154adeSJed Brown //------------------------------------------------------------------------------
47*2b730f8bSJeremy L Thompson extern "C" __global__ void Grad(const CeedInt num_elem, const CeedInt transpose, const CeedScalar *d_G, const CeedScalar *__restrict__ d_U,
48a0154adeSJed Brown                                 CeedScalar *__restrict__ d_V) {
49a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
50a0154adeSJed Brown 
51a0154adeSJed Brown   const CeedScalar *U;
52a0154adeSJed Brown   // TODO load G in shared memory if blockDim.z > 1?
53a0154adeSJed Brown 
54*2b730f8bSJeremy L Thompson   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
55a0154adeSJed Brown     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
56a0154adeSJed Brown       if (transpose) {  // run with P threads
57a0154adeSJed Brown         CeedScalar V = 0.0;
58a0154adeSJed Brown         for (CeedInt dim = 0; dim < BASIS_DIM; dim++) {
59*2b730f8bSJeremy L Thompson           U = d_U + elem * BASIS_Q + comp * num_elem * BASIS_Q + dim * BASIS_NUM_COMP * num_elem * BASIS_Q;
60*2b730f8bSJeremy L Thompson           for (CeedInt i = 0; i < BASIS_Q; i++) V += d_G[t_id + i * BASIS_P + dim * BASIS_P * BASIS_Q] * U[i];
61a0154adeSJed Brown         }
62a0154adeSJed Brown 
63a0154adeSJed Brown         d_V[elem * BASIS_P + comp * num_elem * BASIS_P + t_id] = V;
64a0154adeSJed Brown       } else {  // run with Q threads
65a0154adeSJed Brown         CeedScalar V[BASIS_DIM];
66a0154adeSJed Brown         U = d_U + elem * BASIS_P + comp * num_elem * BASIS_P;
67*2b730f8bSJeremy L Thompson         for (CeedInt dim = 0; dim < BASIS_DIM; dim++) V[dim] = 0.0;
68a0154adeSJed Brown         for (CeedInt i = 0; i < BASIS_P; i++) {
69a0154adeSJed Brown           const CeedScalar val = U[i];
70*2b730f8bSJeremy L Thompson           for (CeedInt dim = 0; dim < BASIS_DIM; dim++) V[dim] += d_G[i + t_id * BASIS_P + dim * BASIS_P * BASIS_Q] * val;
71a0154adeSJed Brown         }
72a0154adeSJed Brown 
73a0154adeSJed Brown         for (CeedInt dim = 0; dim < BASIS_DIM; dim++) {
74a0154adeSJed Brown           d_V[elem * BASIS_Q + comp * num_elem * BASIS_Q + dim * BASIS_NUM_COMP * num_elem * BASIS_Q + t_id] = V[dim];
75a0154adeSJed Brown         }
76a0154adeSJed Brown       }
77a0154adeSJed Brown     }
78a0154adeSJed Brown   }
79a0154adeSJed Brown }
80a0154adeSJed Brown 
81a0154adeSJed Brown //------------------------------------------------------------------------------
82a0154adeSJed Brown // Weight
83a0154adeSJed Brown //------------------------------------------------------------------------------
84*2b730f8bSJeremy L Thompson extern "C" __global__ void Weight(const CeedInt num_elem, const CeedScalar *__restrict__ q_weight, CeedScalar *__restrict__ d_V) {
85a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
86a0154adeSJed Brown   // TODO load q_weight in shared memory if blockDim.z > 1?
87*2b730f8bSJeremy L Thompson   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
88a0154adeSJed Brown     d_V[elem * BASIS_Q + t_id] = q_weight[t_id];
89a0154adeSJed Brown   }
90a0154adeSJed Brown }
91a0154adeSJed Brown 
92a0154adeSJed Brown //------------------------------------------------------------------------------
93