xref: /libCEED/include/ceed/jit-source/cuda/cuda-shared-basis-nontensor.h (revision 9ff05d55386b4e6413be60b7231511258906fd9f)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for CUDA shared memory non-tensor basis
10 #include <ceed/types.h>
11 
12 #include "cuda-shared-basis-nontensor-templates.h"
13 #include "cuda-shared-basis-read-write-templates.h"
14 
15 //------------------------------------------------------------------------------
16 // Interp kernel by dim
17 //------------------------------------------------------------------------------
18 extern "C" __global__ void Interp(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
19   extern __shared__ CeedScalar slice[];
20 
21   SharedData_Cuda data;
22   data.t_id_x = threadIdx.x;
23   data.t_id_y = threadIdx.y;
24   data.t_id_z = threadIdx.z;
25   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
26   data.slice  = slice + data.t_id_z * T_1D;
27 
28   CeedScalar r_U[BASIS_NUM_COMP];
29   CeedScalar r_V[BASIS_NUM_COMP];
30 
31   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
32     ReadElementStrided1d<BASIS_NUM_COMP, BASIS_P>(data, elem, 1, BASIS_P * num_elem, BASIS_P, d_U, r_U);
33     InterpNonTensor<BASIS_NUM_COMP, BASIS_P, BASIS_Q>(data, r_U, c_B, r_V);
34     WriteElementStrided1d<BASIS_NUM_COMP, BASIS_Q>(data, elem, 1, BASIS_Q * num_elem, BASIS_Q, r_V, d_V);
35   }
36 }
37 
38 extern "C" __global__ void InterpTranspose(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
39                                            CeedScalar *__restrict__ d_V) {
40   extern __shared__ CeedScalar slice[];
41 
42   SharedData_Cuda data;
43   data.t_id_x = threadIdx.x;
44   data.t_id_y = threadIdx.y;
45   data.t_id_z = threadIdx.z;
46   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
47   data.slice  = slice + data.t_id_z * T_1D;
48 
49   CeedScalar r_U[BASIS_NUM_COMP];
50   CeedScalar r_V[BASIS_NUM_COMP];
51 
52   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
53     ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q>(data, elem, 1, BASIS_Q * num_elem, BASIS_Q, d_U, r_U);
54     InterpTransposeNonTensor<BASIS_NUM_COMP, BASIS_P, BASIS_Q>(data, r_U, c_B, r_V);
55     WriteElementStrided1d<BASIS_NUM_COMP, BASIS_P>(data, elem, 1, BASIS_P * num_elem, BASIS_P, r_V, d_V);
56   }
57 }
58 
59 extern "C" __global__ void InterpTransposeAdd(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
60                                               CeedScalar *__restrict__ d_V) {
61   extern __shared__ CeedScalar slice[];
62 
63   SharedData_Cuda data;
64   data.t_id_x = threadIdx.x;
65   data.t_id_y = threadIdx.y;
66   data.t_id_z = threadIdx.z;
67   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
68   data.slice  = slice + data.t_id_z * T_1D;
69 
70   CeedScalar r_U[BASIS_NUM_COMP];
71   CeedScalar r_V[BASIS_NUM_COMP];
72 
73   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
74     ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q>(data, elem, 1, BASIS_Q * num_elem, BASIS_Q, d_U, r_U);
75     InterpTransposeNonTensor<BASIS_NUM_COMP, BASIS_P, BASIS_Q>(data, r_U, c_B, r_V);
76     SumElementStrided1d<BASIS_NUM_COMP, BASIS_P>(data, elem, 1, BASIS_P * num_elem, BASIS_P, r_V, d_V);
77   }
78 }
79 
80 //------------------------------------------------------------------------------
81 // Grad kernel by dim
82 //------------------------------------------------------------------------------
83 extern "C" __global__ void Grad(const CeedInt num_elem, const CeedScalar *c_G, const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
84   extern __shared__ CeedScalar slice[];
85 
86   SharedData_Cuda data;
87   data.t_id_x = threadIdx.x;
88   data.t_id_y = threadIdx.y;
89   data.t_id_z = threadIdx.z;
90   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
91   data.slice  = slice + data.t_id_z * T_1D;
92 
93   CeedScalar r_U[BASIS_NUM_COMP];
94   CeedScalar r_V[BASIS_NUM_COMP * BASIS_DIM];
95 
96   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
97     ReadElementStrided1d<BASIS_NUM_COMP, BASIS_P>(data, elem, 1, BASIS_P * num_elem, BASIS_P, d_U, r_U);
98     GradNonTensor<BASIS_NUM_COMP, BASIS_DIM, BASIS_P, BASIS_Q>(data, r_U, c_G, r_V);
99     WriteElementStrided1d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q>(data, elem, 1, BASIS_Q * num_elem, BASIS_Q, r_V, d_V);
100   }
101 }
102 
103 extern "C" __global__ void GradTranspose(const CeedInt num_elem, const CeedScalar *c_G, const CeedScalar *__restrict__ d_U,
104                                          CeedScalar *__restrict__ d_V) {
105   extern __shared__ CeedScalar slice[];
106 
107   SharedData_Cuda data;
108   data.t_id_x = threadIdx.x;
109   data.t_id_y = threadIdx.y;
110   data.t_id_z = threadIdx.z;
111   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
112   data.slice  = slice + data.t_id_z * T_1D;
113 
114   CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM];
115   CeedScalar r_V[BASIS_NUM_COMP];
116 
117   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
118     ReadElementStrided1d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q>(data, elem, 1, BASIS_Q * num_elem, BASIS_Q, d_U, r_U);
119     GradTransposeNonTensor<BASIS_NUM_COMP, BASIS_DIM, BASIS_P, BASIS_Q>(data, r_U, c_G, r_V);
120     WriteElementStrided1d<BASIS_NUM_COMP, BASIS_P>(data, elem, 1, BASIS_P * num_elem, BASIS_P, r_V, d_V);
121   }
122 }
123 
124 extern "C" __global__ void GradTransposeAdd(const CeedInt num_elem, const CeedScalar *c_G, const CeedScalar *__restrict__ d_U,
125                                             CeedScalar *__restrict__ d_V) {
126   extern __shared__ CeedScalar slice[];
127 
128   SharedData_Cuda data;
129   data.t_id_x = threadIdx.x;
130   data.t_id_y = threadIdx.y;
131   data.t_id_z = threadIdx.z;
132   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
133   data.slice  = slice + data.t_id_z * T_1D;
134 
135   CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM];
136   CeedScalar r_V[BASIS_NUM_COMP];
137 
138   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
139     ReadElementStrided1d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q>(data, elem, 1, BASIS_Q * num_elem, BASIS_Q, d_U, r_U);
140     GradTransposeNonTensor<BASIS_NUM_COMP, BASIS_DIM, BASIS_P, BASIS_Q>(data, r_U, c_G, r_V);
141     SumElementStrided1d<BASIS_NUM_COMP, BASIS_P>(data, elem, 1, BASIS_P * num_elem, BASIS_P, r_V, d_V);
142   }
143 }
144 
145 //------------------------------------------------------------------------------
146 // Weight kernels by dim
147 //------------------------------------------------------------------------------
148 extern "C" __global__ void Weight(const CeedInt num_elem, const CeedScalar *__restrict__ q_weight, CeedScalar *__restrict__ d_W) {
149   extern __shared__ CeedScalar slice[];
150 
151   SharedData_Cuda data;
152   data.t_id_x = threadIdx.x;
153   data.t_id_y = threadIdx.y;
154   data.t_id_z = threadIdx.z;
155   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
156   data.slice  = slice + data.t_id_z * T_1D;
157 
158   CeedScalar r_W[1];
159 
160   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
161     WeightNonTensor<BASIS_Q>(data, q_weight, r_W);
162     WriteElementStrided1d<1, BASIS_Q>(data, elem, 1, BASIS_Q * num_elem, BASIS_Q, r_W, d_W);
163   }
164 }
165