xref: /libCEED/rust/libceed/src/elem_restriction.rs (revision 2459f3f1cd4d7d2e210e1c26d669bd2fde41a0b6)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 //! A Ceed ElemRestriction decomposes elements and groups the degrees of freedom
9 //! (dofs) according to the different elements they belong to.
10 
11 use crate::prelude::*;
12 
13 // -----------------------------------------------------------------------------
14 // ElemRestriction option
15 // -----------------------------------------------------------------------------
16 #[derive(Debug)]
17 pub enum ElemRestrictionOpt<'a> {
18     Some(&'a ElemRestriction<'a>),
19     None,
20 }
21 /// Construct a ElemRestrictionOpt reference from a ElemRestriction reference
22 impl<'a> From<&'a ElemRestriction<'_>> for ElemRestrictionOpt<'a> {
23     fn from(restr: &'a ElemRestriction) -> Self {
24         debug_assert!(restr.ptr != unsafe { bind_ceed::CEED_ELEMRESTRICTION_NONE });
25         Self::Some(restr)
26     }
27 }
28 impl<'a> ElemRestrictionOpt<'a> {
29     /// Transform a Rust libCEED ElemRestrictionOpt into C libCEED
30     /// CeedElemRestriction
31     pub(crate) fn to_raw(self) -> bind_ceed::CeedElemRestriction {
32         match self {
33             Self::Some(restr) => restr.ptr,
34             Self::None => unsafe { bind_ceed::CEED_ELEMRESTRICTION_NONE },
35         }
36     }
37 
38     /// Check if an ElemRestrictionOpt is Some
39     ///
40     /// ```
41     /// # use libceed::prelude::*;
42     /// # fn main() -> libceed::Result<()> {
43     /// # let ceed = libceed::Ceed::default_init();
44     /// let nelem = 3;
45     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
46     /// for i in 0..nelem {
47     ///     ind[2 * i + 0] = i as i32;
48     ///     ind[2 * i + 1] = (i + 1) as i32;
49     /// }
50     /// let r = ceed.elem_restriction(nelem, 2, 1, 1, nelem + 1, MemType::Host, &ind)?;
51     /// let r_opt = ElemRestrictionOpt::from(&r);
52     /// assert!(r_opt.is_some(), "Incorrect ElemRestrictionOpt");
53     ///
54     /// let r_opt = ElemRestrictionOpt::None;
55     /// assert!(!r_opt.is_some(), "Incorrect ElemRestrictionOpt");
56     /// # Ok(())
57     /// # }
58     /// ```
59     pub fn is_some(&self) -> bool {
60         match self {
61             Self::Some(_) => true,
62             Self::None => false,
63         }
64     }
65 
66     /// Check if an ElemRestrictionOpt is None
67     ///
68     /// ```
69     /// # use libceed::prelude::*;
70     /// # fn main() -> libceed::Result<()> {
71     /// # let ceed = libceed::Ceed::default_init();
72     /// let nelem = 3;
73     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
74     /// for i in 0..nelem {
75     ///     ind[2 * i + 0] = i as i32;
76     ///     ind[2 * i + 1] = (i + 1) as i32;
77     /// }
78     /// let r = ceed.elem_restriction(nelem, 2, 1, 1, nelem + 1, MemType::Host, &ind)?;
79     /// let r_opt = ElemRestrictionOpt::from(&r);
80     /// assert!(!r_opt.is_none(), "Incorrect ElemRestrictionOpt");
81     ///
82     /// let r_opt = ElemRestrictionOpt::None;
83     /// assert!(r_opt.is_none(), "Incorrect ElemRestrictionOpt");
84     /// # Ok(())
85     /// # }
86     /// ```
87     pub fn is_none(&self) -> bool {
88         match self {
89             Self::Some(_) => false,
90             Self::None => true,
91         }
92     }
93 }
94 
95 // -----------------------------------------------------------------------------
96 // ElemRestriction context wrapper
97 // -----------------------------------------------------------------------------
98 #[derive(Debug)]
99 pub struct ElemRestriction<'a> {
100     pub(crate) ptr: bind_ceed::CeedElemRestriction,
101     _lifeline: PhantomData<&'a ()>,
102 }
103 
104 // -----------------------------------------------------------------------------
105 // Destructor
106 // -----------------------------------------------------------------------------
107 impl<'a> Drop for ElemRestriction<'a> {
108     fn drop(&mut self) {
109         unsafe {
110             if self.ptr != bind_ceed::CEED_ELEMRESTRICTION_NONE {
111                 bind_ceed::CeedElemRestrictionDestroy(&mut self.ptr);
112             }
113         }
114     }
115 }
116 
117 // -----------------------------------------------------------------------------
118 // Display
119 // -----------------------------------------------------------------------------
120 impl<'a> fmt::Display for ElemRestriction<'a> {
121     /// View an ElemRestriction
122     ///
123     /// ```
124     /// # use libceed::prelude::*;
125     /// # fn main() -> libceed::Result<()> {
126     /// # let ceed = libceed::Ceed::default_init();
127     /// let nelem = 3;
128     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
129     /// for i in 0..nelem {
130     ///     ind[2 * i + 0] = i as i32;
131     ///     ind[2 * i + 1] = (i + 1) as i32;
132     /// }
133     /// let r = ceed.elem_restriction(nelem, 2, 1, 1, nelem + 1, MemType::Host, &ind)?;
134     /// println!("{}", r);
135     /// # Ok(())
136     /// # }
137     /// ```
138     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
139         let mut ptr = std::ptr::null_mut();
140         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
141         let cstring = unsafe {
142             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
143             bind_ceed::CeedElemRestrictionView(self.ptr, file);
144             bind_ceed::fclose(file);
145             CString::from_raw(ptr)
146         };
147         cstring.to_string_lossy().fmt(f)
148     }
149 }
150 
151 // -----------------------------------------------------------------------------
152 // Implementations
153 // -----------------------------------------------------------------------------
154 impl<'a> ElemRestriction<'a> {
155     // Constructors
156     pub fn create(
157         ceed: &crate::Ceed,
158         nelem: usize,
159         elemsize: usize,
160         ncomp: usize,
161         compstride: usize,
162         lsize: usize,
163         mtype: crate::MemType,
164         offsets: &[i32],
165     ) -> crate::Result<Self> {
166         let mut ptr = std::ptr::null_mut();
167         let (nelem, elemsize, ncomp, compstride, lsize, mtype) = (
168             i32::try_from(nelem).unwrap(),
169             i32::try_from(elemsize).unwrap(),
170             i32::try_from(ncomp).unwrap(),
171             i32::try_from(compstride).unwrap(),
172             isize::try_from(lsize).unwrap(),
173             mtype as bind_ceed::CeedMemType,
174         );
175         let ierr = unsafe {
176             bind_ceed::CeedElemRestrictionCreate(
177                 ceed.ptr,
178                 nelem,
179                 elemsize,
180                 ncomp,
181                 compstride,
182                 lsize,
183                 mtype,
184                 crate::CopyMode::CopyValues as bind_ceed::CeedCopyMode,
185                 offsets.as_ptr(),
186                 &mut ptr,
187             )
188         };
189         ceed.check_error(ierr)?;
190         Ok(Self {
191             ptr,
192             _lifeline: PhantomData,
193         })
194     }
195 
196     pub fn create_strided(
197         ceed: &crate::Ceed,
198         nelem: usize,
199         elemsize: usize,
200         ncomp: usize,
201         lsize: usize,
202         strides: [i32; 3],
203     ) -> crate::Result<Self> {
204         let mut ptr = std::ptr::null_mut();
205         let (nelem, elemsize, ncomp, lsize) = (
206             i32::try_from(nelem).unwrap(),
207             i32::try_from(elemsize).unwrap(),
208             i32::try_from(ncomp).unwrap(),
209             isize::try_from(lsize).unwrap(),
210         );
211         let ierr = unsafe {
212             bind_ceed::CeedElemRestrictionCreateStrided(
213                 ceed.ptr,
214                 nelem,
215                 elemsize,
216                 ncomp,
217                 lsize,
218                 strides.as_ptr(),
219                 &mut ptr,
220             )
221         };
222         ceed.check_error(ierr)?;
223         Ok(Self {
224             ptr,
225             _lifeline: PhantomData,
226         })
227     }
228 
229     // Error handling
230     #[doc(hidden)]
231     fn check_error(&self, ierr: i32) -> crate::Result<i32> {
232         let mut ptr = std::ptr::null_mut();
233         unsafe {
234             bind_ceed::CeedElemRestrictionGetCeed(self.ptr, &mut ptr);
235         }
236         crate::check_error(ptr, ierr)
237     }
238 
239     /// Create an Lvector for an ElemRestriction
240     ///
241     /// ```
242     /// # use libceed::prelude::*;
243     /// # fn main() -> libceed::Result<()> {
244     /// # let ceed = libceed::Ceed::default_init();
245     /// let nelem = 3;
246     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
247     /// for i in 0..nelem {
248     ///     ind[2 * i + 0] = i as i32;
249     ///     ind[2 * i + 1] = (i + 1) as i32;
250     /// }
251     /// let r = ceed.elem_restriction(nelem, 2, 1, 1, nelem + 1, MemType::Host, &ind)?;
252     ///
253     /// let lvector = r.create_lvector()?;
254     ///
255     /// assert_eq!(lvector.length(), nelem + 1, "Incorrect Lvector size");
256     /// # Ok(())
257     /// # }
258     /// ```
259     pub fn create_lvector<'b>(&self) -> crate::Result<Vector<'b>> {
260         let mut ptr_lvector = std::ptr::null_mut();
261         let null = std::ptr::null_mut() as *mut _;
262         let ierr =
263             unsafe { bind_ceed::CeedElemRestrictionCreateVector(self.ptr, &mut ptr_lvector, null) };
264         self.check_error(ierr)?;
265         Vector::from_raw(ptr_lvector)
266     }
267 
268     /// Create an Evector for an ElemRestriction
269     ///
270     /// ```
271     /// # use libceed::prelude::*;
272     /// # fn main() -> libceed::Result<()> {
273     /// # let ceed = libceed::Ceed::default_init();
274     /// let nelem = 3;
275     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
276     /// for i in 0..nelem {
277     ///     ind[2 * i + 0] = i as i32;
278     ///     ind[2 * i + 1] = (i + 1) as i32;
279     /// }
280     /// let r = ceed.elem_restriction(nelem, 2, 1, 1, nelem + 1, MemType::Host, &ind)?;
281     ///
282     /// let evector = r.create_evector()?;
283     ///
284     /// assert_eq!(evector.length(), nelem * 2, "Incorrect Evector size");
285     /// # Ok(())
286     /// # }
287     /// ```
288     pub fn create_evector<'b>(&self) -> crate::Result<Vector<'b>> {
289         let mut ptr_evector = std::ptr::null_mut();
290         let null = std::ptr::null_mut() as *mut _;
291         let ierr =
292             unsafe { bind_ceed::CeedElemRestrictionCreateVector(self.ptr, null, &mut ptr_evector) };
293         self.check_error(ierr)?;
294         Vector::from_raw(ptr_evector)
295     }
296 
297     /// Create Vectors for an ElemRestriction
298     ///
299     /// ```
300     /// # use libceed::prelude::*;
301     /// # fn main() -> libceed::Result<()> {
302     /// # let ceed = libceed::Ceed::default_init();
303     /// let nelem = 3;
304     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
305     /// for i in 0..nelem {
306     ///     ind[2 * i + 0] = i as i32;
307     ///     ind[2 * i + 1] = (i + 1) as i32;
308     /// }
309     /// let r = ceed.elem_restriction(nelem, 2, 1, 1, nelem + 1, MemType::Host, &ind)?;
310     ///
311     /// let (lvector, evector) = r.create_vectors()?;
312     ///
313     /// assert_eq!(lvector.length(), nelem + 1, "Incorrect Lvector size");
314     /// assert_eq!(evector.length(), nelem * 2, "Incorrect Evector size");
315     /// # Ok(())
316     /// # }
317     /// ```
318     pub fn create_vectors<'b, 'c>(&self) -> crate::Result<(Vector<'b>, Vector<'c>)> {
319         let mut ptr_lvector = std::ptr::null_mut();
320         let mut ptr_evector = std::ptr::null_mut();
321         let ierr = unsafe {
322             bind_ceed::CeedElemRestrictionCreateVector(self.ptr, &mut ptr_lvector, &mut ptr_evector)
323         };
324         self.check_error(ierr)?;
325         let lvector = Vector::from_raw(ptr_lvector)?;
326         let evector = Vector::from_raw(ptr_evector)?;
327         Ok((lvector, evector))
328     }
329 
330     /// Restrict an Lvector to an Evector or apply its transpose
331     ///
332     /// # arguments
333     ///
334     /// * `tmode` - Apply restriction or transpose
335     /// * `u`     - Input vector (of size `lsize` when `TransposeMode::NoTranspose`)
336     /// * `ru`    - Output vector (of shape `[nelem * elemsize]` when
337     ///               `TransposeMode::NoTranspose`). Ordering of the Evector is
338     ///               decided by the backend.
339     ///
340     /// ```
341     /// # use libceed::prelude::*;
342     /// # fn main() -> libceed::Result<()> {
343     /// # let ceed = libceed::Ceed::default_init();
344     /// let nelem = 3;
345     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
346     /// for i in 0..nelem {
347     ///     ind[2 * i + 0] = i as i32;
348     ///     ind[2 * i + 1] = (i + 1) as i32;
349     /// }
350     /// let r = ceed.elem_restriction(nelem, 2, 1, 1, nelem + 1, MemType::Host, &ind)?;
351     ///
352     /// let x = ceed.vector_from_slice(&[0., 1., 2., 3.])?;
353     /// let mut y = ceed.vector(nelem * 2)?;
354     /// y.set_value(0.0);
355     ///
356     /// r.apply(TransposeMode::NoTranspose, &x, &mut y)?;
357     ///
358     /// for (i, y) in y.view()?.iter().enumerate() {
359     ///     assert_eq!(
360     ///         *y,
361     ///         ((i + 1) / 2) as Scalar,
362     ///         "Incorrect value in restricted vector"
363     ///     );
364     /// }
365     /// # Ok(())
366     /// # }
367     /// ```
368     pub fn apply(&self, tmode: TransposeMode, u: &Vector, ru: &mut Vector) -> crate::Result<i32> {
369         let tmode = tmode as bind_ceed::CeedTransposeMode;
370         let ierr = unsafe {
371             bind_ceed::CeedElemRestrictionApply(
372                 self.ptr,
373                 tmode,
374                 u.ptr,
375                 ru.ptr,
376                 bind_ceed::CEED_REQUEST_IMMEDIATE,
377             )
378         };
379         self.check_error(ierr)
380     }
381 
382     /// Returns the Lvector component stride
383     ///
384     /// ```
385     /// # use libceed::prelude::*;
386     /// # fn main() -> libceed::Result<()> {
387     /// # let ceed = libceed::Ceed::default_init();
388     /// let nelem = 3;
389     /// let compstride = 1;
390     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
391     /// for i in 0..nelem {
392     ///     ind[2 * i + 0] = i as i32;
393     ///     ind[2 * i + 1] = (i + 1) as i32;
394     /// }
395     /// let r = ceed.elem_restriction(nelem, 2, 1, compstride, nelem + 1, MemType::Host, &ind)?;
396     ///
397     /// let c = r.comp_stride();
398     /// assert_eq!(c, compstride, "Incorrect component stride");
399     /// # Ok(())
400     /// # }
401     /// ```
402     pub fn comp_stride(&self) -> usize {
403         let mut compstride = 0;
404         unsafe { bind_ceed::CeedElemRestrictionGetCompStride(self.ptr, &mut compstride) };
405         usize::try_from(compstride).unwrap()
406     }
407 
408     /// Returns the total number of elements in the range of a ElemRestriction
409     ///
410     /// ```
411     /// # use libceed::prelude::*;
412     /// # fn main() -> libceed::Result<()> {
413     /// # let ceed = libceed::Ceed::default_init();
414     /// let nelem = 3;
415     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
416     /// for i in 0..nelem {
417     ///     ind[2 * i + 0] = i as i32;
418     ///     ind[2 * i + 1] = (i + 1) as i32;
419     /// }
420     /// let r = ceed.elem_restriction(nelem, 2, 1, 1, nelem + 1, MemType::Host, &ind)?;
421     ///
422     /// let n = r.num_elements();
423     /// assert_eq!(n, nelem, "Incorrect number of elements");
424     /// # Ok(())
425     /// # }
426     /// ```
427     pub fn num_elements(&self) -> usize {
428         let mut numelem = 0;
429         unsafe { bind_ceed::CeedElemRestrictionGetNumElements(self.ptr, &mut numelem) };
430         usize::try_from(numelem).unwrap()
431     }
432 
433     /// Returns the size of elements in the ElemRestriction
434     ///
435     /// ```
436     /// # use libceed::prelude::*;
437     /// # fn main() -> libceed::Result<()> {
438     /// # let ceed = libceed::Ceed::default_init();
439     /// let nelem = 3;
440     /// let elem_size = 2;
441     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
442     /// for i in 0..nelem {
443     ///     ind[2 * i + 0] = i as i32;
444     ///     ind[2 * i + 1] = (i + 1) as i32;
445     /// }
446     /// let r = ceed.elem_restriction(nelem, elem_size, 1, 1, nelem + 1, MemType::Host, &ind)?;
447     ///
448     /// let e = r.elem_size();
449     /// assert_eq!(e, elem_size, "Incorrect element size");
450     /// # Ok(())
451     /// # }
452     /// ```
453     pub fn elem_size(&self) -> usize {
454         let mut elemsize = 0;
455         unsafe { bind_ceed::CeedElemRestrictionGetElementSize(self.ptr, &mut elemsize) };
456         usize::try_from(elemsize).unwrap()
457     }
458 
459     /// Returns the size of the Lvector for an ElemRestriction
460     ///
461     /// ```
462     /// # use libceed::prelude::*;
463     /// # fn main() -> libceed::Result<()> {
464     /// # let ceed = libceed::Ceed::default_init();
465     /// let nelem = 3;
466     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
467     /// for i in 0..nelem {
468     ///     ind[2 * i + 0] = i as i32;
469     ///     ind[2 * i + 1] = (i + 1) as i32;
470     /// }
471     /// let r = ceed.elem_restriction(nelem, 2, 1, 1, nelem + 1, MemType::Host, &ind)?;
472     ///
473     /// let lsize = r.lvector_size();
474     /// assert_eq!(lsize, nelem + 1);
475     /// # Ok(())
476     /// # }
477     /// ```
478     pub fn lvector_size(&self) -> usize {
479         let mut lsize = 0;
480         unsafe { bind_ceed::CeedElemRestrictionGetLVectorSize(self.ptr, &mut lsize) };
481         usize::try_from(lsize).unwrap()
482     }
483 
484     /// Returns the number of components in the elements of an ElemRestriction
485     ///
486     /// ```
487     /// # use libceed::prelude::*;
488     /// # fn main() -> libceed::Result<()> {
489     /// # let ceed = libceed::Ceed::default_init();
490     /// let nelem = 3;
491     /// let ncomp = 42;
492     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
493     /// for i in 0..nelem {
494     ///     ind[2 * i + 0] = i as i32;
495     ///     ind[2 * i + 1] = (i + 1) as i32;
496     /// }
497     /// let r = ceed.elem_restriction(nelem, 2, 42, 1, ncomp * (nelem + 1), MemType::Host, &ind)?;
498     ///
499     /// let n = r.num_components();
500     /// assert_eq!(n, ncomp, "Incorrect number of components");
501     /// # Ok(())
502     /// # }
503     /// ```
504     pub fn num_components(&self) -> usize {
505         let mut ncomp = 0;
506         unsafe { bind_ceed::CeedElemRestrictionGetNumComponents(self.ptr, &mut ncomp) };
507         usize::try_from(ncomp).unwrap()
508     }
509 
510     /// Returns the multiplicity of nodes in an ElemRestriction
511     ///
512     /// ```
513     /// # use libceed::prelude::*;
514     /// # fn main() -> libceed::Result<()> {
515     /// # let ceed = libceed::Ceed::default_init();
516     /// let nelem = 3;
517     /// let mut ind: Vec<i32> = vec![0; 2 * nelem];
518     /// for i in 0..nelem {
519     ///     ind[2 * i + 0] = i as i32;
520     ///     ind[2 * i + 1] = (i + 1) as i32;
521     /// }
522     /// let r = ceed.elem_restriction(nelem, 2, 1, 1, nelem + 1, MemType::Host, &ind)?;
523     ///
524     /// let mut mult = ceed.vector(nelem + 1)?;
525     /// mult.set_value(0.0);
526     ///
527     /// r.multiplicity(&mut mult)?;
528     ///
529     /// for (i, m) in mult.view()?.iter().enumerate() {
530     ///     assert_eq!(
531     ///         *m,
532     ///         if (i == 0 || i == nelem) { 1. } else { 2. },
533     ///         "Incorrect multiplicity value"
534     ///     );
535     /// }
536     /// # Ok(())
537     /// # }
538     /// ```
539     pub fn multiplicity(&self, mult: &mut Vector) -> crate::Result<i32> {
540         let ierr = unsafe { bind_ceed::CeedElemRestrictionGetMultiplicity(self.ptr, mult.ptr) };
541         self.check_error(ierr)
542     }
543 }
544 
545 // -----------------------------------------------------------------------------
546