xref: /libCEED/rust/libceed/src/basis.rs (revision de84fe537ceda4ab4dffb70159a3c31945f74235)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 //! A Ceed Basis defines the discrete finite element basis and associated
9 //! quadrature rule.
10 
11 use crate::prelude::*;
12 
13 // -----------------------------------------------------------------------------
14 // Basis option
15 // -----------------------------------------------------------------------------
16 #[derive(Debug)]
17 pub enum BasisOpt<'a> {
18     Some(&'a Basis<'a>),
19     None,
20 }
21 /// Construct a BasisOpt reference from a Basis reference
22 impl<'a> From<&'a Basis<'_>> for BasisOpt<'a> {
23     fn from(basis: &'a Basis) -> Self {
24         debug_assert!(basis.ptr != unsafe { bind_ceed::CEED_BASIS_NONE });
25         Self::Some(basis)
26     }
27 }
28 impl<'a> BasisOpt<'a> {
29     /// Transform a Rust libCEED BasisOpt into C libCEED CeedBasis
30     pub(crate) fn to_raw(self) -> bind_ceed::CeedBasis {
31         match self {
32             Self::Some(basis) => basis.ptr,
33             Self::None => unsafe { bind_ceed::CEED_BASIS_NONE },
34         }
35     }
36 
37     /// Check if a BasisOpt is Some
38     ///
39     /// ```
40     /// # use libceed::prelude::*;
41     /// # fn main() -> libceed::Result<()> {
42     /// # let ceed = libceed::Ceed::default_init();
43     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
44     /// let b_opt = BasisOpt::from(&b);
45     /// assert!(b_opt.is_some(), "Incorrect BasisOpt");
46     ///
47     /// let b_opt = BasisOpt::None;
48     /// assert!(!b_opt.is_some(), "Incorrect BasisOpt");
49     /// # Ok(())
50     /// # }
51     /// ```
52     pub fn is_some(&self) -> bool {
53         match self {
54             Self::Some(_) => true,
55             Self::None => false,
56         }
57     }
58 
59     /// Check if a BasisOpt is None
60     ///
61     /// ```
62     /// # use libceed::prelude::*;
63     /// # fn main() -> libceed::Result<()> {
64     /// # let ceed = libceed::Ceed::default_init();
65     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
66     /// let b_opt = BasisOpt::from(&b);
67     /// assert!(!b_opt.is_none(), "Incorrect BasisOpt");
68     ///
69     /// let b_opt = BasisOpt::None;
70     /// assert!(b_opt.is_none(), "Incorrect BasisOpt");
71     /// # Ok(())
72     /// # }
73     /// ```
74     pub fn is_none(&self) -> bool {
75         match self {
76             Self::Some(_) => false,
77             Self::None => true,
78         }
79     }
80 }
81 
82 // -----------------------------------------------------------------------------
83 // Basis context wrapper
84 // -----------------------------------------------------------------------------
85 #[derive(Debug)]
86 pub struct Basis<'a> {
87     pub(crate) ptr: bind_ceed::CeedBasis,
88     _lifeline: PhantomData<&'a ()>,
89 }
90 
91 // -----------------------------------------------------------------------------
92 // Destructor
93 // -----------------------------------------------------------------------------
94 impl<'a> Drop for Basis<'a> {
95     fn drop(&mut self) {
96         unsafe {
97             if self.ptr != bind_ceed::CEED_BASIS_NONE {
98                 bind_ceed::CeedBasisDestroy(&mut self.ptr);
99             }
100         }
101     }
102 }
103 
104 // -----------------------------------------------------------------------------
105 // Display
106 // -----------------------------------------------------------------------------
107 impl<'a> fmt::Display for Basis<'a> {
108     /// View a Basis
109     ///
110     /// ```
111     /// # use libceed::prelude::*;
112     /// # fn main() -> libceed::Result<()> {
113     /// # let ceed = libceed::Ceed::default_init();
114     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
115     /// println!("{}", b);
116     /// # Ok(())
117     /// # }
118     /// ```
119     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120         let mut ptr = std::ptr::null_mut();
121         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
122         let cstring = unsafe {
123             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
124             bind_ceed::CeedBasisView(self.ptr, file);
125             bind_ceed::fclose(file);
126             CString::from_raw(ptr)
127         };
128         cstring.to_string_lossy().fmt(f)
129     }
130 }
131 
132 // -----------------------------------------------------------------------------
133 // Implementations
134 // -----------------------------------------------------------------------------
135 impl<'a> Basis<'a> {
136     // Constructors
137     pub fn create_tensor_H1(
138         ceed: &crate::Ceed,
139         dim: usize,
140         ncomp: usize,
141         P1d: usize,
142         Q1d: usize,
143         interp1d: &[crate::Scalar],
144         grad1d: &[crate::Scalar],
145         qref1d: &[crate::Scalar],
146         qweight1d: &[crate::Scalar],
147     ) -> crate::Result<Self> {
148         let mut ptr = std::ptr::null_mut();
149         let (dim, ncomp, P1d, Q1d) = (
150             i32::try_from(dim).unwrap(),
151             i32::try_from(ncomp).unwrap(),
152             i32::try_from(P1d).unwrap(),
153             i32::try_from(Q1d).unwrap(),
154         );
155         ceed.check_error(unsafe {
156             bind_ceed::CeedBasisCreateTensorH1(
157                 ceed.ptr,
158                 dim,
159                 ncomp,
160                 P1d,
161                 Q1d,
162                 interp1d.as_ptr(),
163                 grad1d.as_ptr(),
164                 qref1d.as_ptr(),
165                 qweight1d.as_ptr(),
166                 &mut ptr,
167             )
168         })?;
169         Ok(Self {
170             ptr,
171             _lifeline: PhantomData,
172         })
173     }
174 
175     pub(crate) fn from_raw(ptr: bind_ceed::CeedBasis) -> crate::Result<Self> {
176         Ok(Self {
177             ptr,
178             _lifeline: PhantomData,
179         })
180     }
181 
182     pub fn create_tensor_H1_Lagrange(
183         ceed: &crate::Ceed,
184         dim: usize,
185         ncomp: usize,
186         P: usize,
187         Q: usize,
188         qmode: crate::QuadMode,
189     ) -> crate::Result<Self> {
190         let mut ptr = std::ptr::null_mut();
191         let (dim, ncomp, P, Q, qmode) = (
192             i32::try_from(dim).unwrap(),
193             i32::try_from(ncomp).unwrap(),
194             i32::try_from(P).unwrap(),
195             i32::try_from(Q).unwrap(),
196             qmode as bind_ceed::CeedQuadMode,
197         );
198         ceed.check_error(unsafe {
199             bind_ceed::CeedBasisCreateTensorH1Lagrange(ceed.ptr, dim, ncomp, P, Q, qmode, &mut ptr)
200         })?;
201         Ok(Self {
202             ptr,
203             _lifeline: PhantomData,
204         })
205     }
206 
207     pub fn create_H1(
208         ceed: &crate::Ceed,
209         topo: crate::ElemTopology,
210         ncomp: usize,
211         nnodes: usize,
212         nqpts: usize,
213         interp: &[crate::Scalar],
214         grad: &[crate::Scalar],
215         qref: &[crate::Scalar],
216         qweight: &[crate::Scalar],
217     ) -> crate::Result<Self> {
218         let mut ptr = std::ptr::null_mut();
219         let (topo, ncomp, nnodes, nqpts) = (
220             topo as bind_ceed::CeedElemTopology,
221             i32::try_from(ncomp).unwrap(),
222             i32::try_from(nnodes).unwrap(),
223             i32::try_from(nqpts).unwrap(),
224         );
225         ceed.check_error(unsafe {
226             bind_ceed::CeedBasisCreateH1(
227                 ceed.ptr,
228                 topo,
229                 ncomp,
230                 nnodes,
231                 nqpts,
232                 interp.as_ptr(),
233                 grad.as_ptr(),
234                 qref.as_ptr(),
235                 qweight.as_ptr(),
236                 &mut ptr,
237             )
238         })?;
239         Ok(Self {
240             ptr,
241             _lifeline: PhantomData,
242         })
243     }
244 
245     pub fn create_Hdiv(
246         ceed: &crate::Ceed,
247         topo: crate::ElemTopology,
248         ncomp: usize,
249         nnodes: usize,
250         nqpts: usize,
251         interp: &[crate::Scalar],
252         div: &[crate::Scalar],
253         qref: &[crate::Scalar],
254         qweight: &[crate::Scalar],
255     ) -> crate::Result<Self> {
256         let mut ptr = std::ptr::null_mut();
257         let (topo, ncomp, nnodes, nqpts) = (
258             topo as bind_ceed::CeedElemTopology,
259             i32::try_from(ncomp).unwrap(),
260             i32::try_from(nnodes).unwrap(),
261             i32::try_from(nqpts).unwrap(),
262         );
263         ceed.check_error(unsafe {
264             bind_ceed::CeedBasisCreateHdiv(
265                 ceed.ptr,
266                 topo,
267                 ncomp,
268                 nnodes,
269                 nqpts,
270                 interp.as_ptr(),
271                 div.as_ptr(),
272                 qref.as_ptr(),
273                 qweight.as_ptr(),
274                 &mut ptr,
275             )
276         })?;
277         Ok(Self {
278             ptr,
279             _lifeline: PhantomData,
280         })
281     }
282 
283     pub fn create_Hcurl(
284         ceed: &crate::Ceed,
285         topo: crate::ElemTopology,
286         ncomp: usize,
287         nnodes: usize,
288         nqpts: usize,
289         interp: &[crate::Scalar],
290         curl: &[crate::Scalar],
291         qref: &[crate::Scalar],
292         qweight: &[crate::Scalar],
293     ) -> crate::Result<Self> {
294         let mut ptr = std::ptr::null_mut();
295         let (topo, ncomp, nnodes, nqpts) = (
296             topo as bind_ceed::CeedElemTopology,
297             i32::try_from(ncomp).unwrap(),
298             i32::try_from(nnodes).unwrap(),
299             i32::try_from(nqpts).unwrap(),
300         );
301         ceed.check_error(unsafe {
302             bind_ceed::CeedBasisCreateHcurl(
303                 ceed.ptr,
304                 topo,
305                 ncomp,
306                 nnodes,
307                 nqpts,
308                 interp.as_ptr(),
309                 curl.as_ptr(),
310                 qref.as_ptr(),
311                 qweight.as_ptr(),
312                 &mut ptr,
313             )
314         })?;
315         Ok(Self {
316             ptr,
317             _lifeline: PhantomData,
318         })
319     }
320 
321     // Raw Ceed for error handling
322     #[doc(hidden)]
323     fn ceed(&self) -> bind_ceed::Ceed {
324         unsafe { bind_ceed::CeedBasisReturnCeed(self.ptr) }
325     }
326 
327     // Error handling
328     #[doc(hidden)]
329     fn check_error(&self, ierr: i32) -> crate::Result<i32> {
330         crate::check_error(|| self.ceed(), ierr)
331     }
332 
333     /// Apply basis evaluation from nodes to quadrature points or vice versa
334     ///
335     /// * `nelem` - The number of elements to apply the basis evaluation to
336     /// * `tmode` - `TrasposeMode::NoTranspose` to evaluate from nodes to
337     ///               quadrature points, `TransposeMode::Transpose` to apply the
338     ///               transpose, mapping from quadrature points to nodes
339     /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp`
340     ///               to use interpolated values, `EvalMode::Grad` to use
341     ///               gradients, `EvalMode::Weight` to use quadrature weights
342     /// * `u`     - Input Vector
343     /// * `v`     - Output Vector
344     ///
345     /// ```
346     /// # use libceed::prelude::*;
347     /// # fn main() -> libceed::Result<()> {
348     /// # let ceed = libceed::Ceed::default_init();
349     /// const Q: usize = 6;
350     /// let bu = ceed.basis_tensor_H1_Lagrange(1, 1, Q, Q, QuadMode::GaussLobatto)?;
351     /// let bx = ceed.basis_tensor_H1_Lagrange(1, 1, 2, Q, QuadMode::Gauss)?;
352     ///
353     /// let x_corners = ceed.vector_from_slice(&[-1., 1.])?;
354     /// let mut x_qpts = ceed.vector(Q)?;
355     /// let mut x_nodes = ceed.vector(Q)?;
356     /// bx.apply(
357     ///     1,
358     ///     TransposeMode::NoTranspose,
359     ///     EvalMode::Interp,
360     ///     &x_corners,
361     ///     &mut x_nodes,
362     /// )?;
363     /// bu.apply(
364     ///     1,
365     ///     TransposeMode::NoTranspose,
366     ///     EvalMode::Interp,
367     ///     &x_nodes,
368     ///     &mut x_qpts,
369     /// )?;
370     ///
371     /// // Create function x^3 + 1 on Gauss Lobatto points
372     /// let mut u_arr = [0.; Q];
373     /// u_arr
374     ///     .iter_mut()
375     ///     .zip(x_nodes.view()?.iter())
376     ///     .for_each(|(u, x)| *u = x * x * x + 1.);
377     /// let u = ceed.vector_from_slice(&u_arr)?;
378     ///
379     /// // Map function to Gauss points
380     /// let mut v = ceed.vector(Q)?;
381     /// v.set_value(0.);
382     /// bu.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
383     ///
384     /// // Verify results
385     /// v.view()?
386     ///     .iter()
387     ///     .zip(x_qpts.view()?.iter())
388     ///     .for_each(|(v, x)| {
389     ///         let true_value = x * x * x + 1.;
390     ///         assert!(
391     ///             (*v - true_value).abs() < 10.0 * libceed::EPSILON,
392     ///             "Incorrect basis application"
393     ///         );
394     ///     });
395     /// # Ok(())
396     /// # }
397     /// ```
398     pub fn apply(
399         &self,
400         nelem: usize,
401         tmode: TransposeMode,
402         emode: EvalMode,
403         u: &Vector,
404         v: &mut Vector,
405     ) -> crate::Result<i32> {
406         let (nelem, tmode, emode) = (
407             i32::try_from(nelem).unwrap(),
408             tmode as bind_ceed::CeedTransposeMode,
409             emode as bind_ceed::CeedEvalMode,
410         );
411         self.check_error(unsafe {
412             bind_ceed::CeedBasisApply(self.ptr, nelem, tmode, emode, u.ptr, v.ptr)
413         })
414     }
415 
416     /// Returns the dimension for given Basis
417     ///
418     /// ```
419     /// # use libceed::prelude::*;
420     /// # fn main() -> libceed::Result<()> {
421     /// # let ceed = libceed::Ceed::default_init();
422     /// let dim = 2;
423     /// let b = ceed.basis_tensor_H1_Lagrange(dim, 1, 3, 4, QuadMode::Gauss)?;
424     ///
425     /// let d = b.dimension();
426     /// assert_eq!(d, dim, "Incorrect dimension");
427     /// # Ok(())
428     /// # }
429     /// ```
430     pub fn dimension(&self) -> usize {
431         let mut dim = 0;
432         unsafe { bind_ceed::CeedBasisGetDimension(self.ptr, &mut dim) };
433         usize::try_from(dim).unwrap()
434     }
435 
436     /// Returns number of components for given Basis
437     ///
438     /// ```
439     /// # use libceed::prelude::*;
440     /// # fn main() -> libceed::Result<()> {
441     /// # let ceed = libceed::Ceed::default_init();
442     /// let ncomp = 2;
443     /// let b = ceed.basis_tensor_H1_Lagrange(1, ncomp, 3, 4, QuadMode::Gauss)?;
444     ///
445     /// let n = b.num_components();
446     /// assert_eq!(n, ncomp, "Incorrect number of components");
447     /// # Ok(())
448     /// # }
449     /// ```
450     pub fn num_components(&self) -> usize {
451         let mut ncomp = 0;
452         unsafe { bind_ceed::CeedBasisGetNumComponents(self.ptr, &mut ncomp) };
453         usize::try_from(ncomp).unwrap()
454     }
455 
456     /// Returns total number of nodes (in dim dimensions) of a Basis
457     ///
458     /// ```
459     /// # use libceed::prelude::*;
460     /// # fn main() -> libceed::Result<()> {
461     /// # let ceed = libceed::Ceed::default_init();
462     /// let p = 3;
463     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, p, 4, QuadMode::Gauss)?;
464     ///
465     /// let nnodes = b.num_nodes();
466     /// assert_eq!(nnodes, p * p, "Incorrect number of nodes");
467     /// # Ok(())
468     /// # }
469     /// ```
470     pub fn num_nodes(&self) -> usize {
471         let mut nnodes = 0;
472         unsafe { bind_ceed::CeedBasisGetNumNodes(self.ptr, &mut nnodes) };
473         usize::try_from(nnodes).unwrap()
474     }
475 
476     /// Returns total number of quadrature points (in dim dimensions) of a
477     /// Basis
478     ///
479     /// ```
480     /// # use libceed::prelude::*;
481     /// # fn main() -> libceed::Result<()> {
482     /// # let ceed = libceed::Ceed::default_init();
483     /// let q = 4;
484     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, 3, q, QuadMode::Gauss)?;
485     ///
486     /// let nqpts = b.num_quadrature_points();
487     /// assert_eq!(nqpts, q * q, "Incorrect number of quadrature points");
488     /// # Ok(())
489     /// # }
490     /// ```
491     pub fn num_quadrature_points(&self) -> usize {
492         let mut Q = 0;
493         unsafe {
494             bind_ceed::CeedBasisGetNumQuadraturePoints(self.ptr, &mut Q);
495         }
496         usize::try_from(Q).unwrap()
497     }
498 
499     /// Create projection from self to specified Basis.
500     ///
501     /// Both bases must have the same quadrature space. The input bases need not
502     /// be nested as function spaces; this interface solves a least squares
503     /// problem to find a representation in the `to` basis that agrees at
504     /// quadrature points with the origin basis. Since the bases need not be
505     /// Lagrange, the resulting projection "basis" will have empty quadrature
506     /// points and weights.
507     ///
508     /// ```
509     /// # use libceed::prelude::*;
510     /// # fn main() -> libceed::Result<()> {
511     /// # let ceed = libceed::Ceed::default_init();
512     /// let coarse = ceed.basis_tensor_H1_Lagrange(1, 1, 2, 3, QuadMode::Gauss)?;
513     /// let fine = ceed.basis_tensor_H1_Lagrange(1, 1, 3, 3, QuadMode::Gauss)?;
514     /// let proj = coarse.create_projection(&fine)?;
515     /// let u = ceed.vector_from_slice(&[1., 2.])?;
516     /// let mut v = ceed.vector(3)?;
517     /// proj.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
518     /// let expected = [1., 1.5, 2.];
519     /// for (a, b) in v.view()?.iter().zip(expected) {
520     ///     assert!(
521     ///         (a - b).abs() < 10.0 * libceed::EPSILON,
522     ///         "Incorrect projection of linear Lagrange to quadratic Lagrange"
523     ///     );
524     /// }
525     /// # Ok(())
526     /// # }
527     /// ```
528     pub fn create_projection(&self, to: &Self) -> crate::Result<Self> {
529         let mut ptr = std::ptr::null_mut();
530         self.check_error(unsafe {
531             bind_ceed::CeedBasisCreateProjection(self.ptr, to.ptr, &mut ptr)
532         })?;
533         Ok(Self {
534             ptr,
535             _lifeline: PhantomData,
536         })
537     }
538 }
539 
540 // -----------------------------------------------------------------------------
541