xref: /libCEED/include/ceed/jit-source/cuda/cuda-ref-basis-nontensor.h (revision b2165e7a2e371018feedcb47974a027ed47e0487)
1a0154adeSJed Brown // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2a0154adeSJed Brown // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3a0154adeSJed Brown //
4a0154adeSJed Brown // SPDX-License-Identifier: BSD-2-Clause
5a0154adeSJed Brown //
6a0154adeSJed Brown // This file is part of CEED:  http://github.com/ceed
7a0154adeSJed Brown 
8*b2165e7aSSebastian Grimberg /// @file
9*b2165e7aSSebastian Grimberg /// Internal header for CUDA non-tensor product basis
10*b2165e7aSSebastian Grimberg #ifndef _ceed_cuda_ref_basis_nontensor_h
11*b2165e7aSSebastian Grimberg #define _ceed_cuda_ref_basis_nontensor_h
12*b2165e7aSSebastian Grimberg 
13c9c2c079SJeremy L Thompson #include <ceed.h>
14a0154adeSJed Brown 
15a0154adeSJed Brown //------------------------------------------------------------------------------
16a0154adeSJed Brown // Non-Tensor Basis Kernels
17a0154adeSJed Brown //------------------------------------------------------------------------------
18a0154adeSJed Brown 
19a0154adeSJed Brown //------------------------------------------------------------------------------
20a0154adeSJed Brown // Interp
21a0154adeSJed Brown //------------------------------------------------------------------------------
222b730f8bSJeremy L Thompson extern "C" __global__ void Interp(const CeedInt num_elem, const CeedInt transpose, const CeedScalar *d_B, const CeedScalar *__restrict__ d_U,
23a0154adeSJed Brown                                   CeedScalar *__restrict__ d_V) {
24a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
25a0154adeSJed Brown 
26a0154adeSJed Brown   const CeedScalar *U;
27a0154adeSJed Brown   CeedScalar        V;
28a0154adeSJed Brown   // TODO load B in shared memory if blockDim.z > 1?
29a0154adeSJed Brown 
302b730f8bSJeremy L Thompson   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
31a0154adeSJed Brown     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
32a0154adeSJed Brown       if (transpose) {  // run with P threads
33a0154adeSJed Brown         U = d_U + elem * BASIS_Q + comp * num_elem * BASIS_Q;
34a0154adeSJed Brown         V = 0.0;
352b730f8bSJeremy L Thompson         for (CeedInt i = 0; i < BASIS_Q; i++) V += d_B[t_id + i * BASIS_P] * U[i];
36a0154adeSJed Brown 
37a0154adeSJed Brown         d_V[elem * BASIS_P + comp * num_elem * BASIS_P + t_id] = V;
38a0154adeSJed Brown       } else {  // run with Q threads
39a0154adeSJed Brown         U = d_U + elem * BASIS_P + comp * num_elem * BASIS_P;
40a0154adeSJed Brown         V = 0.0;
412b730f8bSJeremy L Thompson         for (CeedInt i = 0; i < BASIS_P; i++) V += d_B[i + t_id * BASIS_P] * U[i];
42a0154adeSJed Brown 
43a0154adeSJed Brown         d_V[elem * BASIS_Q + comp * num_elem * BASIS_Q + t_id] = V;
44a0154adeSJed Brown       }
45a0154adeSJed Brown     }
46a0154adeSJed Brown   }
47a0154adeSJed Brown }
48a0154adeSJed Brown 
49a0154adeSJed Brown //------------------------------------------------------------------------------
50a0154adeSJed Brown // Grad
51a0154adeSJed Brown //------------------------------------------------------------------------------
522b730f8bSJeremy L Thompson extern "C" __global__ void Grad(const CeedInt num_elem, const CeedInt transpose, const CeedScalar *d_G, const CeedScalar *__restrict__ d_U,
53a0154adeSJed Brown                                 CeedScalar *__restrict__ d_V) {
54a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
55a0154adeSJed Brown 
56a0154adeSJed Brown   const CeedScalar *U;
57a0154adeSJed Brown   // TODO load G in shared memory if blockDim.z > 1?
58a0154adeSJed Brown 
592b730f8bSJeremy L Thompson   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
60a0154adeSJed Brown     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
61a0154adeSJed Brown       if (transpose) {  // run with P threads
62a0154adeSJed Brown         CeedScalar V = 0.0;
63a0154adeSJed Brown         for (CeedInt dim = 0; dim < BASIS_DIM; dim++) {
642b730f8bSJeremy L Thompson           U = d_U + elem * BASIS_Q + comp * num_elem * BASIS_Q + dim * BASIS_NUM_COMP * num_elem * BASIS_Q;
652b730f8bSJeremy L Thompson           for (CeedInt i = 0; i < BASIS_Q; i++) V += d_G[t_id + i * BASIS_P + dim * BASIS_P * BASIS_Q] * U[i];
66a0154adeSJed Brown         }
67a0154adeSJed Brown 
68a0154adeSJed Brown         d_V[elem * BASIS_P + comp * num_elem * BASIS_P + t_id] = V;
69a0154adeSJed Brown       } else {  // run with Q threads
70a0154adeSJed Brown         CeedScalar V[BASIS_DIM];
71a0154adeSJed Brown         U = d_U + elem * BASIS_P + comp * num_elem * BASIS_P;
722b730f8bSJeremy L Thompson         for (CeedInt dim = 0; dim < BASIS_DIM; dim++) V[dim] = 0.0;
73a0154adeSJed Brown         for (CeedInt i = 0; i < BASIS_P; i++) {
74a0154adeSJed Brown           const CeedScalar val = U[i];
752b730f8bSJeremy L Thompson           for (CeedInt dim = 0; dim < BASIS_DIM; dim++) V[dim] += d_G[i + t_id * BASIS_P + dim * BASIS_P * BASIS_Q] * val;
76a0154adeSJed Brown         }
77a0154adeSJed Brown 
78a0154adeSJed Brown         for (CeedInt dim = 0; dim < BASIS_DIM; dim++) {
79a0154adeSJed Brown           d_V[elem * BASIS_Q + comp * num_elem * BASIS_Q + dim * BASIS_NUM_COMP * num_elem * BASIS_Q + t_id] = V[dim];
80a0154adeSJed Brown         }
81a0154adeSJed Brown       }
82a0154adeSJed Brown     }
83a0154adeSJed Brown   }
84a0154adeSJed Brown }
85a0154adeSJed Brown 
86a0154adeSJed Brown //------------------------------------------------------------------------------
87a0154adeSJed Brown // Weight
88a0154adeSJed Brown //------------------------------------------------------------------------------
892b730f8bSJeremy L Thompson extern "C" __global__ void Weight(const CeedInt num_elem, const CeedScalar *__restrict__ q_weight, CeedScalar *__restrict__ d_V) {
90a0154adeSJed Brown   const CeedInt t_id = threadIdx.x;
91a0154adeSJed Brown   // TODO load q_weight in shared memory if blockDim.z > 1?
922b730f8bSJeremy L Thompson   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
93a0154adeSJed Brown     d_V[elem * BASIS_Q + t_id] = q_weight[t_id];
94a0154adeSJed Brown   }
95a0154adeSJed Brown }
96a0154adeSJed Brown 
97a0154adeSJed Brown //------------------------------------------------------------------------------
98*b2165e7aSSebastian Grimberg 
99*b2165e7aSSebastian Grimberg #endif
100