xref: /libCEED/backends/hip-ref/ceed-hip-ref-restriction.c (revision 53f7acb178914a16137d9c91c84843f149b8f9af)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <string.h>
14 #include <hip/hip_runtime.h>
15 
16 #include "../hip/ceed-hip-common.h"
17 #include "../hip/ceed-hip-compile.h"
18 #include "ceed-hip-ref.h"
19 
20 //------------------------------------------------------------------------------
21 // Apply restriction
22 //------------------------------------------------------------------------------
23 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
24   CeedElemRestriction_Hip *impl;
25   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
26   Ceed ceed;
27   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
28   Ceed_Hip *data;
29   CeedCallBackend(CeedGetData(ceed, &data));
30   const CeedInt num_nodes = impl->num_nodes;
31   CeedInt       num_elem, elem_size;
32   CeedElemRestrictionGetNumElements(r, &num_elem);
33   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
34   hipFunction_t kernel;
35 
36   // Get vectors
37   const CeedScalar *d_u;
38   CeedScalar       *d_v;
39   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
40   if (t_mode == CEED_TRANSPOSE) {
41     // Sum into for transpose mode, e-vec to l-vec
42     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
43   } else {
44     // Overwrite for notranspose mode, l-vec to e-vec
45     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
46   }
47 
48   // Restrict
49   if (t_mode == CEED_NOTRANSPOSE) {
50     // L-vector -> E-vector
51     if (impl->d_ind) {
52       // -- Offsets provided
53       kernel             = impl->OffsetNoTranspose;
54       void   *args[]     = {&num_elem, &impl->d_ind, &d_u, &d_v};
55       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
56 
57       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
58     } else {
59       // -- Strided restriction
60       kernel             = impl->StridedNoTranspose;
61       void   *args[]     = {&num_elem, &d_u, &d_v};
62       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
63 
64       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
65     }
66   } else {
67     // E-vector -> L-vector
68     if (impl->d_ind) {
69       // -- Offsets provided
70       CeedInt block_size = 64;
71       if (impl->OffsetTranspose) {
72         kernel       = impl->OffsetTranspose;
73         void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
74 
75         CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
76       } else {
77         kernel       = impl->OffsetTransposeDet;
78         void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
79 
80         CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
81       }
82     } else {
83       // -- Strided restriction
84       kernel             = impl->StridedTranspose;
85       void   *args[]     = {&num_elem, &d_u, &d_v};
86       CeedInt block_size = 64;
87 
88       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
89     }
90   }
91 
92   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
93 
94   // Restore arrays
95   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
96   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
97   return CEED_ERROR_SUCCESS;
98 }
99 
100 //------------------------------------------------------------------------------
101 // Get offsets
102 //------------------------------------------------------------------------------
103 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) {
104   CeedElemRestriction_Hip *impl;
105   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
106 
107   switch (mem_type) {
108     case CEED_MEM_HOST:
109       *offsets = impl->h_ind;
110       break;
111     case CEED_MEM_DEVICE:
112       *offsets = impl->d_ind;
113       break;
114   }
115   return CEED_ERROR_SUCCESS;
116 }
117 
118 //------------------------------------------------------------------------------
119 // Destroy restriction
120 //------------------------------------------------------------------------------
121 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) {
122   CeedElemRestriction_Hip *impl;
123   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
124 
125   Ceed ceed;
126   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
127   CeedCallHip(ceed, hipModuleUnload(impl->module));
128   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
129   CeedCallHip(ceed, hipFree(impl->d_ind_allocated));
130   CeedCallHip(ceed, hipFree(impl->d_t_offsets));
131   CeedCallHip(ceed, hipFree(impl->d_t_indices));
132   CeedCallHip(ceed, hipFree(impl->d_l_vec_indices));
133   CeedCallBackend(CeedFree(&impl));
134 
135   return CEED_ERROR_SUCCESS;
136 }
137 
138 //------------------------------------------------------------------------------
139 // Create transpose offsets and indices
140 //------------------------------------------------------------------------------
141 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r, const CeedInt *indices) {
142   Ceed ceed;
143   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
144   CeedElemRestriction_Hip *impl;
145   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
146   CeedSize l_size;
147   CeedInt  num_elem, elem_size, num_comp;
148   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
149   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
150   CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
151   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
152 
153   // Count num_nodes
154   bool *is_node;
155   CeedCallBackend(CeedCalloc(l_size, &is_node));
156   const CeedInt size_indices = num_elem * elem_size;
157   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
158   CeedInt num_nodes = 0;
159   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
160   impl->num_nodes = num_nodes;
161 
162   // L-vector offsets array
163   CeedInt *ind_to_offset, *l_vec_indices;
164   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
165   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
166   CeedInt j = 0;
167   for (CeedInt i = 0; i < l_size; i++) {
168     if (is_node[i]) {
169       l_vec_indices[j] = i;
170       ind_to_offset[i] = j++;
171     }
172   }
173   CeedCallBackend(CeedFree(&is_node));
174 
175   // Compute transpose offsets and indices
176   const CeedInt size_offsets = num_nodes + 1;
177   CeedInt      *t_offsets;
178   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
179   CeedInt *t_indices;
180   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
181   // Count node multiplicity
182   for (CeedInt e = 0; e < num_elem; ++e) {
183     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
184   }
185   // Convert to running sum
186   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
187   // List all E-vec indices associated with L-vec node
188   for (CeedInt e = 0; e < num_elem; ++e) {
189     for (CeedInt i = 0; i < elem_size; ++i) {
190       const CeedInt lid                          = elem_size * e + i;
191       const CeedInt gid                          = indices[lid];
192       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
193     }
194   }
195   // Reset running sum
196   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
197   t_offsets[0] = 0;
198 
199   // Copy data to device
200   // -- L-vector indices
201   CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt)));
202   CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice));
203   // -- Transpose offsets
204   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt)));
205   CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice));
206   // -- Transpose indices
207   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt)));
208   CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice));
209 
210   // Cleanup
211   CeedCallBackend(CeedFree(&ind_to_offset));
212   CeedCallBackend(CeedFree(&l_vec_indices));
213   CeedCallBackend(CeedFree(&t_offsets));
214   CeedCallBackend(CeedFree(&t_indices));
215 
216   return CEED_ERROR_SUCCESS;
217 }
218 
219 //------------------------------------------------------------------------------
220 // Create restriction
221 //------------------------------------------------------------------------------
222 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients,
223                                   const CeedInt8 *curl_orients, CeedElemRestriction r) {
224   Ceed ceed;
225   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
226   CeedElemRestriction_Hip *impl;
227   CeedCallBackend(CeedCalloc(1, &impl));
228   Ceed parent;
229   CeedCallBackend(CeedGetParent(ceed, &parent));
230   bool is_deterministic;
231   CeedCallBackend(CeedIsDeterministic(parent, &is_deterministic));
232   CeedInt num_elem, num_comp, elem_size;
233   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
234   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
235   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
236   CeedInt size        = num_elem * elem_size;
237   CeedInt strides[3]  = {1, size, elem_size};
238   CeedInt comp_stride = 1;
239 
240   CeedRestrictionType rstr_type;
241   CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type));
242   CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND,
243             "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented");
244 
245   // Stride data
246   bool is_strided;
247   CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided));
248   if (is_strided) {
249     bool has_backend_strides;
250     CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides));
251     if (!has_backend_strides) {
252       CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
253     }
254   } else {
255     CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
256   }
257 
258   impl->h_ind           = NULL;
259   impl->h_ind_allocated = NULL;
260   impl->d_ind           = NULL;
261   impl->d_ind_allocated = NULL;
262   impl->d_t_indices     = NULL;
263   impl->d_t_offsets     = NULL;
264   impl->num_nodes       = size;
265   CeedCallBackend(CeedElemRestrictionSetData(r, impl));
266   CeedInt layout[3] = {1, elem_size * num_elem, elem_size};
267   CeedCallBackend(CeedElemRestrictionSetELayout(r, layout));
268 
269   // Set up device indices/offset arrays
270   switch (mem_type) {
271     case CEED_MEM_HOST: {
272       switch (copy_mode) {
273         case CEED_OWN_POINTER:
274           impl->h_ind_allocated = (CeedInt *)indices;
275           impl->h_ind           = (CeedInt *)indices;
276           break;
277         case CEED_USE_POINTER:
278           impl->h_ind = (CeedInt *)indices;
279           break;
280         case CEED_COPY_VALUES:
281           if (indices != NULL) {
282             CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
283             memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
284             impl->h_ind = impl->h_ind_allocated;
285           }
286           break;
287       }
288       if (indices != NULL) {
289         CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
290         impl->d_ind_allocated = impl->d_ind;  // We own the device memory
291         CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice));
292         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices));
293       }
294       break;
295     }
296     case CEED_MEM_DEVICE: {
297       switch (copy_mode) {
298         case CEED_COPY_VALUES:
299           if (indices != NULL) {
300             CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
301             impl->d_ind_allocated = impl->d_ind;  // We own the device memory
302             CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice));
303           }
304           break;
305         case CEED_OWN_POINTER:
306           impl->d_ind           = (CeedInt *)indices;
307           impl->d_ind_allocated = impl->d_ind;
308           break;
309         case CEED_USE_POINTER:
310           impl->d_ind = (CeedInt *)indices;
311       }
312       if (indices != NULL) {
313         CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
314         CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost));
315         impl->h_ind = impl->h_ind_allocated;
316         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices));
317       }
318       break;
319     }
320     // LCOV_EXCL_START
321     default:
322       return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
323       // LCOV_EXCL_STOP
324   }
325 
326   // Compile HIP kernels
327   CeedInt num_nodes = impl->num_nodes;
328   char   *restriction_kernel_path, *restriction_kernel_source;
329   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path));
330   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
331   CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
332   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
333   CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem,
334                                   "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES",
335                                   strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2]));
336   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose));
337   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose));
338   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose));
339   if (!is_deterministic) {
340     CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose));
341   } else {
342     CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTransposeDet", &impl->OffsetTransposeDet));
343   }
344   CeedCallBackend(CeedFree(&restriction_kernel_path));
345   CeedCallBackend(CeedFree(&restriction_kernel_source));
346 
347   // Register backend functions
348   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Hip));
349   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Hip));
350   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Hip));
351   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Hip));
352   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Hip));
353   return CEED_ERROR_SUCCESS;
354 }
355 
356 //------------------------------------------------------------------------------
357