xref: /libCEED/backends/sycl-ref/ceed-sycl-restriction.sycl.cpp (revision 356036fa84f714fa73ef64c9a80ce2028dde816f)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other
2 // CEED contributors. All Rights Reserved. See the top-level LICENSE and NOTICE
3 // files for details.
4 //
5 // SPDX-License-Identifier: BSD-2-Clause
6 //
7 // This file is part of CEED:  http://github.com/ceed
8 
9 #include <ceed/backend.h>
10 #include <ceed/ceed.h>
11 #include <ceed/jit-tools.h>
12 
13 #include <string>
14 #include <sycl/sycl.hpp>
15 
16 #include "../sycl/ceed-sycl-compile.hpp"
17 #include "ceed-sycl-ref.hpp"
18 
19 class CeedElemRestrSyclStridedNT;
20 class CeedElemRestrSyclOffsetNT;
21 class CeedElemRestrSyclStridedT;
22 class CeedElemRestrSyclOffsetT;
23 
24 //------------------------------------------------------------------------------
25 // Restriction Kernel : L-vector -> E-vector, strided
26 //------------------------------------------------------------------------------
27 static int CeedElemRestrictionStridedNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
28                                                       CeedScalar *v) {
29   const CeedInt  elem_size    = impl->elem_size;
30   const CeedInt  num_elem     = impl->num_elem;
31   const CeedInt  num_comp     = impl->num_comp;
32   const CeedInt  stride_nodes = impl->strides[0];
33   const CeedInt  stride_comp  = impl->strides[1];
34   const CeedInt  stride_elem  = impl->strides[2];
35   sycl::range<1> kernel_range(num_elem * elem_size);
36 
37   // Order queue
38   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
39   sycl_queue.parallel_for<CeedElemRestrSyclStridedNT>(kernel_range, {e}, [=](sycl::id<1> node) {
40     const CeedInt loc_node = node % elem_size;
41     const CeedInt elem     = node / elem_size;
42 
43     for (CeedInt comp = 0; comp < num_comp; comp++) {
44       v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem];
45     }
46   });
47   return CEED_ERROR_SUCCESS;
48 }
49 
50 //------------------------------------------------------------------------------
51 // Restriction Kernel : L-vector -> E-vector, offsets provided
52 //------------------------------------------------------------------------------
53 static int CeedElemRestrictionOffsetNoTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
54                                                      CeedScalar *v) {
55   const CeedInt elem_size   = impl->elem_size;
56   const CeedInt num_elem    = impl->num_elem;
57   const CeedInt num_comp    = impl->num_comp;
58   const CeedInt comp_stride = impl->comp_stride;
59 
60   const CeedInt *indices = impl->d_ind;
61 
62   sycl::range<1> kernel_range(num_elem * elem_size);
63 
64   // Order queue
65   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
66   sycl_queue.parallel_for<CeedElemRestrSyclOffsetNT>(kernel_range, {e}, [=](sycl::id<1> node) {
67     const CeedInt ind      = indices[node];
68     const CeedInt loc_node = node % elem_size;
69     const CeedInt elem     = node / elem_size;
70 
71     for (CeedInt comp = 0; comp < num_comp; comp++) {
72       v[loc_node + comp * elem_size * num_elem + elem * elem_size] = u[ind + comp * comp_stride];
73     }
74   });
75   return CEED_ERROR_SUCCESS;
76 }
77 
78 //------------------------------------------------------------------------------
79 // Kernel: E-vector -> L-vector, strided
80 //------------------------------------------------------------------------------
81 static int CeedElemRestrictionStridedTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
82                                                     CeedScalar *v) {
83   const CeedInt elem_size    = impl->elem_size;
84   const CeedInt num_elem     = impl->num_elem;
85   const CeedInt num_comp     = impl->num_comp;
86   const CeedInt stride_nodes = impl->strides[0];
87   const CeedInt stride_comp  = impl->strides[1];
88   const CeedInt stride_elem  = impl->strides[2];
89 
90   sycl::range<1> kernel_range(num_elem * elem_size);
91 
92   // Order queue
93   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
94   sycl_queue.parallel_for<CeedElemRestrSyclStridedT>(kernel_range, {e}, [=](sycl::id<1> node) {
95     const CeedInt loc_node = node % elem_size;
96     const CeedInt elem     = node / elem_size;
97 
98     for (CeedInt comp = 0; comp < num_comp; comp++) {
99       v[loc_node * stride_nodes + comp * stride_comp + elem * stride_elem] += u[loc_node + comp * elem_size * num_elem + elem * elem_size];
100     }
101   });
102   return CEED_ERROR_SUCCESS;
103 }
104 
105 //------------------------------------------------------------------------------
106 // Kernel: E-vector -> L-vector, offsets provided
107 //------------------------------------------------------------------------------
108 static int CeedElemRestrictionOffsetTranspose_Sycl(sycl::queue &sycl_queue, const CeedElemRestriction_Sycl *impl, const CeedScalar *u,
109                                                    CeedScalar *v) {
110   const CeedInt num_nodes   = impl->num_nodes;
111   const CeedInt elem_size   = impl->elem_size;
112   const CeedInt num_elem    = impl->num_elem;
113   const CeedInt num_comp    = impl->num_comp;
114   const CeedInt comp_stride = impl->comp_stride;
115 
116   const CeedInt *l_vec_indices = impl->d_l_vec_indices;
117   const CeedInt *t_offsets     = impl->d_t_offsets;
118   const CeedInt *t_indices     = impl->d_t_indices;
119 
120   sycl::range<1> kernel_range(num_nodes * num_comp);
121 
122   // Order queue
123   sycl::event e = sycl_queue.ext_oneapi_submit_barrier();
124   sycl_queue.parallel_for<CeedElemRestrSyclOffsetT>(kernel_range, {e}, [=](sycl::id<1> id) {
125     const CeedInt node    = id % num_nodes;
126     const CeedInt comp    = id / num_nodes;
127     const CeedInt ind     = l_vec_indices[node];
128     const CeedInt range_1 = t_offsets[node];
129     const CeedInt range_N = t_offsets[node + 1];
130 
131     CeedScalar value = 0.0;
132 
133     for (CeedInt j = range_1; j < range_N; j++) {
134       const CeedInt t_ind    = t_indices[j];
135       CeedInt       loc_node = t_ind % elem_size;
136       CeedInt       elem     = t_ind / elem_size;
137 
138       value += u[loc_node + comp * elem_size * num_elem + elem * elem_size];
139     }
140     v[ind + comp * comp_stride] += value;
141   });
142   return CEED_ERROR_SUCCESS;
143 }
144 
145 //------------------------------------------------------------------------------
146 // Apply restriction
147 //------------------------------------------------------------------------------
148 static int CeedElemRestrictionApply_Sycl(CeedElemRestriction r, CeedTransposeMode t_mode, CeedVector u, CeedVector v, CeedRequest *request) {
149   Ceed ceed;
150   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
151   CeedElemRestriction_Sycl *impl;
152   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
153   Ceed_Sycl *data;
154   CeedCallBackend(CeedGetData(ceed, &data));
155 
156   // Get vectors
157   const CeedScalar *d_u;
158   CeedScalar       *d_v;
159   CeedCallBackend(CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u));
160   if (t_mode == CEED_TRANSPOSE) {
161     // Sum into for transpose mode, e-vec to l-vec
162     CeedCallBackend(CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v));
163   } else {
164     // Overwrite for notranspose mode, l-vec to e-vec
165     CeedCallBackend(CeedVectorGetArrayWrite(v, CEED_MEM_DEVICE, &d_v));
166   }
167 
168   // Restrict
169   if (t_mode == CEED_NOTRANSPOSE) {
170     // L-vector -> E-vector
171     if (impl->d_ind) {
172       // -- Offsets provided
173       CeedCallBackend(CeedElemRestrictionOffsetNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
174     } else {
175       // -- Strided restriction
176       CeedCallBackend(CeedElemRestrictionStridedNoTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
177     }
178   } else {
179     // E-vector -> L-vector
180     if (impl->d_ind) {
181       // -- Offsets provided
182       CeedCallBackend(CeedElemRestrictionOffsetTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
183     } else {
184       // -- Strided restriction
185       CeedCallBackend(CeedElemRestrictionStridedTranspose_Sycl(data->sycl_queue, impl, d_u, d_v));
186     }
187   }
188   // Wait for queues to be completed. NOTE: This may not be necessary
189   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
190 
191   if (request != CEED_REQUEST_IMMEDIATE && request != CEED_REQUEST_ORDERED) *request = NULL;
192 
193   // Restore arrays
194   CeedCallBackend(CeedVectorRestoreArrayRead(u, &d_u));
195   CeedCallBackend(CeedVectorRestoreArray(v, &d_v));
196   return CEED_ERROR_SUCCESS;
197 }
198 
199 //------------------------------------------------------------------------------
200 // Get offsets
201 //------------------------------------------------------------------------------
202 static int CeedElemRestrictionGetOffsets_Sycl(CeedElemRestriction r, CeedMemType m_type, const CeedInt **offsets) {
203   Ceed ceed;
204   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
205   CeedElemRestriction_Sycl *impl;
206   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
207 
208   switch (m_type) {
209     case CEED_MEM_HOST:
210       *offsets = impl->h_ind;
211       break;
212     case CEED_MEM_DEVICE:
213       *offsets = impl->d_ind;
214       break;
215   }
216   return CEED_ERROR_SUCCESS;
217 }
218 
219 //------------------------------------------------------------------------------
220 // Destroy restriction
221 //------------------------------------------------------------------------------
222 static int CeedElemRestrictionDestroy_Sycl(CeedElemRestriction r) {
223   Ceed ceed;
224   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
225   CeedElemRestriction_Sycl *impl;
226   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
227   Ceed_Sycl *data;
228   CeedCallBackend(CeedGetData(ceed, &data));
229 
230   // Wait for all work to finish before freeing memory
231   CeedCallSycl(ceed, data->sycl_queue.wait_and_throw());
232 
233   CeedCallBackend(CeedFree(&impl->h_ind_allocated));
234   CeedCallSycl(ceed, sycl::free(impl->d_ind_allocated, data->sycl_context));
235   CeedCallSycl(ceed, sycl::free(impl->d_t_offsets, data->sycl_context));
236   CeedCallSycl(ceed, sycl::free(impl->d_t_indices, data->sycl_context));
237   CeedCallSycl(ceed, sycl::free(impl->d_l_vec_indices, data->sycl_context));
238   CeedCallBackend(CeedFree(&impl));
239 
240   return CEED_ERROR_SUCCESS;
241 }
242 
243 //------------------------------------------------------------------------------
244 // Create transpose offsets and indices
245 //------------------------------------------------------------------------------
246 static int CeedElemRestrictionOffset_Sycl(const CeedElemRestriction r, const CeedInt *indices) {
247   Ceed ceed;
248   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
249   CeedElemRestriction_Sycl *impl;
250   CeedCallBackend(CeedElemRestrictionGetData(r, &impl));
251   CeedSize l_size;
252   CeedInt  num_elem, elem_size, num_comp;
253   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
254   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
255   CeedCallBackend(CeedElemRestrictionGetLVectorSize(r, &l_size));
256   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
257 
258   // Count num_nodes
259   bool *is_node;
260   CeedCallBackend(CeedCalloc(l_size, &is_node));
261   const CeedInt size_indices = num_elem * elem_size;
262   for (CeedInt i = 0; i < size_indices; i++) is_node[indices[i]] = 1;
263   CeedInt num_nodes = 0;
264   for (CeedInt i = 0; i < l_size; i++) num_nodes += is_node[i];
265   impl->num_nodes = num_nodes;
266 
267   // L-vector offsets array
268   CeedInt *ind_to_offset, *l_vec_indices;
269   CeedCallBackend(CeedCalloc(l_size, &ind_to_offset));
270   CeedCallBackend(CeedCalloc(num_nodes, &l_vec_indices));
271   CeedInt j = 0;
272   for (CeedInt i = 0; i < l_size; i++) {
273     if (is_node[i]) {
274       l_vec_indices[j] = i;
275       ind_to_offset[i] = j++;
276     }
277   }
278   CeedCallBackend(CeedFree(&is_node));
279 
280   // Compute transpose offsets and indices
281   const CeedInt size_offsets = num_nodes + 1;
282   CeedInt      *t_offsets;
283   CeedCallBackend(CeedCalloc(size_offsets, &t_offsets));
284   CeedInt *t_indices;
285   CeedCallBackend(CeedMalloc(size_indices, &t_indices));
286   // Count node multiplicity
287   for (CeedInt e = 0; e < num_elem; ++e) {
288     for (CeedInt i = 0; i < elem_size; ++i) ++t_offsets[ind_to_offset[indices[elem_size * e + i]] + 1];
289   }
290   // Convert to running sum
291   for (CeedInt i = 1; i < size_offsets; ++i) t_offsets[i] += t_offsets[i - 1];
292   // List all E-vec indices associated with L-vec node
293   for (CeedInt e = 0; e < num_elem; ++e) {
294     for (CeedInt i = 0; i < elem_size; ++i) {
295       const CeedInt lid                          = elem_size * e + i;
296       const CeedInt gid                          = indices[lid];
297       t_indices[t_offsets[ind_to_offset[gid]]++] = lid;
298     }
299   }
300   // Reset running sum
301   for (int i = size_offsets - 1; i > 0; --i) t_offsets[i] = t_offsets[i - 1];
302   t_offsets[0] = 0;
303 
304   // Copy data to device
305   Ceed_Sycl *data;
306   CeedCallBackend(CeedGetData(ceed, &data));
307 
308   // Order queue
309   sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
310 
311   // -- L-vector indices
312   CeedCallSycl(ceed, impl->d_l_vec_indices = sycl::malloc_device<CeedInt>(num_nodes, data->sycl_device, data->sycl_context));
313   sycl::event copy_lvec = data->sycl_queue.copy<CeedInt>(l_vec_indices, impl->d_l_vec_indices, num_nodes, {e});
314   // -- Transpose offsets
315   CeedCallSycl(ceed, impl->d_t_offsets = sycl::malloc_device<CeedInt>(size_offsets, data->sycl_device, data->sycl_context));
316   sycl::event copy_offsets = data->sycl_queue.copy<CeedInt>(t_offsets, impl->d_t_offsets, size_offsets, {e});
317   // -- Transpose indices
318   CeedCallSycl(ceed, impl->d_t_indices = sycl::malloc_device<CeedInt>(size_indices, data->sycl_device, data->sycl_context));
319   sycl::event copy_indices = data->sycl_queue.copy<CeedInt>(t_indices, impl->d_t_indices, size_indices, {e});
320 
321   // Wait for all copies to complete and handle exceptions
322   CeedCallSycl(ceed, sycl::event::wait_and_throw({copy_lvec, copy_offsets, copy_indices}));
323 
324   // Cleanup
325   CeedCallBackend(CeedFree(&ind_to_offset));
326   CeedCallBackend(CeedFree(&l_vec_indices));
327   CeedCallBackend(CeedFree(&t_offsets));
328   CeedCallBackend(CeedFree(&t_indices));
329 
330   return CEED_ERROR_SUCCESS;
331 }
332 
333 //------------------------------------------------------------------------------
334 // Create restriction
335 //------------------------------------------------------------------------------
336 int CeedElemRestrictionCreate_Sycl(CeedMemType mem_type, CeedCopyMode copy_mode, const CeedInt *indices, const bool *orients,
337                                    const CeedInt8 *curl_orients, CeedElemRestriction r) {
338   Ceed ceed;
339   CeedCallBackend(CeedElemRestrictionGetCeed(r, &ceed));
340   Ceed_Sycl *data;
341   CeedCallBackend(CeedGetData(ceed, &data));
342   CeedElemRestriction_Sycl *impl;
343   CeedCallBackend(CeedCalloc(1, &impl));
344   CeedInt num_elem, num_comp, elem_size;
345   CeedCallBackend(CeedElemRestrictionGetNumElements(r, &num_elem));
346   CeedCallBackend(CeedElemRestrictionGetNumComponents(r, &num_comp));
347   CeedCallBackend(CeedElemRestrictionGetElementSize(r, &elem_size));
348   CeedInt size        = num_elem * elem_size;
349   CeedInt strides[3]  = {1, size, elem_size};
350   CeedInt comp_stride = 1;
351 
352   CeedRestrictionType rstr_type;
353   CeedCallBackend(CeedElemRestrictionGetType(r, &rstr_type));
354   CeedCheck(rstr_type != CEED_RESTRICTION_ORIENTED && rstr_type != CEED_RESTRICTION_CURL_ORIENTED, ceed, CEED_ERROR_BACKEND,
355             "Backend does not implement CeedElemRestrictionCreateOriented or CeedElemRestrictionCreateCurlOriented");
356 
357   // Stride data
358   bool is_strided;
359   CeedCallBackend(CeedElemRestrictionIsStrided(r, &is_strided));
360   if (is_strided) {
361     bool has_backend_strides;
362     CeedCallBackend(CeedElemRestrictionHasBackendStrides(r, &has_backend_strides));
363     if (!has_backend_strides) {
364       CeedCallBackend(CeedElemRestrictionGetStrides(r, &strides));
365     }
366   } else {
367     CeedCallBackend(CeedElemRestrictionGetCompStride(r, &comp_stride));
368   }
369 
370   impl->h_ind           = NULL;
371   impl->h_ind_allocated = NULL;
372   impl->d_ind           = NULL;
373   impl->d_ind_allocated = NULL;
374   impl->d_t_indices     = NULL;
375   impl->d_t_offsets     = NULL;
376   impl->num_nodes       = size;
377   impl->num_elem        = num_elem;
378   impl->num_comp        = num_comp;
379   impl->elem_size       = elem_size;
380   impl->comp_stride     = comp_stride;
381   impl->strides[0]      = strides[0];
382   impl->strides[1]      = strides[1];
383   impl->strides[2]      = strides[2];
384   CeedCallBackend(CeedElemRestrictionSetData(r, impl));
385   CeedInt layout[3] = {1, elem_size * num_elem, elem_size};
386   CeedCallBackend(CeedElemRestrictionSetELayout(r, layout));
387 
388   // Set up device indices/offset arrays
389   if (m_type == CEED_MEM_HOST) {
390     switch (copy_mode) {
391       case CEED_OWN_POINTER:
392         impl->h_ind_allocated = (CeedInt *)indices;
393         impl->h_ind           = (CeedInt *)indices;
394         break;
395       case CEED_USE_POINTER:
396         impl->h_ind = (CeedInt *)indices;
397         break;
398       case CEED_COPY_VALUES:
399         if (indices != NULL) {
400           CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
401           memcpy(impl->h_ind_allocated, indices, elem_size * num_elem * sizeof(CeedInt));
402           impl->h_ind = impl->h_ind_allocated;
403         }
404         break;
405     }
406     if (indices != NULL) {
407       CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context));
408       impl->d_ind_allocated = impl->d_ind;  // We own the device memory
409       // Order queue
410       sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
411       // Copy from host to device
412       sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e});
413       // Wait for copy to finish and handle exceptions
414       CeedCallSycl(ceed, copy_event.wait_and_throw());
415       CeedCallBackend(CeedElemRestrictionOffset_Sycl(r, indices));
416     }
417   } else if (m_type == CEED_MEM_DEVICE) {
418     switch (copy_mode) {
419       case CEED_COPY_VALUES:
420         if (indices != NULL) {
421           CeedCallSycl(ceed, impl->d_ind = sycl::malloc_device<CeedInt>(size, data->sycl_device, data->sycl_context));
422           impl->d_ind_allocated = impl->d_ind;  // We own the device memory
423                                                 // Copy from device to device
424           // Order queue
425           sycl::event e          = data->sycl_queue.ext_oneapi_submit_barrier();
426           sycl::event copy_event = data->sycl_queue.copy<CeedInt>(indices, impl->d_ind, size, {e});
427           // Wait for copy to finish and handle exceptions
428           CeedCallSycl(ceed, copy_event.wait_and_throw());
429         }
430         break;
431       case CEED_OWN_POINTER:
432         impl->d_ind           = (CeedInt *)indices;
433         impl->d_ind_allocated = impl->d_ind;
434         break;
435       case CEED_USE_POINTER:
436         impl->d_ind = (CeedInt *)indices;
437     }
438     if (indices != NULL) {
439       CeedCallBackend(CeedMalloc(elem_size * num_elem, &impl->h_ind_allocated));
440       // Order queue
441       sycl::event e = data->sycl_queue.ext_oneapi_submit_barrier();
442       // Copy from device to host
443       sycl::event copy_event = data->sycl_queue.copy<CeedInt>(impl->d_ind, impl->h_ind_allocated, elem_size * num_elem, {e});
444       CeedCallSycl(ceed, copy_event.wait_and_throw());
445       impl->h_ind = impl->h_ind_allocated;
446       CeedCallBackend(CeedElemRestrictionOffset_Sycl(r, indices));
447     }
448   } else {
449     // LCOV_EXCL_START
450     return CeedError(ceed, CEED_ERROR_BACKEND, "Only MemType = HOST or DEVICE supported");
451     // LCOV_EXCL_STOP
452   }
453 
454   // Register backend functions
455   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "Apply", CeedElemRestrictionApply_Sycl));
456   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "ApplyUnsigned", CeedElemRestrictionApply_Sycl));
457   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "ApplyUnoriented", CeedElemRestrictionApply_Sycl));
458   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "GetOffsets", CeedElemRestrictionGetOffsets_Sycl));
459   CeedCallBackend(CeedSetBackendFunctionCpp(ceed, "ElemRestriction", r, "Destroy", CeedElemRestrictionDestroy_Sycl));
460   return CEED_ERROR_SUCCESS;
461 }
462