xref: /libCEED/backends/hip-ref/ceed-hip-ref-restriction.c (revision 437930d19388999b5cc2d76e2fe0d14f58fb41f3)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <ceed/jit-tools.h>
20 #include <hip/hip_runtime.h>
21 #include <stdbool.h>
22 #include <stddef.h>
23 #include "ceed-hip-ref.h"
24 #include "../hip/ceed-hip-compile.h"
25 
26 //------------------------------------------------------------------------------
27 // Apply restriction
28 //------------------------------------------------------------------------------
29 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r,
30                                         CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
31   int ierr;
32   CeedElemRestriction_Hip *impl;
33   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
34   Ceed ceed;
35   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
36   Ceed_Hip *data;
37   ierr = CeedGetData(ceed, &data); CeedChkBackend(ierr);
38   const CeedInt block_size = 64;
39   const CeedInt num_nodes = impl->num_nodes;
40   CeedInt num_elem, elem_size;
41   CeedElemRestrictionGetNumElements(r, &num_elem);
42   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
43   hipFunction_t kernel;
44 
45   // Get vectors
46   const CeedScalar *d_u;
47   CeedScalar *d_v;
48   ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
49   if (t_mode == CEED_TRANSPOSE) {
50     // Sum into for transpose mode, e-vec to l-vec
51     ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
52   } else {
53     // Overwrite for notranspose mode, l-vec to e-vec
54     ierr = CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
55   }
56 
57   // Restrict
58   if (t_mode == CEED_NOTRANSPOSE) {
59     // L-vector -> E-vector
60     if (impl->d_ind) {
61       // -- Offsets provided
62       kernel = impl->OffsetNoTranspose;
63       void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
64       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
65       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
66                               block_size, args); CeedChkBackend(ierr);
67     } else {
68       // -- Strided restriction
69       kernel = impl->StridedNoTranspose;
70       void *args[] = {&num_elem, &d_u, &d_v};
71       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
72       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
73                               block_size, args); CeedChkBackend(ierr);
74     }
75   } else {
76     // E-vector -> L-vector
77     if (impl->d_ind) {
78       // -- Offsets provided
79       kernel = impl->OffsetTranspose;
80       void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices,
81                       &impl->d_t_offsets, &d_u, &d_v
82                      };
83       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
84                               block_size, args); CeedChkBackend(ierr);
85     } else {
86       // -- Strided restriction
87       kernel = impl->StridedTranspose;
88       void *args[] = {&num_elem, &d_u, &d_v};
89       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
90                               block_size, args); CeedChkBackend(ierr);
91     }
92   }
93 
94   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED)
95     *request = NULL;
96 
97   // Restore arrays
98   ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
99   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
100   return CEED_ERROR_SUCCESS;
101 }
102 
103 //------------------------------------------------------------------------------
104 // Blocked not supported
105 //------------------------------------------------------------------------------
106 int CeedElemRestrictionApplyBlock_Hip(CeedElemRestriction r, CeedInt block,
107                                       CeedTransposeMode t_mode, CeedVector u,
108                                       CeedVector v, CeedRequest *request) {
109   // LCOV_EXCL_START
110   int ierr;
111   Ceed ceed;
112   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
113   return CeedError(ceed, CEED_ERROR_BACKEND,
114                    "Backend does not implement blocked restrictions");
115   // LCOV_EXCL_STOP
116 }
117 
118 //------------------------------------------------------------------------------
119 // Get offsets
120 //------------------------------------------------------------------------------
121 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr,
122     CeedMemType mtype, const CeedInt **offsets) {
123   int ierr;
124   CeedElemRestriction_Hip *impl;
125   ierr = CeedElemRestrictionGetData(rstr, &impl); CeedChkBackend(ierr);
126 
127   switch (mtype) {
128   case CEED_MEM_HOST:
129     *offsets = impl->h_ind;
130     break;
131   case CEED_MEM_DEVICE:
132     *offsets = impl->d_ind;
133     break;
134   }
135   return CEED_ERROR_SUCCESS;
136 }
137 
138 //------------------------------------------------------------------------------
139 // Destroy restriction
140 //------------------------------------------------------------------------------
141 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) {
142   int ierr;
143   CeedElemRestriction_Hip *impl;
144   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
145 
146   Ceed ceed;
147   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
148   ierr = hipModuleUnload(impl->module); CeedChk_Hip(ceed, ierr);
149   ierr = CeedFree(&impl->h_ind_allocated); CeedChkBackend(ierr);
150   ierr = hipFree(impl->d_ind_allocated); CeedChk_Hip(ceed, ierr);
151   ierr = hipFree(impl->d_t_offsets); CeedChk_Hip(ceed, ierr);
152   ierr = hipFree(impl->d_t_indices); CeedChk_Hip(ceed, ierr);
153   ierr = hipFree(impl->d_l_vec_indices); CeedChk_Hip(ceed, ierr);
154   ierr = CeedFree(&impl); CeedChkBackend(ierr);
155 
156   return CEED_ERROR_SUCCESS;
157 }
158 
159 //------------------------------------------------------------------------------
160 // Create transpose offsets and indices
161 //------------------------------------------------------------------------------
162 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r,
163     const CeedInt *indices) {
164   int ierr;
165   Ceed ceed;
166   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
167   CeedElemRestriction_Hip *impl;
168   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
169   CeedInt num_elem, elem_size, l_size, num_comp;
170   ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr);
171   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
172   ierr = CeedElemRestrictionGetLVectorSize(r, &l_size); CeedChkBackend(ierr);
173   ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr);
174 
175   // Count num_nodes
176   bool *is_node;
177   ierr = CeedCalloc(l_size, &is_node); CeedChkBackend(ierr);
178   const CeedInt size_indices = num_elem * elem_size;
179   for (CeedInt i = 0; i < size_indices; i++)
180     is_node[indices[i]] = 1;
181   CeedInt num_nodes = 0;
182   for (CeedInt i = 0; i < l_size; i++)
183     num_nodes += is_node[i];
184   impl->num_nodes = num_nodes;
185 
186   // L-vector offsets array
187   CeedInt *ind_to_offset, *l_vec_indices;
188   ierr = CeedCalloc(l_size, &ind_to_offset); CeedChkBackend(ierr);
189   ierr = CeedCalloc(num_nodes, &l_vec_indices); CeedChkBackend(ierr);
190   CeedInt j = 0;
191   for (CeedInt i = 0; i < l_size; i++)
192     if (is_node[i]) {
193       l_vec_indices[j] = i;
194       ind_to_offset[i] = j++;
195     }
196   ierr = CeedFree(&is_node); CeedChkBackend(ierr);
197 
198   // Compute transpose offsets and indices
199   const CeedInt size_offsets = num_nodes + 1;
200   CeedInt *t_offsets;
201   ierr = CeedCalloc(size_offsets, &t_offsets); CeedChkBackend(ierr);
202   CeedInt *t_indices;
203   ierr = CeedMalloc(size_indices, &t_indices); CeedChkBackend(ierr);
204   // Count node multiplicity
205   for (CeedInt e = 0; e < num_elem; ++e)
206     for (CeedInt i = 0; i < elem_size; ++i)
207       ++t_offsets[ind_to_offset[indices[elem_size*e + i]] + 1];
208   // Convert to running sum
209   for (CeedInt i = 1; i < size_offsets; ++i)
210     t_offsets[i] += t_offsets[i-1];
211   // List all E-vec indices associated with L-vec node
212   for (CeedInt e = 0; e < num_elem; ++e) {
213     for (CeedInt i = 0; i < elem_size; ++i) {
214       const CeedInt lid = elem_size*e + i;
215       const CeedInt gid = indices[lid];
216       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
217     }
218   }
219   // Reset running sum
220   for (int i = size_offsets - 1; i > 0; --i)
221     t_offsets[i] = t_offsets[i - 1];
222   t_offsets[0] = 0;
223 
224   // Copy data to device
225   // -- L-vector indices
226   ierr = hipMalloc((void **)&impl->d_l_vec_indices, num_nodes*sizeof(CeedInt));
227   CeedChk_Hip(ceed, ierr);
228   ierr = hipMemcpy(impl->d_l_vec_indices, l_vec_indices,
229                    num_nodes*sizeof(CeedInt), hipMemcpyHostToDevice);
230   CeedChk_Hip(ceed, ierr);
231   // -- Transpose offsets
232   ierr = hipMalloc((void **)&impl->d_t_offsets, size_offsets*sizeof(CeedInt));
233   CeedChk_Hip(ceed, ierr);
234   ierr = hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets*sizeof(CeedInt),
235                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
236   // -- Transpose indices
237   ierr = hipMalloc((void **)&impl->d_t_indices, size_indices*sizeof(CeedInt));
238   CeedChk_Hip(ceed, ierr);
239   ierr = hipMemcpy(impl->d_t_indices, t_indices, size_indices*sizeof(CeedInt),
240                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
241 
242   // Cleanup
243   ierr = CeedFree(&ind_to_offset); CeedChkBackend(ierr);
244   ierr = CeedFree(&l_vec_indices); CeedChkBackend(ierr);
245   ierr = CeedFree(&t_offsets); CeedChkBackend(ierr);
246   ierr = CeedFree(&t_indices); CeedChkBackend(ierr);
247 
248   return CEED_ERROR_SUCCESS;
249 }
250 
251 //------------------------------------------------------------------------------
252 // Create restriction
253 //------------------------------------------------------------------------------
254 int CeedElemRestrictionCreate_Hip(CeedMemType mtype, CeedCopyMode cmode,
255                                   const CeedInt *indices,
256                                   CeedElemRestriction r) {
257   int ierr;
258   Ceed ceed;
259   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
260   CeedElemRestriction_Hip *impl;
261   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
262   CeedInt num_elem, num_comp, elem_size;
263   ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr);
264   ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr);
265   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
266   CeedInt size = num_elem * elem_size;
267   CeedInt strides[3] = {1, size, elem_size};
268   CeedInt comp_stride = 1;
269 
270   // Stride data
271   bool is_strided;
272   ierr = CeedElemRestrictionIsStrided(r, &is_strided); CeedChkBackend(ierr);
273   if (is_strided) {
274     bool has_backend_strides;
275     ierr = CeedElemRestrictionHasBackendStrides(r, &has_backend_strides);
276     CeedChkBackend(ierr);
277     if (!has_backend_strides) {
278       ierr = CeedElemRestrictionGetStrides(r, &strides); CeedChkBackend(ierr);
279     }
280   } else {
281     ierr = CeedElemRestrictionGetCompStride(r, &comp_stride); CeedChkBackend(ierr);
282   }
283 
284   impl->h_ind           = NULL;
285   impl->h_ind_allocated = NULL;
286   impl->d_ind           = NULL;
287   impl->d_ind_allocated = NULL;
288   impl->d_t_indices     = NULL;
289   impl->d_t_offsets     = NULL;
290   impl->num_nodes = size;
291   ierr = CeedElemRestrictionSetData(r, impl); CeedChkBackend(ierr);
292   CeedInt layout[3] = {1, elem_size*num_elem, elem_size};
293   ierr = CeedElemRestrictionSetELayout(r, layout); CeedChkBackend(ierr);
294 
295   // Set up device indices/offset arrays
296   if (mtype == CEED_MEM_HOST) {
297     switch (cmode) {
298     case CEED_OWN_POINTER:
299       impl->h_ind_allocated = (CeedInt *)indices;
300       impl->h_ind = (CeedInt *)indices;
301       break;
302     case CEED_USE_POINTER:
303       impl->h_ind = (CeedInt *)indices;
304       break;
305     case CEED_COPY_VALUES:
306       break;
307     }
308     if (indices != NULL) {
309       ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt));
310       CeedChk_Hip(ceed, ierr);
311       impl->d_ind_allocated = impl->d_ind; // We own the device memory
312       ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt),
313                        hipMemcpyHostToDevice);
314       CeedChk_Hip(ceed, ierr);
315       ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr);
316     }
317   } else if (mtype == CEED_MEM_DEVICE) {
318     switch (cmode) {
319     case CEED_COPY_VALUES:
320       if (indices != NULL) {
321         ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt));
322         CeedChk_Hip(ceed, ierr);
323         impl->d_ind_allocated = impl->d_ind; // We own the device memory
324         ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt),
325                          hipMemcpyDeviceToDevice);
326         CeedChk_Hip(ceed, ierr);
327       }
328       break;
329     case CEED_OWN_POINTER:
330       impl->d_ind = (CeedInt *)indices;
331       impl->d_ind_allocated = impl->d_ind;
332       break;
333     case CEED_USE_POINTER:
334       impl->d_ind = (CeedInt *)indices;
335     }
336     if (indices != NULL) {
337       ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr);
338     }
339   } else {
340     // LCOV_EXCL_START
341     return CeedError(ceed, CEED_ERROR_BACKEND,
342                      "Only MemType = HOST or DEVICE supported");
343     // LCOV_EXCL_STOP
344   }
345 
346   // Compile HIP kernels
347   CeedInt num_nodes = impl->num_nodes;
348   char *restriction_kernel_path, *restriction_kernel_source;
349   ierr = CeedPathConcatenate(ceed, __FILE__, "kernels/hip-ref-restriction.h",
350                              &restriction_kernel_path); CeedChkBackend(ierr);
351   ierr = CeedLoadSourceToBuffer(ceed, restriction_kernel_path,
352                                 &restriction_kernel_source);
353   CeedChkBackend(ierr);
354   ierr = CeedCompileHip(ceed, restriction_kernel_source, &impl->module, 8,
355                         "RESTRICTION_ELEMSIZE", elem_size,
356                         "RESTRICTION_NELEM", num_elem,
357                         "RESTRICTION_NCOMP", num_comp,
358                         "RESTRICTION_NNODES", num_nodes,
359                         "RESTRICTION_COMPSTRIDE", comp_stride,
360                         "STRIDE_NODES", strides[0],
361                         "STRIDE_COMP", strides[1],
362                         "STRIDE_ELEM", strides[2]); CeedChkBackend(ierr);
363   ierr = CeedGetKernelHip(ceed, impl->module, "StridedNoTranspose",
364                           &impl->StridedNoTranspose); CeedChkBackend(ierr);
365   ierr = CeedGetKernelHip(ceed, impl->module, "OffsetNoTranspose",
366                           &impl->OffsetNoTranspose); CeedChkBackend(ierr);
367   ierr = CeedGetKernelHip(ceed, impl->module, "StridedTranspose",
368                           &impl->StridedTranspose); CeedChkBackend(ierr);
369   ierr = CeedGetKernelHip(ceed, impl->module, "OffsetTranspose",
370                           &impl->OffsetTranspose); CeedChkBackend(ierr);
371   ierr = CeedFree(&restriction_kernel_path); CeedChkBackend(ierr);
372   ierr = CeedFree(&restriction_kernel_source); CeedChkBackend(ierr);
373 
374   // Register backend functions
375   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply",
376                                 CeedElemRestrictionApply_Hip);
377   CeedChkBackend(ierr);
378   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyBlock",
379                                 CeedElemRestrictionApplyBlock_Hip);
380   CeedChkBackend(ierr);
381   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets",
382                                 CeedElemRestrictionGetOffsets_Hip);
383   CeedChkBackend(ierr);
384   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy",
385                                 CeedElemRestrictionDestroy_Hip);
386   CeedChkBackend(ierr);
387   return CEED_ERROR_SUCCESS;
388 }
389 
390 //------------------------------------------------------------------------------
391 // Blocked not supported
392 //------------------------------------------------------------------------------
393 int CeedElemRestrictionCreateBlocked_Hip(const CeedMemType mtype,
394     const CeedCopyMode cmode, const CeedInt *indices, CeedElemRestriction r) {
395   int ierr;
396   Ceed ceed;
397   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
398   return CeedError(ceed, CEED_ERROR_BACKEND,
399                    "Backend does not implement blocked restrictions");
400 }
401 //------------------------------------------------------------------------------
402