xref: /libCEED/backends/hip-ref/ceed-hip-ref-restriction.c (revision e26ec02599109ff3515b5118893f538030923aec)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <hip/hip_runtime.h>
12 #include <stdbool.h>
13 #include <stddef.h>
14 #include <string.h>
15 #include "ceed-hip-ref.h"
16 #include "../hip/ceed-hip-compile.h"
17 
18 //------------------------------------------------------------------------------
19 // Apply restriction
20 //------------------------------------------------------------------------------
21 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r,
22                                         CeedTransposeMode t_mode, CeedVector u,
23                                         CeedVector v, CeedRequest *request) {
24   int ierr;
25   CeedElemRestriction_Hip *impl;
26   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
27   Ceed ceed;
28   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
29   Ceed_Hip *data;
30   ierr = CeedGetData(ceed, &data); CeedChkBackend(ierr);
31   const CeedInt block_size = 64;
32   const CeedInt num_nodes = impl->num_nodes;
33   CeedInt num_elem, elem_size;
34   CeedElemRestrictionGetNumElements(r, &num_elem);
35   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
36   hipFunction_t kernel;
37 
38   // Get vectors
39   const CeedScalar *d_u;
40   CeedScalar *d_v;
41   ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
42   if (t_mode == CEED_TRANSPOSE) {
43     // Sum into for transpose mode, e-vec to l-vec
44     ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
45   } else {
46     // Overwrite for notranspose mode, l-vec to e-vec
47     ierr = CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
48   }
49 
50   // Restrict
51   if (t_mode == CEED_NOTRANSPOSE) {
52     // L-vector -> E-vector
53     if (impl->d_ind) {
54       // -- Offsets provided
55       kernel = impl->OffsetNoTranspose;
56       void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
57       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
58       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
59                               block_size, args); CeedChkBackend(ierr);
60     } else {
61       // -- Strided restriction
62       kernel = impl->StridedNoTranspose;
63       void *args[] = {&num_elem, &d_u, &d_v};
64       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
65       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
66                               block_size, args); CeedChkBackend(ierr);
67     }
68   } else {
69     // E-vector -> L-vector
70     if (impl->d_ind) {
71       // -- Offsets provided
72       kernel = impl->OffsetTranspose;
73       void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices,
74                       &impl->d_t_offsets, &d_u, &d_v
75                      };
76       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
77                               block_size, args); CeedChkBackend(ierr);
78     } else {
79       // -- Strided restriction
80       kernel = impl->StridedTranspose;
81       void *args[] = {&num_elem, &d_u, &d_v};
82       ierr = CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size),
83                               block_size, args); CeedChkBackend(ierr);
84     }
85   }
86 
87   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED)
88     *request = NULL;
89 
90   // Restore arrays
91   ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
92   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
93   return CEED_ERROR_SUCCESS;
94 }
95 
96 //------------------------------------------------------------------------------
97 // Blocked not supported
98 //------------------------------------------------------------------------------
99 int CeedElemRestrictionApplyBlock_Hip(CeedElemRestriction r, CeedInt block,
100                                       CeedTransposeMode t_mode, CeedVector u,
101                                       CeedVector v, CeedRequest *request) {
102   // LCOV_EXCL_START
103   int ierr;
104   Ceed ceed;
105   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
106   return CeedError(ceed, CEED_ERROR_BACKEND,
107                    "Backend does not implement blocked restrictions");
108   // LCOV_EXCL_STOP
109 }
110 
111 //------------------------------------------------------------------------------
112 // Get offsets
113 //------------------------------------------------------------------------------
114 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr,
115     CeedMemType mtype, const CeedInt **offsets) {
116   int ierr;
117   CeedElemRestriction_Hip *impl;
118   ierr = CeedElemRestrictionGetData(rstr, &impl); CeedChkBackend(ierr);
119 
120   switch (mtype) {
121   case CEED_MEM_HOST:
122     *offsets = impl->h_ind;
123     break;
124   case CEED_MEM_DEVICE:
125     *offsets = impl->d_ind;
126     break;
127   }
128   return CEED_ERROR_SUCCESS;
129 }
130 
131 //------------------------------------------------------------------------------
132 // Destroy restriction
133 //------------------------------------------------------------------------------
134 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) {
135   int ierr;
136   CeedElemRestriction_Hip *impl;
137   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
138 
139   Ceed ceed;
140   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
141   ierr = hipModuleUnload(impl->module); CeedChk_Hip(ceed, ierr);
142   ierr = CeedFree(&impl->h_ind_allocated); CeedChkBackend(ierr);
143   ierr = hipFree(impl->d_ind_allocated); CeedChk_Hip(ceed, ierr);
144   ierr = hipFree(impl->d_t_offsets); CeedChk_Hip(ceed, ierr);
145   ierr = hipFree(impl->d_t_indices); CeedChk_Hip(ceed, ierr);
146   ierr = hipFree(impl->d_l_vec_indices); CeedChk_Hip(ceed, ierr);
147   ierr = CeedFree(&impl); CeedChkBackend(ierr);
148 
149   return CEED_ERROR_SUCCESS;
150 }
151 
152 //------------------------------------------------------------------------------
153 // Create transpose offsets and indices
154 //------------------------------------------------------------------------------
155 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r,
156     const CeedInt *indices) {
157   int ierr;
158   Ceed ceed;
159   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
160   CeedElemRestriction_Hip *impl;
161   ierr = CeedElemRestrictionGetData(r, &impl); CeedChkBackend(ierr);
162   CeedSize l_size;
163   CeedInt num_elem, elem_size, num_comp;
164   ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr);
165   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
166   ierr = CeedElemRestrictionGetLVectorSize(r, &l_size); CeedChkBackend(ierr);
167   ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr);
168 
169   // Count num_nodes
170   bool *is_node;
171   ierr = CeedCalloc(l_size, &is_node); CeedChkBackend(ierr);
172   const CeedInt size_indices = num_elem * elem_size;
173   for (CeedInt i = 0; i < size_indices; i++)
174     is_node[indices[i]] = 1;
175   CeedInt num_nodes = 0;
176   for (CeedInt i = 0; i < l_size; i++)
177     num_nodes += is_node[i];
178   impl->num_nodes = num_nodes;
179 
180   // L-vector offsets array
181   CeedInt *ind_to_offset, *l_vec_indices;
182   ierr = CeedCalloc(l_size, &ind_to_offset); CeedChkBackend(ierr);
183   ierr = CeedCalloc(num_nodes, &l_vec_indices); CeedChkBackend(ierr);
184   CeedInt j = 0;
185   for (CeedInt i = 0; i < l_size; i++)
186     if (is_node[i]) {
187       l_vec_indices[j] = i;
188       ind_to_offset[i] = j++;
189     }
190   ierr = CeedFree(&is_node); CeedChkBackend(ierr);
191 
192   // Compute transpose offsets and indices
193   const CeedInt size_offsets = num_nodes + 1;
194   CeedInt *t_offsets;
195   ierr = CeedCalloc(size_offsets, &t_offsets); CeedChkBackend(ierr);
196   CeedInt *t_indices;
197   ierr = CeedMalloc(size_indices, &t_indices); CeedChkBackend(ierr);
198   // Count node multiplicity
199   for (CeedInt e = 0; e < num_elem; ++e)
200     for (CeedInt i = 0; i < elem_size; ++i)
201       ++t_offsets[ind_to_offset[indices[elem_size*e + i]] + 1];
202   // Convert to running sum
203   for (CeedInt i = 1; i < size_offsets; ++i)
204     t_offsets[i] += t_offsets[i-1];
205   // List all E-vec indices associated with L-vec node
206   for (CeedInt e = 0; e < num_elem; ++e) {
207     for (CeedInt i = 0; i < elem_size; ++i) {
208       const CeedInt lid = elem_size*e + i;
209       const CeedInt gid = indices[lid];
210       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
211     }
212   }
213   // Reset running sum
214   for (int i = size_offsets - 1; i > 0; --i)
215     t_offsets[i] = t_offsets[i - 1];
216   t_offsets[0] = 0;
217 
218   // Copy data to device
219   // -- L-vector indices
220   ierr = hipMalloc((void **)&impl->d_l_vec_indices, num_nodes*sizeof(CeedInt));
221   CeedChk_Hip(ceed, ierr);
222   ierr = hipMemcpy(impl->d_l_vec_indices, l_vec_indices,
223                    num_nodes*sizeof(CeedInt), hipMemcpyHostToDevice);
224   CeedChk_Hip(ceed, ierr);
225   // -- Transpose offsets
226   ierr = hipMalloc((void **)&impl->d_t_offsets, size_offsets*sizeof(CeedInt));
227   CeedChk_Hip(ceed, ierr);
228   ierr = hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets*sizeof(CeedInt),
229                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
230   // -- Transpose indices
231   ierr = hipMalloc((void **)&impl->d_t_indices, size_indices*sizeof(CeedInt));
232   CeedChk_Hip(ceed, ierr);
233   ierr = hipMemcpy(impl->d_t_indices, t_indices, size_indices*sizeof(CeedInt),
234                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
235 
236   // Cleanup
237   ierr = CeedFree(&ind_to_offset); CeedChkBackend(ierr);
238   ierr = CeedFree(&l_vec_indices); CeedChkBackend(ierr);
239   ierr = CeedFree(&t_offsets); CeedChkBackend(ierr);
240   ierr = CeedFree(&t_indices); CeedChkBackend(ierr);
241 
242   return CEED_ERROR_SUCCESS;
243 }
244 
245 //------------------------------------------------------------------------------
246 // Create restriction
247 //------------------------------------------------------------------------------
248 int CeedElemRestrictionCreate_Hip(CeedMemType mtype, CeedCopyMode cmode,
249                                   const CeedInt *indices,
250                                   CeedElemRestriction r) {
251   int ierr;
252   Ceed ceed;
253   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
254   CeedElemRestriction_Hip *impl;
255   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
256   CeedInt num_elem, num_comp, elem_size;
257   ierr = CeedElemRestrictionGetNumElements(r, &num_elem); CeedChkBackend(ierr);
258   ierr = CeedElemRestrictionGetNumComponents(r, &num_comp); CeedChkBackend(ierr);
259   ierr = CeedElemRestrictionGetElementSize(r, &elem_size); CeedChkBackend(ierr);
260   CeedInt size = num_elem * elem_size;
261   CeedInt strides[3] = {1, size, elem_size};
262   CeedInt comp_stride = 1;
263 
264   // Stride data
265   bool is_strided;
266   ierr = CeedElemRestrictionIsStrided(r, &is_strided); CeedChkBackend(ierr);
267   if (is_strided) {
268     bool has_backend_strides;
269     ierr = CeedElemRestrictionHasBackendStrides(r, &has_backend_strides);
270     CeedChkBackend(ierr);
271     if (!has_backend_strides) {
272       ierr = CeedElemRestrictionGetStrides(r, &strides); CeedChkBackend(ierr);
273     }
274   } else {
275     ierr = CeedElemRestrictionGetCompStride(r, &comp_stride); CeedChkBackend(ierr);
276   }
277 
278   impl->h_ind           = NULL;
279   impl->h_ind_allocated = NULL;
280   impl->d_ind           = NULL;
281   impl->d_ind_allocated = NULL;
282   impl->d_t_indices     = NULL;
283   impl->d_t_offsets     = NULL;
284   impl->num_nodes = size;
285   ierr = CeedElemRestrictionSetData(r, impl); CeedChkBackend(ierr);
286   CeedInt layout[3] = {1, elem_size*num_elem, elem_size};
287   ierr = CeedElemRestrictionSetELayout(r, layout); CeedChkBackend(ierr);
288 
289   // Set up device indices/offset arrays
290   if (mtype == CEED_MEM_HOST) {
291     switch (cmode) {
292     case CEED_OWN_POINTER:
293       impl->h_ind_allocated = (CeedInt *)indices;
294       impl->h_ind = (CeedInt *)indices;
295       break;
296     case CEED_USE_POINTER:
297       impl->h_ind = (CeedInt *)indices;
298       break;
299     case CEED_COPY_VALUES:
300       if (indices != NULL) {
301         ierr = CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated);
302         CeedChkBackend(ierr);
303         memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
304         impl->h_ind = impl->h_ind_allocated;
305       }
306       break;
307     }
308     if (indices != NULL) {
309       ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt));
310       CeedChk_Hip(ceed, ierr);
311       impl->d_ind_allocated = impl->d_ind; // We own the device memory
312       ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt),
313                        hipMemcpyHostToDevice);
314       CeedChk_Hip(ceed, ierr);
315       ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr);
316     }
317   } else if (mtype == CEED_MEM_DEVICE) {
318     switch (cmode) {
319     case CEED_COPY_VALUES:
320       if (indices != NULL) {
321         ierr = hipMalloc( (void **)&impl->d_ind, size * sizeof(CeedInt));
322         CeedChk_Hip(ceed, ierr);
323         impl->d_ind_allocated = impl->d_ind; // We own the device memory
324         ierr = hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt),
325                          hipMemcpyDeviceToDevice);
326         CeedChk_Hip(ceed, ierr);
327       }
328       break;
329     case CEED_OWN_POINTER:
330       impl->d_ind = (CeedInt *)indices;
331       impl->d_ind_allocated = impl->d_ind;
332       break;
333     case CEED_USE_POINTER:
334       impl->d_ind = (CeedInt *)indices;
335     }
336     if (indices != NULL) {
337       ierr = CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated);
338       CeedChkBackend(ierr);
339       ierr = hipMemcpy(impl->h_ind_allocated, impl->d_ind,
340                        elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost);
341       CeedChk_Hip(ceed, ierr);
342       impl->h_ind = impl->h_ind_allocated;
343       ierr = CeedElemRestrictionOffset_Hip(r, indices); CeedChkBackend(ierr);
344     }
345   } else {
346     // LCOV_EXCL_START
347     return CeedError(ceed, CEED_ERROR_BACKEND,
348                      "Only MemType = HOST or DEVICE supported");
349     // LCOV_EXCL_STOP
350   }
351 
352   // Compile HIP kernels
353   CeedInt num_nodes = impl->num_nodes;
354   char *restriction_kernel_path, *restriction_kernel_source;
355   ierr = CeedGetJitAbsolutePath(ceed,
356                                 "ceed/jit-source/hip/hip-ref-restriction.h",
357                                 &restriction_kernel_path); CeedChkBackend(ierr);
358   CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source -----\n");
359   ierr = CeedLoadSourceToBuffer(ceed, restriction_kernel_path,
360                                 &restriction_kernel_source);
361   CeedChkBackend(ierr);
362   CeedDebug256(ceed, 2,
363                "----- Loading Restriction Kernel Source Complete! -----\n");
364   ierr = CeedCompileHip(ceed, restriction_kernel_source, &impl->module, 8,
365                         "RESTR_ELEM_SIZE", elem_size,
366                         "RESTR_NUM_ELEM", num_elem,
367                         "RESTR_NUM_COMP", num_comp,
368                         "RESTR_NUM_NODES", num_nodes,
369                         "RESTR_COMP_STRIDE", comp_stride,
370                         "RESTR_STRIDE_NODES", strides[0],
371                         "RESTR_STRIDE_COMP", strides[1],
372                         "RESTR_STRIDE_ELEM", strides[2]); CeedChkBackend(ierr);
373   ierr = CeedGetKernelHip(ceed, impl->module, "StridedNoTranspose",
374                           &impl->StridedNoTranspose); CeedChkBackend(ierr);
375   ierr = CeedGetKernelHip(ceed, impl->module, "OffsetNoTranspose",
376                           &impl->OffsetNoTranspose); CeedChkBackend(ierr);
377   ierr = CeedGetKernelHip(ceed, impl->module, "StridedTranspose",
378                           &impl->StridedTranspose); CeedChkBackend(ierr);
379   ierr = CeedGetKernelHip(ceed, impl->module, "OffsetTranspose",
380                           &impl->OffsetTranspose); CeedChkBackend(ierr);
381   ierr = CeedFree(&restriction_kernel_path); CeedChkBackend(ierr);
382   ierr = CeedFree(&restriction_kernel_source); CeedChkBackend(ierr);
383 
384   // Register backend functions
385   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply",
386                                 CeedElemRestrictionApply_Hip);
387   CeedChkBackend(ierr);
388   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyBlock",
389                                 CeedElemRestrictionApplyBlock_Hip);
390   CeedChkBackend(ierr);
391   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets",
392                                 CeedElemRestrictionGetOffsets_Hip);
393   CeedChkBackend(ierr);
394   ierr = CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy",
395                                 CeedElemRestrictionDestroy_Hip);
396   CeedChkBackend(ierr);
397   return CEED_ERROR_SUCCESS;
398 }
399 
400 //------------------------------------------------------------------------------
401 // Blocked not supported
402 //------------------------------------------------------------------------------
403 int CeedElemRestrictionCreateBlocked_Hip(const CeedMemType mtype,
404     const CeedCopyMode cmode, const CeedInt *indices, CeedElemRestriction r) {
405   int ierr;
406   Ceed ceed;
407   ierr = CeedElemRestrictionGetCeed(r, &ceed); CeedChkBackend(ierr);
408   return CeedError(ceed, CEED_ERROR_BACKEND,
409                    "Backend does not implement blocked restrictions");
410 }
411 //------------------------------------------------------------------------------
412