xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-restriction.c (revision e30587cba82ab5bd19eea7b489a50e8de8aeb030)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <cuda.h>
12 #include <cuda_runtime.h>
13 #include <stdbool.h>
14 #include <stddef.h>
15 #include <string.h>
16 
17 #include "../cuda/ceed-cuda-common.h"
18 #include "../cuda/ceed-cuda-compile.h"
19 #include "ceed-cuda-ref.h"
20 
21 //------------------------------------------------------------------------------
22 // Apply restriction
23 //------------------------------------------------------------------------------
24 static int CeedElemRestrictionApply_Cuda(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
25   CeedElemRestriction_Cuda *impl;
26   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
27   Ceed ceed;
28   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
29   Ceed_Cuda *data;
30   CeedCallBackend(CeedGetData(ceed, &data));
31   const CeedInt warp_size  = 32;
32   const CeedInt block_size = warp_size;
33   const CeedInt num_nodes  = impl->num_nodes;
34   CeedInt       num_elem, elem_size;
35   CeedElemRestrictionGetNumElements(r, &num_elem);
36   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
37   CUfunction kernel;
38 
39   // Get vectors
40   const CeedScalar *d_u;
41   CeedScalar       *d_v;
42   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
43   if (t_mode == CEED_TRANSPOSE) {
44     // Sum into for transpose mode, e-vec to l-vec
45     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
46   } else {
47     // Overwrite for notranspose mode, l-vec to e-vec
48     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
49   }
50 
51   // Restrict
52   if (t_mode == CEED_NOTRANSPOSE) {
53     // L-vector -> E-vector
54     if (impl->d_ind) {
55       // -- Offsets provided
56       kernel             = impl->OffsetNoTranspose;
57       void   *args[]     = {&num_elem, &impl->d_ind, &d_u, &d_v};
58       CeedInt block_size = elem_size < 1024 ? (elem_size > 32 ? elem_size : 32) : 1024;
59       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
60     } else {
61       // -- Strided restriction
62       kernel             = impl->StridedNoTranspose;
63       void   *args[]     = {&num_elem, &d_u, &d_v};
64       CeedInt block_size = elem_size < 1024 ? (elem_size > 32 ? elem_size : 32) : 1024;
65       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
66     }
67   } else {
68     // E-vector -> L-vector
69     if (impl->d_ind) {
70       // -- Offsets provided
71       kernel       = impl->OffsetTranspose;
72       void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
73       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
74     } else {
75       // -- Strided restriction
76       kernel       = impl->StridedTranspose;
77       void *args[] = {&num_elem, &d_u, &d_v};
78       CeedCallBackend(CeedRunKernel_Cuda(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
79     }
80   }
81 
82   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
83 
84   // Restore arrays
85   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
86   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
87   return CEED_ERROR_SUCCESS;
88 }
89 
90 //------------------------------------------------------------------------------
91 // Get offsets
92 //------------------------------------------------------------------------------
93 static int CeedElemRestrictionGetOffsets_Cuda(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) {
94   CeedElemRestriction_Cuda *impl;
95   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
96 
97   switch (mem_type) {
98     case CEED_MEM_HOST:
99       *offsets = impl->h_ind;
100       break;
101     case CEED_MEM_DEVICE:
102       *offsets = impl->d_ind;
103       break;
104   }
105   return CEED_ERROR_SUCCESS;
106 }
107 
108 //------------------------------------------------------------------------------
109 // Destroy restriction
110 //------------------------------------------------------------------------------
111 static int CeedElemRestrictionDestroy_Cuda(CeedElemRestriction r) {
112   CeedElemRestriction_Cuda *impl;
113   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
114 
115   Ceed ceed;
116   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
117   CeedCallCuda(ceed, cuModuleUnload(impl->module));
118   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
119   CeedCallCuda(ceed, cudaFree(impl->d_ind_allocated));
120   CeedCallCuda(ceed, cudaFree(impl->d_t_offsets));
121   CeedCallCuda(ceed, cudaFree(impl->d_t_indices));
122   CeedCallCuda(ceed, cudaFree(impl->d_l_vec_indices));
123   CeedCallBackend(CeedFree(&impl));
124 
125   return CEED_ERROR_SUCCESS;
126 }
127 
128 //------------------------------------------------------------------------------
129 // Create transpose offsets and indices
130 //------------------------------------------------------------------------------
131 static int CeedElemRestrictionOffset_Cuda(const CeedElemRestriction r, const CeedInt *indices) {
132   Ceed ceed;
133   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
134   CeedElemRestriction_Cuda *impl;
135   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
136   CeedSize l_size;
137   CeedInt  num_elem, elem_size, num_comp;
138   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
139   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
140   CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
141   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
142 
143   // Count num_nodes
144   bool *is_node;
145   CeedCallBackend(CeedCalloc(l_size, &is_node));
146   const CeedInt size_indices = num_elem * elem_size;
147   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
148   CeedInt num_nodes = 0;
149   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
150   impl->num_nodes = num_nodes;
151 
152   // L-vector offsets array
153   CeedInt *ind_to_offset, *l_vec_indices;
154   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
155   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
156   CeedInt j = 0;
157   for (CeedInt i = 0; i < l_size; i++) {
158     if (is_node[i]) {
159       l_vec_indices[j] = i;
160       ind_to_offset[i] = j++;
161     }
162   }
163   CeedCallBackend(CeedFree(&is_node));
164 
165   // Compute transpose offsets and indices
166   const CeedInt size_offsets = num_nodes + 1;
167   CeedInt      *t_offsets;
168   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
169   CeedInt *t_indices;
170   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
171   // Count node multiplicity
172   for (CeedInt e = 0; e < num_elem; ++e) {
173     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
174   }
175   // Convert to running sum
176   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
177   // List all E-vec indices associated with L-vec node
178   for (CeedInt e = 0; e < num_elem; ++e) {
179     for (CeedInt i = 0; i < elem_size; ++i) {
180       const CeedInt lid                          = elem_size * e + i;
181       const CeedInt gid                          = indices[lid];
182       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
183     }
184   }
185   // Reset running sum
186   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
187   t_offsets[0] = 0;
188 
189   // Copy data to device
190   // -- L-vector indices
191   CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt)));
192   CeedCallCuda(ceed, cudaMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), cudaMemcpyHostToDevice));
193   // -- Transpose offsets
194   CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt)));
195   CeedCallCuda(ceed, cudaMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), cudaMemcpyHostToDevice));
196   // -- Transpose indices
197   CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt)));
198   CeedCallCuda(ceed, cudaMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), cudaMemcpyHostToDevice));
199 
200   // Cleanup
201   CeedCallBackend(CeedFree(&ind_to_offset));
202   CeedCallBackend(CeedFree(&l_vec_indices));
203   CeedCallBackend(CeedFree(&t_offsets));
204   CeedCallBackend(CeedFree(&t_indices));
205 
206   return CEED_ERROR_SUCCESS;
207 }
208 
209 //------------------------------------------------------------------------------
210 // Create restriction
211 //------------------------------------------------------------------------------
212 int CeedElemRestrictionCreate_Cuda(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients,
213                                    const CeedInt8 *curl_orients, CeedElemRestriction r) {
214   Ceed ceed;
215   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
216   CeedElemRestriction_Cuda *impl;
217   CeedCallBackend(CeedCalloc(1, &impl));
218   CeedInt num_elem, num_comp, elem_size;
219   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
220   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
221   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
222   CeedInt size        = num_elem * elem_size;
223   CeedInt strides[3]  = {1, size, elem_size};
224   CeedInt comp_stride = 1;
225 
226   CeedRestrictionType rstr_type;
227   CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type));
228   CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND,
229             "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented");
230 
231   // Stride data
232   bool is_strided;
233   CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided));
234   if (is_strided) {
235     bool has_backend_strides;
236     CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides));
237     if (!has_backend_strides) {
238       CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
239     }
240   } else {
241     CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
242   }
243 
244   impl->h_ind           = NULL;
245   impl->h_ind_allocated = NULL;
246   impl->d_ind           = NULL;
247   impl->d_ind_allocated = NULL;
248   impl->d_t_indices     = NULL;
249   impl->d_t_offsets     = NULL;
250   impl->num_nodes       = size;
251   CeedCallBackend(CeedElemRestrictionSetData(r, impl));
252   CeedInt layout[3] = {1, elem_size * num_elem, elem_size};
253   CeedCallBackend(CeedElemRestrictionSetELayout(r, layout));
254 
255   // Set up device indices/offset arrays
256   switch (mem_type) {
257     case CEED_MEM_HOST: {
258       switch (copy_mode) {
259         case CEED_OWN_POINTER:
260           impl->h_ind_allocated = (CeedInt *)indices;
261           impl->h_ind           = (CeedInt *)indices;
262           break;
263         case CEED_USE_POINTER:
264           impl->h_ind = (CeedInt *)indices;
265           break;
266         case CEED_COPY_VALUES:
267           if (indices != NULL) {
268             CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
269             memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
270             impl->h_ind = impl->h_ind_allocated;
271           }
272           break;
273       }
274       if (indices != NULL) {
275         CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
276         impl->d_ind_allocated = impl->d_ind;  // We own the device memory
277         CeedCallCuda(ceed, cudaMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), cudaMemcpyHostToDevice));
278         CeedCallBackend(CeedElemRestrictionOffset_Cuda(r, indices));
279       }
280       break;
281     }
282     case CEED_MEM_DEVICE: {
283       switch (copy_mode) {
284         case CEED_COPY_VALUES:
285           if (indices != NULL) {
286             CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
287             impl->d_ind_allocated = impl->d_ind;  // We own the device memory
288             CeedCallCuda(ceed, cudaMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), cudaMemcpyDeviceToDevice));
289           }
290           break;
291         case CEED_OWN_POINTER:
292           impl->d_ind           = (CeedInt *)indices;
293           impl->d_ind_allocated = impl->d_ind;
294           break;
295         case CEED_USE_POINTER:
296           impl->d_ind = (CeedInt *)indices;
297       }
298       if (indices != NULL) {
299         CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
300         CeedCallCuda(ceed, cudaMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), cudaMemcpyDeviceToHost));
301         impl->h_ind = impl->h_ind_allocated;
302         CeedCallBackend(CeedElemRestrictionOffset_Cuda(r, indices));
303       }
304       break;
305     }
306     // LCOV_EXCL_START
307     default:
308       return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
309       // LCOV_EXCL_STOP
310   }
311 
312   // Compile CUDA kernels
313   CeedInt num_nodes = impl->num_nodes;
314   char   *restriction_kernel_path, *restriction_kernel_source;
315   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/cuda/cuda-ref-restriction.h", &restriction_kernel_path));
316   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
317   CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
318   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
319   CeedCallBackend(CeedCompile_Cuda(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem,
320                                    "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES",
321                                    strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2]));
322   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose));
323   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose));
324   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose));
325   CeedCallBackend(CeedGetKernel_Cuda(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose));
326   CeedCallBackend(CeedFree(&restriction_kernel_path));
327   CeedCallBackend(CeedFree(&restriction_kernel_source));
328 
329   // Register backend functions
330   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Cuda));
331   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Cuda));
332   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Cuda));
333   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Cuda));
334   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Cuda));
335   return CEED_ERROR_SUCCESS;
336 }
337 
338 //------------------------------------------------------------------------------
339