xref: /libCEED/rust/libceed/src/basis.rs (revision 8a94a473032dc6ed59a2cf0afe1d886fbdb591f4)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 //! A Ceed Basis defines the discrete finite element basis and associated
9 //! quadrature rule.
10 
11 use crate::prelude::*;
12 
13 // -----------------------------------------------------------------------------
14 // Basis option
15 // -----------------------------------------------------------------------------
16 #[derive(Debug)]
17 pub enum BasisOpt<'a> {
18     Some(&'a Basis<'a>),
19     Collocated,
20 }
21 /// Construct a BasisOpt reference from a Basis reference
22 impl<'a> From<&'a Basis<'_>> for BasisOpt<'a> {
23     fn from(basis: &'a Basis) -> Self {
24         debug_assert!(basis.ptr != unsafe { bind_ceed::CEED_BASIS_COLLOCATED });
25         Self::Some(basis)
26     }
27 }
28 impl<'a> BasisOpt<'a> {
29     /// Transform a Rust libCEED BasisOpt into C libCEED CeedBasis
30     pub(crate) fn to_raw(self) -> bind_ceed::CeedBasis {
31         match self {
32             Self::Some(basis) => basis.ptr,
33             Self::Collocated => unsafe { bind_ceed::CEED_BASIS_COLLOCATED },
34         }
35     }
36 
37     /// Check if a BasisOpt is Some
38     ///
39     /// ```
40     /// # use libceed::prelude::*;
41     /// # fn main() -> libceed::Result<()> {
42     /// # let ceed = libceed::Ceed::default_init();
43     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
44     /// let b_opt = BasisOpt::from(&b);
45     /// assert!(b_opt.is_some(), "Incorrect BasisOpt");
46     ///
47     /// let b_opt = BasisOpt::Collocated;
48     /// assert!(!b_opt.is_some(), "Incorrect BasisOpt");
49     /// # Ok(())
50     /// # }
51     /// ```
52     pub fn is_some(&self) -> bool {
53         match self {
54             Self::Some(_) => true,
55             Self::Collocated => false,
56         }
57     }
58 
59     /// Check if a BasisOpt is Collocated
60     ///
61     /// ```
62     /// # use libceed::prelude::*;
63     /// # fn main() -> libceed::Result<()> {
64     /// # let ceed = libceed::Ceed::default_init();
65     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
66     /// let b_opt = BasisOpt::from(&b);
67     /// assert!(!b_opt.is_collocated(), "Incorrect BasisOpt");
68     ///
69     /// let b_opt = BasisOpt::Collocated;
70     /// assert!(b_opt.is_collocated(), "Incorrect BasisOpt");
71     /// # Ok(())
72     /// # }
73     /// ```
74     pub fn is_collocated(&self) -> bool {
75         match self {
76             Self::Some(_) => false,
77             Self::Collocated => true,
78         }
79     }
80 }
81 
82 // -----------------------------------------------------------------------------
83 // Basis context wrapper
84 // -----------------------------------------------------------------------------
85 #[derive(Debug)]
86 pub struct Basis<'a> {
87     pub(crate) ptr: bind_ceed::CeedBasis,
88     _lifeline: PhantomData<&'a ()>,
89 }
90 
91 // -----------------------------------------------------------------------------
92 // Destructor
93 // -----------------------------------------------------------------------------
94 impl<'a> Drop for Basis<'a> {
95     fn drop(&mut self) {
96         unsafe {
97             if self.ptr != bind_ceed::CEED_BASIS_COLLOCATED {
98                 bind_ceed::CeedBasisDestroy(&mut self.ptr);
99             }
100         }
101     }
102 }
103 
104 // -----------------------------------------------------------------------------
105 // Display
106 // -----------------------------------------------------------------------------
107 impl<'a> fmt::Display for Basis<'a> {
108     /// View a Basis
109     ///
110     /// ```
111     /// # use libceed::prelude::*;
112     /// # fn main() -> libceed::Result<()> {
113     /// # let ceed = libceed::Ceed::default_init();
114     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
115     /// println!("{}", b);
116     /// # Ok(())
117     /// # }
118     /// ```
119     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120         let mut ptr = std::ptr::null_mut();
121         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
122         let cstring = unsafe {
123             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
124             bind_ceed::CeedBasisView(self.ptr, file);
125             bind_ceed::fclose(file);
126             CString::from_raw(ptr)
127         };
128         cstring.to_string_lossy().fmt(f)
129     }
130 }
131 
132 // -----------------------------------------------------------------------------
133 // Implementations
134 // -----------------------------------------------------------------------------
135 impl<'a> Basis<'a> {
136     // Constructors
137     pub fn create_tensor_H1(
138         ceed: &crate::Ceed,
139         dim: usize,
140         ncomp: usize,
141         P1d: usize,
142         Q1d: usize,
143         interp1d: &[crate::Scalar],
144         grad1d: &[crate::Scalar],
145         qref1d: &[crate::Scalar],
146         qweight1d: &[crate::Scalar],
147     ) -> crate::Result<Self> {
148         let mut ptr = std::ptr::null_mut();
149         let (dim, ncomp, P1d, Q1d) = (
150             i32::try_from(dim).unwrap(),
151             i32::try_from(ncomp).unwrap(),
152             i32::try_from(P1d).unwrap(),
153             i32::try_from(Q1d).unwrap(),
154         );
155         let ierr = unsafe {
156             bind_ceed::CeedBasisCreateTensorH1(
157                 ceed.ptr,
158                 dim,
159                 ncomp,
160                 P1d,
161                 Q1d,
162                 interp1d.as_ptr(),
163                 grad1d.as_ptr(),
164                 qref1d.as_ptr(),
165                 qweight1d.as_ptr(),
166                 &mut ptr,
167             )
168         };
169         ceed.check_error(ierr)?;
170         Ok(Self {
171             ptr,
172             _lifeline: PhantomData,
173         })
174     }
175 
176     pub fn create_tensor_H1_Lagrange(
177         ceed: &crate::Ceed,
178         dim: usize,
179         ncomp: usize,
180         P: usize,
181         Q: usize,
182         qmode: crate::QuadMode,
183     ) -> crate::Result<Self> {
184         let mut ptr = std::ptr::null_mut();
185         let (dim, ncomp, P, Q, qmode) = (
186             i32::try_from(dim).unwrap(),
187             i32::try_from(ncomp).unwrap(),
188             i32::try_from(P).unwrap(),
189             i32::try_from(Q).unwrap(),
190             qmode as bind_ceed::CeedQuadMode,
191         );
192         let ierr = unsafe {
193             bind_ceed::CeedBasisCreateTensorH1Lagrange(ceed.ptr, dim, ncomp, P, Q, qmode, &mut ptr)
194         };
195         ceed.check_error(ierr)?;
196         Ok(Self {
197             ptr,
198             _lifeline: PhantomData,
199         })
200     }
201 
202     pub fn create_H1(
203         ceed: &crate::Ceed,
204         topo: crate::ElemTopology,
205         ncomp: usize,
206         nnodes: usize,
207         nqpts: usize,
208         interp: &[crate::Scalar],
209         grad: &[crate::Scalar],
210         qref: &[crate::Scalar],
211         qweight: &[crate::Scalar],
212     ) -> crate::Result<Self> {
213         let mut ptr = std::ptr::null_mut();
214         let (topo, ncomp, nnodes, nqpts) = (
215             topo as bind_ceed::CeedElemTopology,
216             i32::try_from(ncomp).unwrap(),
217             i32::try_from(nnodes).unwrap(),
218             i32::try_from(nqpts).unwrap(),
219         );
220         let ierr = unsafe {
221             bind_ceed::CeedBasisCreateH1(
222                 ceed.ptr,
223                 topo,
224                 ncomp,
225                 nnodes,
226                 nqpts,
227                 interp.as_ptr(),
228                 grad.as_ptr(),
229                 qref.as_ptr(),
230                 qweight.as_ptr(),
231                 &mut ptr,
232             )
233         };
234         ceed.check_error(ierr)?;
235         Ok(Self {
236             ptr,
237             _lifeline: PhantomData,
238         })
239     }
240 
241     // Error handling
242     #[doc(hidden)]
243     fn check_error(&self, ierr: i32) -> crate::Result<i32> {
244         let mut ptr = std::ptr::null_mut();
245         unsafe {
246             bind_ceed::CeedBasisGetCeed(self.ptr, &mut ptr);
247         }
248         crate::check_error(ptr, ierr)
249     }
250 
251     /// Apply basis evaluation from nodes to quadrature points or vice versa
252     ///
253     /// * `nelem` - The number of elements to apply the basis evaluation to
254     /// * `tmode` - `TrasposeMode::NoTranspose` to evaluate from nodes to
255     ///               quadrature points, `TransposeMode::Transpose` to apply the
256     ///               transpose, mapping from quadrature points to nodes
257     /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp`
258     ///               to use interpolated values, `EvalMode::Grad` to use
259     ///               gradients, `EvalMode::Weight` to use quadrature weights
260     /// * `u`     - Input Vector
261     /// * `v`     - Output Vector
262     ///
263     /// ```
264     /// # use libceed::prelude::*;
265     /// # fn main() -> libceed::Result<()> {
266     /// # let ceed = libceed::Ceed::default_init();
267     /// const Q: usize = 6;
268     /// let bu = ceed.basis_tensor_H1_Lagrange(1, 1, Q, Q, QuadMode::GaussLobatto)?;
269     /// let bx = ceed.basis_tensor_H1_Lagrange(1, 1, 2, Q, QuadMode::Gauss)?;
270     ///
271     /// let x_corners = ceed.vector_from_slice(&[-1., 1.])?;
272     /// let mut x_qpts = ceed.vector(Q)?;
273     /// let mut x_nodes = ceed.vector(Q)?;
274     /// bx.apply(
275     ///     1,
276     ///     TransposeMode::NoTranspose,
277     ///     EvalMode::Interp,
278     ///     &x_corners,
279     ///     &mut x_nodes,
280     /// )?;
281     /// bu.apply(
282     ///     1,
283     ///     TransposeMode::NoTranspose,
284     ///     EvalMode::Interp,
285     ///     &x_nodes,
286     ///     &mut x_qpts,
287     /// )?;
288     ///
289     /// // Create function x^3 + 1 on Gauss Lobatto points
290     /// let mut u_arr = [0.; Q];
291     /// u_arr
292     ///     .iter_mut()
293     ///     .zip(x_nodes.view()?.iter())
294     ///     .for_each(|(u, x)| *u = x * x * x + 1.);
295     /// let u = ceed.vector_from_slice(&u_arr)?;
296     ///
297     /// // Map function to Gauss points
298     /// let mut v = ceed.vector(Q)?;
299     /// v.set_value(0.);
300     /// bu.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
301     ///
302     /// // Verify results
303     /// v.view()?
304     ///     .iter()
305     ///     .zip(x_qpts.view()?.iter())
306     ///     .for_each(|(v, x)| {
307     ///         let true_value = x * x * x + 1.;
308     ///         assert!(
309     ///             (*v - true_value).abs() < 10.0 * libceed::EPSILON,
310     ///             "Incorrect basis application"
311     ///         );
312     ///     });
313     /// # Ok(())
314     /// # }
315     /// ```
316     pub fn apply(
317         &self,
318         nelem: usize,
319         tmode: TransposeMode,
320         emode: EvalMode,
321         u: &Vector,
322         v: &mut Vector,
323     ) -> crate::Result<i32> {
324         let (nelem, tmode, emode) = (
325             i32::try_from(nelem).unwrap(),
326             tmode as bind_ceed::CeedTransposeMode,
327             emode as bind_ceed::CeedEvalMode,
328         );
329         let ierr =
330             unsafe { bind_ceed::CeedBasisApply(self.ptr, nelem, tmode, emode, u.ptr, v.ptr) };
331         self.check_error(ierr)
332     }
333 
334     /// Returns the dimension for given Basis
335     ///
336     /// ```
337     /// # use libceed::prelude::*;
338     /// # fn main() -> libceed::Result<()> {
339     /// # let ceed = libceed::Ceed::default_init();
340     /// let dim = 2;
341     /// let b = ceed.basis_tensor_H1_Lagrange(dim, 1, 3, 4, QuadMode::Gauss)?;
342     ///
343     /// let d = b.dimension();
344     /// assert_eq!(d, dim, "Incorrect dimension");
345     /// # Ok(())
346     /// # }
347     /// ```
348     pub fn dimension(&self) -> usize {
349         let mut dim = 0;
350         unsafe { bind_ceed::CeedBasisGetDimension(self.ptr, &mut dim) };
351         usize::try_from(dim).unwrap()
352     }
353 
354     /// Returns number of components for given Basis
355     ///
356     /// ```
357     /// # use libceed::prelude::*;
358     /// # fn main() -> libceed::Result<()> {
359     /// # let ceed = libceed::Ceed::default_init();
360     /// let ncomp = 2;
361     /// let b = ceed.basis_tensor_H1_Lagrange(1, ncomp, 3, 4, QuadMode::Gauss)?;
362     ///
363     /// let n = b.num_components();
364     /// assert_eq!(n, ncomp, "Incorrect number of components");
365     /// # Ok(())
366     /// # }
367     /// ```
368     pub fn num_components(&self) -> usize {
369         let mut ncomp = 0;
370         unsafe { bind_ceed::CeedBasisGetNumComponents(self.ptr, &mut ncomp) };
371         usize::try_from(ncomp).unwrap()
372     }
373 
374     /// Returns total number of nodes (in dim dimensions) of a Basis
375     ///
376     /// ```
377     /// # use libceed::prelude::*;
378     /// # fn main() -> libceed::Result<()> {
379     /// # let ceed = libceed::Ceed::default_init();
380     /// let p = 3;
381     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, p, 4, QuadMode::Gauss)?;
382     ///
383     /// let nnodes = b.num_nodes();
384     /// assert_eq!(nnodes, p * p, "Incorrect number of nodes");
385     /// # Ok(())
386     /// # }
387     /// ```
388     pub fn num_nodes(&self) -> usize {
389         let mut nnodes = 0;
390         unsafe { bind_ceed::CeedBasisGetNumNodes(self.ptr, &mut nnodes) };
391         usize::try_from(nnodes).unwrap()
392     }
393 
394     /// Returns total number of quadrature points (in dim dimensions) of a
395     /// Basis
396     ///
397     /// ```
398     /// # use libceed::prelude::*;
399     /// # fn main() -> libceed::Result<()> {
400     /// # let ceed = libceed::Ceed::default_init();
401     /// let q = 4;
402     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, 3, q, QuadMode::Gauss)?;
403     ///
404     /// let nqpts = b.num_quadrature_points();
405     /// assert_eq!(nqpts, q * q, "Incorrect number of quadrature points");
406     /// # Ok(())
407     /// # }
408     /// ```
409     pub fn num_quadrature_points(&self) -> usize {
410         let mut Q = 0;
411         unsafe {
412             bind_ceed::CeedBasisGetNumQuadraturePoints(self.ptr, &mut Q);
413         }
414         usize::try_from(Q).unwrap()
415     }
416 
417     /// Create projection from self to specified Basis.
418     ///
419     /// Both bases must have the same quadrature space. The input bases need not
420     /// be nested as function spaces; this interface solves a least squares
421     /// probles to find a representation in the `to` basis that agrees at
422     /// quadrature points with the origin basis. Since the bases need not be
423     /// Lagrange, the resulting projection "basis" will have empty quadrature
424     /// points and weights.
425     ///
426     /// ```
427     /// # use libceed::prelude::*;
428     /// # fn main() -> libceed::Result<()> {
429     /// # let ceed = libceed::Ceed::default_init();
430     /// let coarse = ceed.basis_tensor_H1_Lagrange(1, 1, 2, 3, QuadMode::Gauss)?;
431     /// let fine = ceed.basis_tensor_H1_Lagrange(1, 1, 3, 3, QuadMode::Gauss)?;
432     /// let proj = coarse.create_projection(&fine)?;
433     /// let u = ceed.vector_from_slice(&[1., 2.])?;
434     /// let mut v = ceed.vector(3)?;
435     /// proj.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
436     /// let expected = [1., 1.5, 2.];
437     /// for (a, b) in v.view()?.iter().zip(expected) {
438     ///     assert!(
439     ///         (a - b).abs() < 10.0 * libceed::EPSILON,
440     ///         "Incorrect projection of linear Lagrange to quadratic Lagrange"
441     ///     );
442     /// }
443     /// # Ok(())
444     /// # }
445     /// ```
446     pub fn create_projection(&self, to: &Self) -> crate::Result<Self> {
447         let mut ptr = std::ptr::null_mut();
448         let ierr = unsafe { bind_ceed::CeedBasisCreateProjection(self.ptr, to.ptr, &mut ptr) };
449         self.check_error(ierr)?;
450         Ok(Self {
451             ptr,
452             _lifeline: PhantomData,
453         })
454     }
455 }
456 
457 // -----------------------------------------------------------------------------
458