xref: /libCEED/backends/hip-ref/ceed-hip-ref-restriction.c (revision aa67b84255fd38cedae0f40d1566f643808af2e9)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <ceed/jit-tools.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <string.h>
14 #include <hip/hip_runtime.h>
15 
16 #include "../hip/ceed-hip-common.h"
17 #include "../hip/ceed-hip-compile.h"
18 #include "ceed-hip-ref.h"
19 
20 //------------------------------------------------------------------------------
21 // Apply restriction
22 //------------------------------------------------------------------------------
23 static int CeedElemRestrictionApply_Hip(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
24   Ceed                     ceed;
25   Ceed_Hip                *data;
26   CeedInt                  num_elem, elem_size;
27   const CeedScalar        *d_u;
28   CeedScalar              *d_v;
29   CeedElemRestriction_Hip *impl;
30   hipFunction_t            kernel;
31 
32   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
33   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
34   CeedCallBackend(CeedGetData(ceed, &data));
35   CeedElemRestrictionGetNumElements(r, &num_elem);
36   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
37   const CeedInt num_nodes = impl->num_nodes;
38 
39   // Get vectors
40   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
41   if (t_mode == CEED_TRANSPOSE) {
42     // Sum into for transpose mode, e-vec to l-vec
43     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
44   } else {
45     // Overwrite for notranspose mode, l-vec to e-vec
46     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
47   }
48 
49   // Restrict
50   if (t_mode == CEED_NOTRANSPOSE) {
51     // L-vector -> E-vector
52     if (impl->d_ind) {
53       // -- Offsets provided
54       kernel             = impl->OffsetNoTranspose;
55       void   *args[]     = {&num_elem, &impl->d_ind, &d_u, &d_v};
56       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
57 
58       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
59     } else {
60       // -- Strided restriction
61       kernel             = impl->StridedNoTranspose;
62       void   *args[]     = {&num_elem, &d_u, &d_v};
63       CeedInt block_size = elem_size < 256 ? (elem_size > 64 ? elem_size : 64) : 256;
64 
65       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
66     }
67   } else {
68     // E-vector -> L-vector
69     if (impl->d_ind) {
70       // -- Offsets provided
71       CeedInt block_size = 64;
72 
73       if (impl->OffsetTranspose) {
74         kernel       = impl->OffsetTranspose;
75         void *args[] = {&num_elem, &impl->d_ind, &d_u, &d_v};
76 
77         CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
78       } else {
79         kernel       = impl->OffsetTransposeDet;
80         void *args[] = {&impl->d_l_vec_indices, &impl->d_t_indices, &impl->d_t_offsets, &d_u, &d_v};
81 
82         CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
83       }
84     } else {
85       // -- Strided restriction
86       kernel             = impl->StridedTranspose;
87       void   *args[]     = {&num_elem, &d_u, &d_v};
88       CeedInt block_size = 64;
89 
90       CeedCallBackend(CeedRunKernel_Hip(ceed, kernel, CeedDivUpInt(num_nodes, block_size), block_size, args));
91     }
92   }
93 
94   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
95 
96   // Restore arrays
97   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
98   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
99   return CEED_ERROR_SUCCESS;
100 }
101 
102 //------------------------------------------------------------------------------
103 // Get offsets
104 //------------------------------------------------------------------------------
105 static int CeedElemRestrictionGetOffsets_Hip(CeedElemRestriction rstr, CeedMemType mem_type, const CeedInt **offsets) {
106   CeedElemRestriction_Hip *impl;
107 
108   CeedCallBackend(CeedElemRestrictionGetData(rstr, &impl));
109   switch (mem_type) {
110     case CEED_MEM_HOST:
111       *offsets = impl->h_ind;
112       break;
113     case CEED_MEM_DEVICE:
114       *offsets = impl->d_ind;
115       break;
116   }
117   return CEED_ERROR_SUCCESS;
118 }
119 
120 //------------------------------------------------------------------------------
121 // Destroy restriction
122 //------------------------------------------------------------------------------
123 static int CeedElemRestrictionDestroy_Hip(CeedElemRestriction r) {
124   Ceed                     ceed;
125   CeedElemRestriction_Hip *impl;
126 
127   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
128   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
129   CeedCallHip(ceed, hipModuleUnload(impl->module));
130   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
131   CeedCallHip(ceed, hipFree(impl->d_ind_allocated));
132   CeedCallHip(ceed, hipFree(impl->d_t_offsets));
133   CeedCallHip(ceed, hipFree(impl->d_t_indices));
134   CeedCallHip(ceed, hipFree(impl->d_l_vec_indices));
135   CeedCallBackend(CeedFree(&impl));
136   return CEED_ERROR_SUCCESS;
137 }
138 
139 //------------------------------------------------------------------------------
140 // Create transpose offsets and indices
141 //------------------------------------------------------------------------------
142 static int CeedElemRestrictionOffset_Hip(const CeedElemRestriction r, const CeedInt *indices) {
143   Ceed                     ceed;
144   bool                    *is_node;
145   CeedSize                 l_size;
146   CeedInt                  num_elem, elem_size, num_comp, num_nodes = 0, *ind_to_offset, *l_vec_indices, *t_offsets, *t_indices;
147   CeedElemRestriction_Hip *impl;
148 
149   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
150   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
151   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
152   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
153   CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
154   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
155   const CeedInt size_indices = num_elem * elem_size;
156 
157   // Count num_nodes
158   CeedCallBackend(CeedCalloc(l_size, &is_node));
159   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
160   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
161   impl->num_nodes = num_nodes;
162 
163   // L-vector offsets array
164   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
165   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
166   for (CeedInt i = 0, j = 0; i < l_size; i++) {
167     if (is_node[i]) {
168       l_vec_indices[j] = i;
169       ind_to_offset[i] = j++;
170     }
171   }
172   CeedCallBackend(CeedFree(&is_node));
173 
174   // Compute transpose offsets and indices
175   const CeedInt size_offsets = num_nodes + 1;
176 
177   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
178   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
179   // Count node multiplicity
180   for (CeedInt e = 0; e < num_elem; ++e) {
181     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
182   }
183   // Convert to running sum
184   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
185   // List all E-vec indices associated with L-vec node
186   for (CeedInt e = 0; e < num_elem; ++e) {
187     for (CeedInt i = 0; i < elem_size; ++i) {
188       const CeedInt lid = elem_size * e + i;
189       const CeedInt gid = indices[lid];
190 
191       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
192     }
193   }
194   // Reset running sum
195   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
196   t_offsets[0] = 0;
197 
198   // Copy data to device
199   // -- L-vector indices
200   CeedCallHip(ceed, hipMalloc((void **)&impl->d_l_vec_indices, num_nodes * sizeof(CeedInt)));
201   CeedCallHip(ceed, hipMemcpy(impl->d_l_vec_indices, l_vec_indices, num_nodes * sizeof(CeedInt), hipMemcpyHostToDevice));
202   // -- Transpose offsets
203   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_offsets, size_offsets * sizeof(CeedInt)));
204   CeedCallHip(ceed, hipMemcpy(impl->d_t_offsets, t_offsets, size_offsets * sizeof(CeedInt), hipMemcpyHostToDevice));
205   // -- Transpose indices
206   CeedCallHip(ceed, hipMalloc((void **)&impl->d_t_indices, size_indices * sizeof(CeedInt)));
207   CeedCallHip(ceed, hipMemcpy(impl->d_t_indices, t_indices, size_indices * sizeof(CeedInt), hipMemcpyHostToDevice));
208 
209   // Cleanup
210   CeedCallBackend(CeedFree(&ind_to_offset));
211   CeedCallBackend(CeedFree(&l_vec_indices));
212   CeedCallBackend(CeedFree(&t_offsets));
213   CeedCallBackend(CeedFree(&t_indices));
214   return CEED_ERROR_SUCCESS;
215 }
216 
217 //------------------------------------------------------------------------------
218 // Create restriction
219 //------------------------------------------------------------------------------
220 int CeedElemRestrictionCreate_Hip(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients,
221                                   const CeedInt8 *curl_orients, CeedElemRestriction r) {
222   Ceed                     ceed, ceed_parent;
223   bool                     is_deterministic, is_strided;
224   char                    *restriction_kernel_path, *restriction_kernel_source;
225   CeedInt                  num_elem, num_comp, elem_size, comp_stride = 1;
226   CeedRestrictionType      rstr_type;
227   CeedElemRestriction_Hip *impl;
228 
229   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
230   CeedCallBackend(CeedCalloc(1, &impl));
231   CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
232   CeedCallBackend(CeedIsDeterministic(ceed_parent, &is_deterministic));
233   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
234   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
235   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
236   CeedInt size       = num_elem * elem_size;
237   CeedInt strides[3] = {1, size, elem_size};
238   CeedInt layout[3]  = {1, elem_size * num_elem, elem_size};
239 
240   CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type));
241   CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND,
242             "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented");
243 
244   // Stride data
245   CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided));
246   if (is_strided) {
247     bool has_backend_strides;
248 
249     CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides));
250     if (!has_backend_strides) {
251       CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
252     }
253   } else {
254     CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
255   }
256 
257   impl->h_ind           = NULL;
258   impl->h_ind_allocated = NULL;
259   impl->d_ind           = NULL;
260   impl->d_ind_allocated = NULL;
261   impl->d_t_indices     = NULL;
262   impl->d_t_offsets     = NULL;
263   impl->num_nodes       = size;
264   CeedCallBackend(CeedElemRestrictionSetData(r, impl));
265   CeedCallBackend(CeedElemRestrictionSetELayout(r, layout));
266 
267   // Set up device indices/offset arrays
268   switch (mem_type) {
269     case CEED_MEM_HOST: {
270       switch (copy_mode) {
271         case CEED_OWN_POINTER:
272           impl->h_ind_allocated = (CeedInt *)indices;
273           impl->h_ind           = (CeedInt *)indices;
274           break;
275         case CEED_USE_POINTER:
276           impl->h_ind = (CeedInt *)indices;
277           break;
278         case CEED_COPY_VALUES:
279           if (indices != NULL) {
280             CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
281             memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
282             impl->h_ind = impl->h_ind_allocated;
283           }
284           break;
285       }
286       if (indices != NULL) {
287         CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
288         impl->d_ind_allocated = impl->d_ind;  // We own the device memory
289         CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyHostToDevice));
290         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices));
291       }
292       break;
293     }
294     case CEED_MEM_DEVICE: {
295       switch (copy_mode) {
296         case CEED_COPY_VALUES:
297           if (indices != NULL) {
298             CeedCallHip(ceed, hipMalloc((void **)&impl->d_ind, size * sizeof(CeedInt)));
299             impl->d_ind_allocated = impl->d_ind;  // We own the device memory
300             CeedCallHip(ceed, hipMemcpy(impl->d_ind, indices, size * sizeof(CeedInt), hipMemcpyDeviceToDevice));
301           }
302           break;
303         case CEED_OWN_POINTER:
304           impl->d_ind           = (CeedInt *)indices;
305           impl->d_ind_allocated = impl->d_ind;
306           break;
307         case CEED_USE_POINTER:
308           impl->d_ind = (CeedInt *)indices;
309       }
310       if (indices != NULL) {
311         CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
312         CeedCallHip(ceed, hipMemcpy(impl->h_ind_allocated, impl->d_ind, elem_size * num_elem * sizeof(CeedInt), hipMemcpyDeviceToHost));
313         impl->h_ind = impl->h_ind_allocated;
314         if (is_deterministic) CeedCallBackend(CeedElemRestrictionOffset_Hip(r, indices));
315       }
316       break;
317     }
318     // LCOV_EXCL_START
319     default:
320       return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
321       // LCOV_EXCL_STOP
322   }
323 
324   // Compile HIP kernels
325   CeedInt num_nodes = impl->num_nodes;
326 
327   CeedCallBackend(CeedGetJitAbsolutePath(ceed, "ceed/jit-source/hip/hip-ref-restriction.h", &restriction_kernel_path));
328   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source -----\n");
329   CeedCallBackend(CeedLoadSourceToBuffer(ceed, restriction_kernel_path, &restriction_kernel_source));
330   CeedDebug256(ceed, CEED_DEBUG_COLOR_SUCCESS, "----- Loading Restriction Kernel Source Complete! -----\n");
331   CeedCallBackend(CeedCompile_Hip(ceed, restriction_kernel_source, &impl->module, 8, "RESTR_ELEM_SIZE", elem_size, "RESTR_NUM_ELEM", num_elem,
332                                   "RESTR_NUM_COMP", num_comp, "RESTR_NUM_NODES", num_nodes, "RESTR_COMP_STRIDE", comp_stride, "RESTR_STRIDE_NODES",
333                                   strides[0], "RESTR_STRIDE_COMP", strides[1], "RESTR_STRIDE_ELEM", strides[2]));
334   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedNoTranspose", &impl->StridedNoTranspose));
335   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "StridedTranspose", &impl->StridedTranspose));
336   CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetNoTranspose", &impl->OffsetNoTranspose));
337   if (!is_deterministic) {
338     CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTranspose", &impl->OffsetTranspose));
339   } else {
340     CeedCallBackend(CeedGetKernel_Hip(ceed, impl->module, "OffsetTransposeDet", &impl->OffsetTransposeDet));
341   }
342   CeedCallBackend(CeedFree(&restriction_kernel_path));
343   CeedCallBackend(CeedFree(&restriction_kernel_source));
344 
345   // Register backend functions
346   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Hip));
347   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Hip));
348   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Hip));
349   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Hip));
350   CeedCallBackend(CeedSetBackendFunction(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Hip));
351   return CEED_ERROR_SUCCESS;
352 }
353 
354 //------------------------------------------------------------------------------
355