xref: /libCEED/backends/hip-ref/ceed-hip-ref-restriction.c (revision 5047007e18629437ece4e0627cb87e8530a4075f)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <string.h>
14 #include <hip/hip_runtime.h>
15 
16 #include "../hip/ceed-hip-common.h"
17 #include "../hip/ceed-hip-compile.h"
18 #include "ceed-hip-ref.h"
19 
20 //------------------------------------------------------------------------------
21 // Apply restriction
22 //------------------------------------------------------------------------------
23 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
24   CeedElemRestriction_Hip *impl;
25   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
26   Ceed ceed;
27   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
28   Ceed_Hip *data;
29   CeedCallBackend(CeedGetData(ceed, &data));
30   const CeedInt block_size = 64;
31   const CeedInt num_nodes  = impl->num_nodes;
32   CeedInt       num_elem, elem_size;
33   CeedElemRestrictionGetNumElements(r, &num_elem);
34   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
35   hipFunction_t kernel;
36 
37   // Get vectors
38   const CeedScalar *d_u;
39   CeedScalar       *d_v;
40   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
41   if (t_mode == CEED_TRANSPOSE) {
42     // Sum into for transpose mode, e-vec to l-vec
43     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
44   } else {
45     // Overwrite for notranspose mode, l-vec to e-vec
46     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
47   }
48 
49   // Restrict
50   if (t_mode == CEED_NOTRANSPOSE) {
51     // L-vector -> E-vector
52     if (impl->d_ind) {
53       // -- Offsets provided
54       kernel             = impl->OffsetNoTranspose;
55       void   *args[]     = {&num_elem, &impl->d_ind, &d_u, &d_v};
56       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
57       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
58     } else {
59       // -- Strided restriction
60       kernel             = impl->StridedNoTranspose;
61       void   *args[]     = {&num_elem, &d_u, &d_v};
62       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
63       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
64     }
65   } else {
66     // E-vector -> L-vector
67     if (impl->d_ind) {
68       // -- Offsets provided
69       kernel       = impl->OffsetTranspose;
70       void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
71       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
72     } else {
73       // -- Strided restriction
74       kernel       = impl->StridedTranspose;
75       void *args[] = {&num_elem, &d_u, &d_v};
76       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
77     }
78   }
79 
80   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
81 
82   // Restore arrays
83   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
84   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
85   return CEED_ERROR_SUCCESS;
86 }
87 
88 //------------------------------------------------------------------------------
89 // Get offsets
90 //------------------------------------------------------------------------------
91 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) {
92   CeedElemRestriction_Hip *impl;
93   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
94 
95   switch (mem_type) {
96     case CEED_MEM_HOST:
97       *offsets = impl->h_ind;
98       break;
99     case CEED_MEM_DEVICE:
100       *offsets = impl->d_ind;
101       break;
102   }
103   return CEED_ERROR_SUCCESS;
104 }
105 
106 //------------------------------------------------------------------------------
107 // Destroy restriction
108 //------------------------------------------------------------------------------
109 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) {
110   CeedElemRestriction_Hip *impl;
111   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
112 
113   Ceed ceed;
114   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
115   CeedCallHip(ceed, hipModuleUnload(impl->module));
116   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
117   CeedCallHip(ceed, hipFree(impl->d_ind_allocated));
118   CeedCallHip(ceed, hipFree(impl->d_t_offsets));
119   CeedCallHip(ceed, hipFree(impl->d_t_indices));
120   CeedCallHip(ceed, hipFree(impl->d_l_vec_indices));
121   CeedCallBackend(CeedFree(&impl));
122 
123   return CEED_ERROR_SUCCESS;
124 }
125 
126 //------------------------------------------------------------------------------
127 // Create transpose offsets and indices
128 //------------------------------------------------------------------------------
129 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r, const CeedInt *indices) {
130   Ceed ceed;
131   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
132   CeedElemRestriction_Hip *impl;
133   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
134   CeedSize l_size;
135   CeedInt  num_elem, elem_size, num_comp;
136   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
137   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
138   CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
139   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
140 
141   // Count num_nodes
142   bool *is_node;
143   CeedCallBackend(CeedCalloc(l_size, &is_node));
144   const CeedInt size_indices = num_elem * elem_size;
145   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
146   CeedInt num_nodes = 0;
147   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
148   impl->num_nodes = num_nodes;
149 
150   // L-vector offsets array
151   CeedInt *ind_to_offset, *l_vec_indices;
152   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
153   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
154   CeedInt j = 0;
155   for (CeedInt i = 0; i < l_size; i++) {
156     if (is_node[i]) {
157       l_vec_indices[j] = i;
158       ind_to_offset[i] = j++;
159     }
160   }
161   CeedCallBackend(CeedFree(&is_node));
162 
163   // Compute transpose offsets and indices
164   const CeedInt size_offsets = num_nodes + 1;
165   CeedInt      *t_offsets;
166   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
167   CeedInt *t_indices;
168   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
169   // Count node multiplicity
170   for (CeedInt e = 0; e < num_elem; ++e) {
171     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
172   }
173   // Convert to running sum
174   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
175   // List all E-vec indices associated with L-vec node
176   for (CeedInt e = 0; e < num_elem; ++e) {
177     for (CeedInt i = 0; i < elem_size; ++i) {
178       const CeedInt lid                          = elem_size * e + i;
179       const CeedInt gid                          = indices[lid];
180       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
181     }
182   }
183   // Reset running sum
184   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
185   t_offsets[0] = 0;
186 
187   // Copy data to device
188   // -- L-vector indices
189   CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt)));
190   CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice));
191   // -- Transpose offsets
192   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt)));
193   CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice));
194   // -- Transpose indices
195   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt)));
196   CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice));
197 
198   // Cleanup
199   CeedCallBackend(CeedFree(&ind_to_offset));
200   CeedCallBackend(CeedFree(&l_vec_indices));
201   CeedCallBackend(CeedFree(&t_offsets));
202   CeedCallBackend(CeedFree(&t_indices));
203 
204   return CEED_ERROR_SUCCESS;
205 }
206 
207 //------------------------------------------------------------------------------
208 // Create restriction
209 //------------------------------------------------------------------------------
210 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients,
211                                   const CeedInt8 *curl_orients, CeedElemRestriction r) {
212   Ceed ceed;
213   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
214   CeedElemRestriction_Hip *impl;
215   CeedCallBackend(CeedCalloc(1, &impl));
216   CeedInt num_elem, num_comp, elem_size;
217   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
218   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
219   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
220   CeedInt size        = num_elem * elem_size;
221   CeedInt strides[3]  = {1, size, elem_size};
222   CeedInt comp_stride = 1;
223 
224   CeedRestrictionType rstr_type;
225   CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type));
226   CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND,
227             "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented");
228 
229   // Stride data
230   bool is_strided;
231   CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided));
232   if (is_strided) {
233     bool has_backend_strides;
234     CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides));
235     if (!has_backend_strides) {
236       CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
237     }
238   } else {
239     CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
240   }
241 
242   impl->h_ind           = NULL;
243   impl->h_ind_allocated = NULL;
244   impl->d_ind           = NULL;
245   impl->d_ind_allocated = NULL;
246   impl->d_t_indices     = NULL;
247   impl->d_t_offsets     = NULL;
248   impl->num_nodes       = size;
249   CeedCallBackend(CeedElemRestrictionSetData(r, impl));
250   CeedInt layout[3] = {1, elem_size * num_elem, elem_size};
251   CeedCallBackend(CeedElemRestrictionSetELayout(r, layout));
252 
253   // Set up device indices/offset arrays
254   switch (mem_type) {
255     case CEED_MEM_HOST: {
256       switch (copy_mode) {
257         case CEED_OWN_POINTER:
258           impl->h_ind_allocated = (CeedInt *)indices;
259           impl->h_ind           = (CeedInt *)indices;
260           break;
261         case CEED_USE_POINTER:
262           impl->h_ind = (CeedInt *)indices;
263           break;
264         case CEED_COPY_VALUES:
265           if (indices != NULL) {
266             CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
267             memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
268             impl->h_ind = impl->h_ind_allocated;
269           }
270           break;
271       }
272       if (indices != NULL) {
273         CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
274         impl->d_ind_allocated = impl->d_ind;  // We own the device memory
275         CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice));
276         CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices));
277       }
278       break;
279     }
280     case CEED_MEM_DEVICE: {
281       switch (copy_mode) {
282         case CEED_COPY_VALUES:
283           if (indices != NULL) {
284             CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
285             impl->d_ind_allocated = impl->d_ind;  // We own the device memory
286             CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice));
287           }
288           break;
289         case CEED_OWN_POINTER:
290           impl->d_ind           = (CeedInt *)indices;
291           impl->d_ind_allocated = impl->d_ind;
292           break;
293         case CEED_USE_POINTER:
294           impl->d_ind = (CeedInt *)indices;
295       }
296       if (indices != NULL) {
297         CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
298         CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost));
299         impl->h_ind = impl->h_ind_allocated;
300         CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices));
301       }
302       break;
303     }
304     // LCOV_EXCL_START
305     default:
306       return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
307       // LCOV_EXCL_STOP
308   }
309 
310   // Compile HIP kernels
311   CeedInt num_nodes = impl->num_nodes;
312   char   *restriction_kernel_path, *restriction_kernel_source;
313   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path));
314   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
315   CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
316   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
317   CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem,
318                                   "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES",
319                                   strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2]));
320   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose));
321   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose));
322   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose));
323   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose));
324   CeedCallBackend(CeedFree(&restriction_kernel_path));
325   CeedCallBackend(CeedFree(&restriction_kernel_source));
326 
327   // Register backend functions
328   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Hip));
329   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Hip));
330   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Hip));
331   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Hip));
332   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Hip));
333   return CEED_ERROR_SUCCESS;
334 }
335 
336 //------------------------------------------------------------------------------
337