xref: /libCEED/backends/hip-ref/ceed-hip-ref-restriction.c (revision a9e65696a8c8214eb82d2dcf9ed1f28a32d2c94e)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <ceed/jit-tools.h>
11 #include <hip/hip_runtime.h>
12 #include <stdbool.h>
13 #include <stddef.h>
14 #include <string.h>
15 
16 #include "../hip/ceed-hip-compile.h"
17 #include "ceed-hip-ref.h"
18 
19 //------------------------------------------------------------------------------
20 // Apply restriction
21 //------------------------------------------------------------------------------
22 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
23   CeedElemRestriction_Hip *impl;
24   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
25   Ceed ceed;
26   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
27   Ceed_Hip *data;
28   CeedCallBackend(CeedGetData(ceed, &data));
29   const CeedInt block_size = 64;
30   const CeedInt num_nodes  = impl->num_nodes;
31   CeedInt       num_elem, elem_size;
32   CeedElemRestrictionGetNumElements(r, &num_elem);
33   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
34   hipFunction_t kernel;
35 
36   // Get vectors
37   const CeedScalar *d_u;
38   CeedScalar       *d_v;
39   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
40   if (t_mode == CEED_TRANSPOSE) {
41     // Sum into for transpose mode, e-vec to l-vec
42     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
43   } else {
44     // Overwrite for notranspose mode, l-vec to e-vec
45     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
46   }
47 
48   // Restrict
49   if (t_mode == CEED_NOTRANSPOSE) {
50     // L-vector -> E-vector
51     if (impl->d_ind) {
52       // -- Offsets provided
53       kernel             = impl->OffsetNoTranspose;
54       void   *args[]     = {&num_elem, &impl->d_ind, &d_u, &d_v};
55       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
56       CeedCallBackend(CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
57     } else {
58       // -- Strided restriction
59       kernel             = impl->StridedNoTranspose;
60       void   *args[]     = {&num_elem, &d_u, &d_v};
61       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
62       CeedCallBackend(CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
63     }
64   } else {
65     // E-vector -> L-vector
66     if (impl->d_ind) {
67       // -- Offsets provided
68       kernel       = impl->OffsetTranspose;
69       void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
70       CeedCallBackend(CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
71     } else {
72       // -- Strided restriction
73       kernel       = impl->StridedTranspose;
74       void *args[] = {&num_elem, &d_u, &d_v};
75       CeedCallBackend(CeedRunKernelHip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
76     }
77   }
78 
79   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
80 
81   // Restore arrays
82   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
83   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
84   return CEED_ERROR_SUCCESS;
85 }
86 
87 //------------------------------------------------------------------------------
88 // Blocked not supported
89 //------------------------------------------------------------------------------
90 int CeedElemRestrictionApplyBlock_Hip(CeedElemRestriction r, CeedInt block, CeedTransposeMode t_mode, CeedVector u, CeedVector v,
91                                       CeedRequest *request) {
92   // LCOV_EXCL_START
93   Ceed ceed;
94   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
95   return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement blocked restrictions");
96   // LCOV_EXCL_STOP
97 }
98 
99 //------------------------------------------------------------------------------
100 // Get offsets
101 //------------------------------------------------------------------------------
102 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mtype, const CeedInt **offsets) {
103   CeedElemRestriction_Hip *impl;
104   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
105 
106   switch (mtype) {
107     case CEED_MEM_HOST:
108       *offsets = impl->h_ind;
109       break;
110     case CEED_MEM_DEVICE:
111       *offsets = impl->d_ind;
112       break;
113   }
114   return CEED_ERROR_SUCCESS;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Destroy restriction
119 //------------------------------------------------------------------------------
120 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) {
121   CeedElemRestriction_Hip *impl;
122   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
123 
124   Ceed ceed;
125   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
126   CeedCallHip(ceed, hipModuleUnload(impl->module));
127   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
128   CeedCallHip(ceed, hipFree(impl->d_ind_allocated));
129   CeedCallHip(ceed, hipFree(impl->d_t_offsets));
130   CeedCallHip(ceed, hipFree(impl->d_t_indices));
131   CeedCallHip(ceed, hipFree(impl->d_l_vec_indices));
132   CeedCallBackend(CeedFree(&impl));
133 
134   return CEED_ERROR_SUCCESS;
135 }
136 
137 //------------------------------------------------------------------------------
138 // Create transpose offsets and indices
139 //------------------------------------------------------------------------------
140 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r, const CeedInt *indices) {
141   Ceed ceed;
142   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
143   CeedElemRestriction_Hip *impl;
144   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
145   CeedSize l_size;
146   CeedInt  num_elem, elem_size, num_comp;
147   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
148   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
149   CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
150   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
151 
152   // Count num_nodes
153   bool *is_node;
154   CeedCallBackend(CeedCalloc(l_size, &is_node));
155   const CeedInt size_indices = num_elem * elem_size;
156   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
157   CeedInt num_nodes = 0;
158   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
159   impl->num_nodes = num_nodes;
160 
161   // L-vector offsets array
162   CeedInt *ind_to_offset, *l_vec_indices;
163   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
164   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
165   CeedInt j = 0;
166   for (CeedInt i = 0; i < l_size; i++) {
167     if (is_node[i]) {
168       l_vec_indices[j] = i;
169       ind_to_offset[i] = j++;
170     }
171   }
172   CeedCallBackend(CeedFree(&is_node));
173 
174   // Compute transpose offsets and indices
175   const CeedInt size_offsets = num_nodes + 1;
176   CeedInt      *t_offsets;
177   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
178   CeedInt *t_indices;
179   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
180   // Count node multiplicity
181   for (CeedInt e = 0; e < num_elem; ++e) {
182     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
183   }
184   // Convert to running sum
185   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
186   // List all E-vec indices associated with L-vec node
187   for (CeedInt e = 0; e < num_elem; ++e) {
188     for (CeedInt i = 0; i < elem_size; ++i) {
189       const CeedInt lid                          = elem_size * e + i;
190       const CeedInt gid                          = indices[lid];
191       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
192     }
193   }
194   // Reset running sum
195   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
196   t_offsets[0] = 0;
197 
198   // Copy data to device
199   // -- L-vector indices
200   CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt)));
201   CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice));
202   // -- Transpose offsets
203   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt)));
204   CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice));
205   // -- Transpose indices
206   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt)));
207   CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice));
208 
209   // Cleanup
210   CeedCallBackend(CeedFree(&ind_to_offset));
211   CeedCallBackend(CeedFree(&l_vec_indices));
212   CeedCallBackend(CeedFree(&t_offsets));
213   CeedCallBackend(CeedFree(&t_indices));
214 
215   return CEED_ERROR_SUCCESS;
216 }
217 
218 //------------------------------------------------------------------------------
219 // Create restriction
220 //------------------------------------------------------------------------------
221 int CeedElemRestrictionCreate_Hip(CeedMemType mtype, CeedCopyMode cmode, const CeedInt *indices, CeedElemRestriction r) {
222   Ceed ceed;
223   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
224   CeedElemRestriction_Hip *impl;
225   CeedCallBackend(CeedCalloc(1, &impl));
226   CeedInt num_elem, num_comp, elem_size;
227   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
228   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
229   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
230   CeedInt size        = num_elem * elem_size;
231   CeedInt strides[3]  = {1, size, elem_size};
232   CeedInt comp_stride = 1;
233 
234   // Stride data
235   bool is_strided;
236   CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided));
237   if (is_strided) {
238     bool has_backend_strides;
239     CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides));
240     if (!has_backend_strides) {
241       CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
242     }
243   } else {
244     CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
245   }
246 
247   impl->h_ind           = NULL;
248   impl->h_ind_allocated = NULL;
249   impl->d_ind           = NULL;
250   impl->d_ind_allocated = NULL;
251   impl->d_t_indices     = NULL;
252   impl->d_t_offsets     = NULL;
253   impl->num_nodes       = size;
254   CeedCallBackend(CeedElemRestrictionSetData(r, impl));
255   CeedInt layout[3] = {1, elem_size * num_elem, elem_size};
256   CeedCallBackend(CeedElemRestrictionSetELayout(r, layout));
257 
258   // Set up device indices/offset arrays
259   if (mtype == CEED_MEM_HOST) {
260     switch (cmode) {
261       case CEED_OWN_POINTER:
262         impl->h_ind_allocated = (CeedInt *)indices;
263         impl->h_ind           = (CeedInt *)indices;
264         break;
265       case CEED_USE_POINTER:
266         impl->h_ind = (CeedInt *)indices;
267         break;
268       case CEED_COPY_VALUES:
269         if (indices != NULL) {
270           CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
271           memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
272           impl->h_ind = impl->h_ind_allocated;
273         }
274         break;
275     }
276     if (indices != NULL) {
277       CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
278       impl->d_ind_allocated = impl->d_ind;  // We own the device memory
279       CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice));
280       CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices));
281     }
282   } else if (mtype == CEED_MEM_DEVICE) {
283     switch (cmode) {
284       case CEED_COPY_VALUES:
285         if (indices != NULL) {
286           CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
287           impl->d_ind_allocated = impl->d_ind;  // We own the device memory
288           CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice));
289         }
290         break;
291       case CEED_OWN_POINTER:
292         impl->d_ind           = (CeedInt *)indices;
293         impl->d_ind_allocated = impl->d_ind;
294         break;
295       case CEED_USE_POINTER:
296         impl->d_ind = (CeedInt *)indices;
297     }
298     if (indices != NULL) {
299       CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
300       CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost));
301       impl->h_ind = impl->h_ind_allocated;
302       CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices));
303     }
304   } else {
305     // LCOV_EXCL_START
306     return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
307     // LCOV_EXCL_STOP
308   }
309 
310   // Compile HIP kernels
311   CeedInt num_nodes = impl->num_nodes;
312   char   *restriction_kernel_path, *restriction_kernel_source;
313   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path));
314   CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source -----\n");
315   CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
316   CeedDebug256(ceed, 2, "----- Loading Restriction Kernel Source Complete! -----\n");
317   CeedCallBackend(CeedCompileHip(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem,
318                                  "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES",
319                                  strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2]));
320   CeedCallBackend(CeedGetKernelHip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose));
321   CeedCallBackend(CeedGetKernelHip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose));
322   CeedCallBackend(CeedGetKernelHip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose));
323   CeedCallBackend(CeedGetKernelHip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose));
324   CeedCallBackend(CeedFree(&restriction_kernel_path));
325   CeedCallBackend(CeedFree(&restriction_kernel_source));
326 
327   // Register backend functions
328   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Hip));
329   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyBlock", CeedElemRestrictionApplyBlock_Hip));
330   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Hip));
331   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Hip));
332   return CEED_ERROR_SUCCESS;
333 }
334 
335 //------------------------------------------------------------------------------
336 // Blocked not supported
337 //------------------------------------------------------------------------------
338 int CeedElemRestrictionCreateBlocked_Hip(const CeedMemType mtype, const CeedCopyMode cmode, const CeedInt *indices, CeedElemRestriction r) {
339   Ceed ceed;
340   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
341   return CeedError(ceed, CEED_ERROR_BACKEND, "Backend does not implement blocked restrictions");
342 }
343 //------------------------------------------------------------------------------
344