xref: /libCEED/backends/hip-ref/ceed-hip-ref-restriction.c (revision ee5a26f213b810c411fd8b16edb331d78b70003c)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <hip/hip_runtime.h>
12 #include <stdbool.h>
13 #include <stddef.h>
14 #include "ceed-hip-ref.h"
15 #include "../hip/ceed-hip-compile.h"
16 
17 //------------------------------------------------------------------------------
18 // Apply restriction
19 //------------------------------------------------------------------------------
20 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r,
21                                         CeedTransposeMode t_mode, CeedVector u,
22                                         CeedVector v, CeedRequest *request) {
23   int ierr;
24   CeedElemRestriction_Hip *impl;
25   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
26   Ceed ceed;
27   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
28   Ceed_Hip *data;
29   ierr = CeedGetData(ceed, &data); CeedChkBackend(ierr);
30   const CeedInt block_size = 64;
31   const CeedInt num_nodes = impl->num_nodes;
32   CeedInt num_elem, elem_size;
33   CeedElemRestrictionGetNumElements(r, &num_elem);
34   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
35   hipFunction_t kernel;
36 
37   // Get vectors
38   const CeedScalar *d_u;
39   CeedScalar *d_v;
40   ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
41   if (t_mode == CEED_TRANSPOSE) {
42     // Sum into for transpose mode, e-vec to l-vec
43     ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
44   } else {
45     // Overwrite for notranspose mode, l-vec to e-vec
46     ierr = CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
47   }
48 
49   // Restrict
50   if (t_mode == CEED_NOTRANSPOSE) {
51     // L-vector -> E-vector
52     if (impl->d_ind) {
53       // -- Offsets provided
54       kernel = impl->OffsetNoTranspose;
55       void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
56       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
57       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
58                               block_size, args); CeedChkBackend(ierr);
59     } else {
60       // -- Strided restriction
61       kernel = impl->StridedNoTranspose;
62       void *args[] = {&num_elem, &d_u, &d_v};
63       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
64       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
65                               block_size, args); CeedChkBackend(ierr);
66     }
67   } else {
68     // E-vector -> L-vector
69     if (impl->d_ind) {
70       // -- Offsets provided
71       kernel = impl->OffsetTranspose;
72       void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices,
73                       &impl->d_t_offsets, &d_u, &d_v
74                      };
75       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
76                               block_size, args); CeedChkBackend(ierr);
77     } else {
78       // -- Strided restriction
79       kernel = impl->StridedTranspose;
80       void *args[] = {&num_elem, &d_u, &d_v};
81       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
82                               block_size, args); CeedChkBackend(ierr);
83     }
84   }
85 
86   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED)
87     *request = NULL;
88 
89   // Restore arrays
90   ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
91   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
92   return CEED_ERROR_SUCCESS;
93 }
94 
95 //------------------------------------------------------------------------------
96 // Blocked not supported
97 //------------------------------------------------------------------------------
98 int CeedElemRestrictionApplyBlock_Hip(CeedElemRestriction r, CeedInt block,
99                                       CeedTransposeMode t_mode, CeedVector u,
100                                       CeedVector v, CeedRequest *request) {
101   // LCOV_EXCL_START
102   int ierr;
103   Ceed ceed;
104   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
105   return CeedError(ceed, CEED_ERROR_BACKEND,
106                    "Backend does not implement blocked restrictions");
107   // LCOV_EXCL_STOP
108 }
109 
110 //------------------------------------------------------------------------------
111 // Get offsets
112 //------------------------------------------------------------------------------
113 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr,
114     CeedMemType mtype, const CeedInt **offsets) {
115   int ierr;
116   CeedElemRestriction_Hip *impl;
117   ierr = CeedElemRestrictionGetData(rstr, &impl); CeedChkBackend(ierr);
118 
119   switch (mtype) {
120   case CEED_MEM_HOST:
121     *offsets = impl->h_ind;
122     break;
123   case CEED_MEM_DEVICE:
124     *offsets = impl->d_ind;
125     break;
126   }
127   return CEED_ERROR_SUCCESS;
128 }
129 
130 //------------------------------------------------------------------------------
131 // Destroy restriction
132 //------------------------------------------------------------------------------
133 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) {
134   int ierr;
135   CeedElemRestriction_Hip *impl;
136   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
137 
138   Ceed ceed;
139   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
140   ierr = hipModuleUnload(impl->module); CeedChk_Hip(ceed, ierr);
141   ierr = CeedFree(&impl->h_ind_allocated); CeedChkBackend(ierr);
142   ierr = hipFree(impl->d_ind_allocated); CeedChk_Hip(ceed, ierr);
143   ierr = hipFree(impl->d_t_offsets); CeedChk_Hip(ceed, ierr);
144   ierr = hipFree(impl->d_t_indices); CeedChk_Hip(ceed, ierr);
145   ierr = hipFree(impl->d_l_vec_indices); CeedChk_Hip(ceed, ierr);
146   ierr = CeedFree(&impl); CeedChkBackend(ierr);
147 
148   return CEED_ERROR_SUCCESS;
149 }
150 
151 //------------------------------------------------------------------------------
152 // Create transpose offsets and indices
153 //------------------------------------------------------------------------------
154 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r,
155     const CeedInt *indices) {
156   int ierr;
157   Ceed ceed;
158   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
159   CeedElemRestriction_Hip *impl;
160   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
161   CeedSize l_size;
162   CeedInt num_elem, elem_size, num_comp;
163   ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr);
164   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
165   ierr = CeedElemRestrictionGetLVectorSize(r, &l_size); CeedChkBackend(ierr);
166   ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr);
167 
168   // Count num_nodes
169   bool *is_node;
170   ierr = CeedCalloc(l_size, &is_node); CeedChkBackend(ierr);
171   const CeedInt size_indices = num_elem * elem_size;
172   for (CeedInt i = 0; i < size_indices; i++)
173     is_node[indices[i]] = 1;
174   CeedInt num_nodes = 0;
175   for (CeedInt i = 0; i < l_size; i++)
176     num_nodes += is_node[i];
177   impl->num_nodes = num_nodes;
178 
179   // L-vector offsets array
180   CeedInt *ind_to_offset, *l_vec_indices;
181   ierr = CeedCalloc(l_size, &ind_to_offset); CeedChkBackend(ierr);
182   ierr = CeedCalloc(num_nodes, &l_vec_indices); CeedChkBackend(ierr);
183   CeedInt j = 0;
184   for (CeedInt i = 0; i < l_size; i++)
185     if (is_node[i]) {
186       l_vec_indices[j] = i;
187       ind_to_offset[i] = j++;
188     }
189   ierr = CeedFree(&is_node); CeedChkBackend(ierr);
190 
191   // Compute transpose offsets and indices
192   const CeedInt size_offsets = num_nodes + 1;
193   CeedInt *t_offsets;
194   ierr = CeedCalloc(size_offsets, &t_offsets); CeedChkBackend(ierr);
195   CeedInt *t_indices;
196   ierr = CeedMalloc(size_indices, &t_indices); CeedChkBackend(ierr);
197   // Count node multiplicity
198   for (CeedInt e = 0; e < num_elem; ++e)
199     for (CeedInt i = 0; i < elem_size; ++i)
200       ++t_offsets[ind_to_offset[indices[elem_size*e + i]] + 1];
201   // Convert to running sum
202   for (CeedInt i = 1; i < size_offsets; ++i)
203     t_offsets[i] += t_offsets[i-1];
204   // List all E-vec indices associated with L-vec node
205   for (CeedInt e = 0; e < num_elem; ++e) {
206     for (CeedInt i = 0; i < elem_size; ++i) {
207       const CeedInt lid = elem_size*e + i;
208       const CeedInt gid = indices[lid];
209       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
210     }
211   }
212   // Reset running sum
213   for (int i = size_offsets - 1; i > 0; --i)
214     t_offsets[i] = t_offsets[i - 1];
215   t_offsets[0] = 0;
216 
217   // Copy data to device
218   // -- L-vector indices
219   ierr = hipMalloc((void **)&impl->d_l_vec_indices, num_nodes*sizeof(CeedInt));
220   CeedChk_Hip(ceed, ierr);
221   ierr = hipMemcpy(impl->d_l_vec_indices, l_vec_indices,
222                    num_nodes*sizeof(CeedInt), hipMemcpyHostToDevice);
223   CeedChk_Hip(ceed, ierr);
224   // -- Transpose offsets
225   ierr = hipMalloc((void **)&impl->d_t_offsets, size_offsets*sizeof(CeedInt));
226   CeedChk_Hip(ceed, ierr);
227   ierr = hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets*sizeof(CeedInt),
228                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
229   // -- Transpose indices
230   ierr = hipMalloc((void **)&impl->d_t_indices, size_indices*sizeof(CeedInt));
231   CeedChk_Hip(ceed, ierr);
232   ierr = hipMemcpy(impl->d_t_indices, t_indices, size_indices*sizeof(CeedInt),
233                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
234 
235   // Cleanup
236   ierr = CeedFree(&ind_to_offset); CeedChkBackend(ierr);
237   ierr = CeedFree(&l_vec_indices); CeedChkBackend(ierr);
238   ierr = CeedFree(&t_offsets); CeedChkBackend(ierr);
239   ierr = CeedFree(&t_indices); CeedChkBackend(ierr);
240 
241   return CEED_ERROR_SUCCESS;
242 }
243 
244 //------------------------------------------------------------------------------
245 // Create restriction
246 //------------------------------------------------------------------------------
247 int CeedElemRestrictionCreate_Hip(CeedMemType mtype, CeedCopyMode cmode,
248                                   const CeedInt *indices,
249                                   CeedElemRestriction r) {
250   int ierr;
251   Ceed ceed;
252   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
253   CeedElemRestriction_Hip *impl;
254   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
255   CeedInt num_elem, num_comp, elem_size;
256   ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr);
257   ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr);
258   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
259   CeedInt size = num_elem * elem_size;
260   CeedInt strides[3] = {1, size, elem_size};
261   CeedInt comp_stride = 1;
262 
263   // Stride data
264   bool is_strided;
265   ierr = CeedElemRestrictionIsStrided(r, &is_strided); CeedChkBackend(ierr);
266   if (is_strided) {
267     bool has_backend_strides;
268     ierr = CeedElemRestrictionHasBackendStrides(r, &has_backend_strides);
269     CeedChkBackend(ierr);
270     if (!has_backend_strides) {
271       ierr = CeedElemRestrictionGetStrides(r, &strides); CeedChkBackend(ierr);
272     }
273   } else {
274     ierr = CeedElemRestrictionGetCompStride(r, &comp_stride); CeedChkBackend(ierr);
275   }
276 
277   impl->h_ind           = NULL;
278   impl->h_ind_allocated = NULL;
279   impl->d_ind           = NULL;
280   impl->d_ind_allocated = NULL;
281   impl->d_t_indices     = NULL;
282   impl->d_t_offsets     = NULL;
283   impl->num_nodes = size;
284   ierr = CeedElemRestrictionSetData(r, impl); CeedChkBackend(ierr);
285   CeedInt layout[3] = {1, elem_size*num_elem, elem_size};
286   ierr = CeedElemRestrictionSetELayout(r, layout); CeedChkBackend(ierr);
287 
288   // Set up device indices/offset arrays
289   if (mtype == CEED_MEM_HOST) {
290     switch (cmode) {
291     case CEED_OWN_POINTER:
292       impl->h_ind_allocated = (CeedInt *)indices;
293       impl->h_ind = (CeedInt *)indices;
294       break;
295     case CEED_USE_POINTER:
296       impl->h_ind = (CeedInt *)indices;
297       break;
298     case CEED_COPY_VALUES:
299       break;
300     }
301     if (indices != NULL) {
302       ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt));
303       CeedChk_Hip(ceed, ierr);
304       impl->d_ind_allocated = impl->d_ind; // We own the device memory
305       ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt),
306                        hipMemcpyHostToDevice);
307       CeedChk_Hip(ceed, ierr);
308       ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr);
309     }
310   } else if (mtype == CEED_MEM_DEVICE) {
311     switch (cmode) {
312     case CEED_COPY_VALUES:
313       if (indices != NULL) {
314         ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt));
315         CeedChk_Hip(ceed, ierr);
316         impl->d_ind_allocated = impl->d_ind; // We own the device memory
317         ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt),
318                          hipMemcpyDeviceToDevice);
319         CeedChk_Hip(ceed, ierr);
320       }
321       break;
322     case CEED_OWN_POINTER:
323       impl->d_ind = (CeedInt *)indices;
324       impl->d_ind_allocated = impl->d_ind;
325       break;
326     case CEED_USE_POINTER:
327       impl->d_ind = (CeedInt *)indices;
328     }
329     if (indices != NULL) {
330       ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr);
331     }
332   } else {
333     // LCOV_EXCL_START
334     return CeedError(ceed, CEED_ERROR_BACKEND,
335                      "Only MemType = HOST or DEVICE supported");
336     // LCOV_EXCL_STOP
337   }
338 
339   // Compile HIP kernels
340   CeedInt num_nodes = impl->num_nodes;
341   char *restriction_kernel_path, *restriction_kernel_source;
342   ierr = CeedGetJitAbsolutePath(ceed,
343                                 "ceed/jit-source/hip/hip-ref-restriction.h",
344                                 &restriction_kernel_path); CeedChkBackend(ierr);
345   CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source -----\n");
346   ierr = CeedLoadSourceToBuffer(ceed, restriction_kernel_path,
347                                 &restriction_kernel_source);
348   CeedChkBackend(ierr);
349   CeedDebug256(ceed, 2,
350                "----- Loading Restriction Kernel Source Complete! -----\n");
351   ierr = CeedCompileHip(ceed, restriction_kernel_source, &impl->module, 8,
352                         "RESTR_ELEM_SIZE", elem_size,
353                         "RESTR_NUM_ELEM", num_elem,
354                         "RESTR_NUM_COMP", num_comp,
355                         "RESTR_NUM_NODES", num_nodes,
356                         "RESTR_COMP_STRIDE", comp_stride,
357                         "RESTR_STRIDE_NODES", strides[0],
358                         "RESTR_STRIDE_COMP", strides[1],
359                         "RESTR_STRIDE_ELEM", strides[2]); CeedChkBackend(ierr);
360   ierr = CeedGetKernelHip(ceed, impl->module, "StridedNoTranspose",
361                           &impl->StridedNoTranspose); CeedChkBackend(ierr);
362   ierr = CeedGetKernelHip(ceed, impl->module, "OffsetNoTranspose",
363                           &impl->OffsetNoTranspose); CeedChkBackend(ierr);
364   ierr = CeedGetKernelHip(ceed, impl->module, "StridedTranspose",
365                           &impl->StridedTranspose); CeedChkBackend(ierr);
366   ierr = CeedGetKernelHip(ceed, impl->module, "OffsetTranspose",
367                           &impl->OffsetTranspose); CeedChkBackend(ierr);
368   ierr = CeedFree(&restriction_kernel_path); CeedChkBackend(ierr);
369   ierr = CeedFree(&restriction_kernel_source); CeedChkBackend(ierr);
370 
371   // Register backend functions
372   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply",
373                                 CeedElemRestrictionApply_Hip);
374   CeedChkBackend(ierr);
375   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyBlock",
376                                 CeedElemRestrictionApplyBlock_Hip);
377   CeedChkBackend(ierr);
378   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets",
379                                 CeedElemRestrictionGetOffsets_Hip);
380   CeedChkBackend(ierr);
381   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy",
382                                 CeedElemRestrictionDestroy_Hip);
383   CeedChkBackend(ierr);
384   return CEED_ERROR_SUCCESS;
385 }
386 
387 //------------------------------------------------------------------------------
388 // Blocked not supported
389 //------------------------------------------------------------------------------
390 int CeedElemRestrictionCreateBlocked_Hip(const CeedMemType mtype,
391     const CeedCopyMode cmode, const CeedInt *indices, CeedElemRestriction r) {
392   int ierr;
393   Ceed ceed;
394   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
395   return CeedError(ceed, CEED_ERROR_BACKEND,
396                    "Backend does not implement blocked restrictions");
397 }
398 //------------------------------------------------------------------------------
399