xref: /libCEED/rust/libceed/src/basis.rs (revision 85938a6d1dd5e68e6deadca612b182f7422c5a77)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 //! A Ceed Basis defines the discrete finite element basis and associated
9 //! quadrature rule.
10 
11 use crate::prelude::*;
12 
13 // -----------------------------------------------------------------------------
14 // Basis option
15 // -----------------------------------------------------------------------------
16 #[derive(Debug)]
17 pub enum BasisOpt<'a> {
18     Some(&'a Basis<'a>),
19     None,
20 }
21 /// Construct a BasisOpt reference from a Basis reference
22 impl<'a> From<&'a Basis<'_>> for BasisOpt<'a> {
23     fn from(basis: &'a Basis) -> Self {
24         debug_assert!(basis.ptr != unsafe { bind_ceed::CEED_BASIS_NONE });
25         Self::Some(basis)
26     }
27 }
28 impl<'a> BasisOpt<'a> {
29     /// Transform a Rust libCEED BasisOpt into C libCEED CeedBasis
30     pub(crate) fn to_raw(self) -> bind_ceed::CeedBasis {
31         match self {
32             Self::Some(basis) => basis.ptr,
33             Self::None => unsafe { bind_ceed::CEED_BASIS_NONE },
34         }
35     }
36 
37     /// Check if a BasisOpt is Some
38     ///
39     /// ```
40     /// # use libceed::prelude::*;
41     /// # fn main() -> libceed::Result<()> {
42     /// # let ceed = libceed::Ceed::default_init();
43     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
44     /// let b_opt = BasisOpt::from(&b);
45     /// assert!(b_opt.is_some(), "Incorrect BasisOpt");
46     ///
47     /// let b_opt = BasisOpt::None;
48     /// assert!(!b_opt.is_some(), "Incorrect BasisOpt");
49     /// # Ok(())
50     /// # }
51     /// ```
52     pub fn is_some(&self) -> bool {
53         match self {
54             Self::Some(_) => true,
55             Self::None => false,
56         }
57     }
58 
59     /// Check if a BasisOpt is None
60     ///
61     /// ```
62     /// # use libceed::prelude::*;
63     /// # fn main() -> libceed::Result<()> {
64     /// # let ceed = libceed::Ceed::default_init();
65     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
66     /// let b_opt = BasisOpt::from(&b);
67     /// assert!(!b_opt.is_none(), "Incorrect BasisOpt");
68     ///
69     /// let b_opt = BasisOpt::None;
70     /// assert!(b_opt.is_none(), "Incorrect BasisOpt");
71     /// # Ok(())
72     /// # }
73     /// ```
74     pub fn is_none(&self) -> bool {
75         match self {
76             Self::Some(_) => false,
77             Self::None => true,
78         }
79     }
80 }
81 
82 // -----------------------------------------------------------------------------
83 // Basis context wrapper
84 // -----------------------------------------------------------------------------
85 #[derive(Debug)]
86 pub struct Basis<'a> {
87     pub(crate) ptr: bind_ceed::CeedBasis,
88     _lifeline: PhantomData<&'a ()>,
89 }
90 
91 // -----------------------------------------------------------------------------
92 // Destructor
93 // -----------------------------------------------------------------------------
94 impl<'a> Drop for Basis<'a> {
95     fn drop(&mut self) {
96         unsafe {
97             if self.ptr != bind_ceed::CEED_BASIS_NONE {
98                 bind_ceed::CeedBasisDestroy(&mut self.ptr);
99             }
100         }
101     }
102 }
103 
104 // -----------------------------------------------------------------------------
105 // Display
106 // -----------------------------------------------------------------------------
107 impl<'a> fmt::Display for Basis<'a> {
108     /// View a Basis
109     ///
110     /// ```
111     /// # use libceed::prelude::*;
112     /// # fn main() -> libceed::Result<()> {
113     /// # let ceed = libceed::Ceed::default_init();
114     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
115     /// println!("{}", b);
116     /// # Ok(())
117     /// # }
118     /// ```
119     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120         let mut ptr = std::ptr::null_mut();
121         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
122         let cstring = unsafe {
123             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
124             bind_ceed::CeedBasisView(self.ptr, file);
125             bind_ceed::fclose(file);
126             CString::from_raw(ptr)
127         };
128         cstring.to_string_lossy().fmt(f)
129     }
130 }
131 
132 // -----------------------------------------------------------------------------
133 // Implementations
134 // -----------------------------------------------------------------------------
135 impl<'a> Basis<'a> {
136     // Constructors
137     pub fn create_tensor_H1(
138         ceed: &crate::Ceed,
139         dim: usize,
140         ncomp: usize,
141         P1d: usize,
142         Q1d: usize,
143         interp1d: &[crate::Scalar],
144         grad1d: &[crate::Scalar],
145         qref1d: &[crate::Scalar],
146         qweight1d: &[crate::Scalar],
147     ) -> crate::Result<Self> {
148         let mut ptr = std::ptr::null_mut();
149         let (dim, ncomp, P1d, Q1d) = (
150             i32::try_from(dim).unwrap(),
151             i32::try_from(ncomp).unwrap(),
152             i32::try_from(P1d).unwrap(),
153             i32::try_from(Q1d).unwrap(),
154         );
155         let ierr = unsafe {
156             bind_ceed::CeedBasisCreateTensorH1(
157                 ceed.ptr,
158                 dim,
159                 ncomp,
160                 P1d,
161                 Q1d,
162                 interp1d.as_ptr(),
163                 grad1d.as_ptr(),
164                 qref1d.as_ptr(),
165                 qweight1d.as_ptr(),
166                 &mut ptr,
167             )
168         };
169         ceed.check_error(ierr)?;
170         Ok(Self {
171             ptr,
172             _lifeline: PhantomData,
173         })
174     }
175 
176     pub(crate) fn from_raw(ptr: bind_ceed::CeedBasis) -> crate::Result<Self> {
177         Ok(Self {
178             ptr,
179             _lifeline: PhantomData,
180         })
181     }
182 
183     pub fn create_tensor_H1_Lagrange(
184         ceed: &crate::Ceed,
185         dim: usize,
186         ncomp: usize,
187         P: usize,
188         Q: usize,
189         qmode: crate::QuadMode,
190     ) -> crate::Result<Self> {
191         let mut ptr = std::ptr::null_mut();
192         let (dim, ncomp, P, Q, qmode) = (
193             i32::try_from(dim).unwrap(),
194             i32::try_from(ncomp).unwrap(),
195             i32::try_from(P).unwrap(),
196             i32::try_from(Q).unwrap(),
197             qmode as bind_ceed::CeedQuadMode,
198         );
199         let ierr = unsafe {
200             bind_ceed::CeedBasisCreateTensorH1Lagrange(ceed.ptr, dim, ncomp, P, Q, qmode, &mut ptr)
201         };
202         ceed.check_error(ierr)?;
203         Ok(Self {
204             ptr,
205             _lifeline: PhantomData,
206         })
207     }
208 
209     pub fn create_H1(
210         ceed: &crate::Ceed,
211         topo: crate::ElemTopology,
212         ncomp: usize,
213         nnodes: usize,
214         nqpts: usize,
215         interp: &[crate::Scalar],
216         grad: &[crate::Scalar],
217         qref: &[crate::Scalar],
218         qweight: &[crate::Scalar],
219     ) -> crate::Result<Self> {
220         let mut ptr = std::ptr::null_mut();
221         let (topo, ncomp, nnodes, nqpts) = (
222             topo as bind_ceed::CeedElemTopology,
223             i32::try_from(ncomp).unwrap(),
224             i32::try_from(nnodes).unwrap(),
225             i32::try_from(nqpts).unwrap(),
226         );
227         let ierr = unsafe {
228             bind_ceed::CeedBasisCreateH1(
229                 ceed.ptr,
230                 topo,
231                 ncomp,
232                 nnodes,
233                 nqpts,
234                 interp.as_ptr(),
235                 grad.as_ptr(),
236                 qref.as_ptr(),
237                 qweight.as_ptr(),
238                 &mut ptr,
239             )
240         };
241         ceed.check_error(ierr)?;
242         Ok(Self {
243             ptr,
244             _lifeline: PhantomData,
245         })
246     }
247 
248     pub fn create_Hdiv(
249         ceed: &crate::Ceed,
250         topo: crate::ElemTopology,
251         ncomp: usize,
252         nnodes: usize,
253         nqpts: usize,
254         interp: &[crate::Scalar],
255         div: &[crate::Scalar],
256         qref: &[crate::Scalar],
257         qweight: &[crate::Scalar],
258     ) -> crate::Result<Self> {
259         let mut ptr = std::ptr::null_mut();
260         let (topo, ncomp, nnodes, nqpts) = (
261             topo as bind_ceed::CeedElemTopology,
262             i32::try_from(ncomp).unwrap(),
263             i32::try_from(nnodes).unwrap(),
264             i32::try_from(nqpts).unwrap(),
265         );
266         let ierr = unsafe {
267             bind_ceed::CeedBasisCreateHdiv(
268                 ceed.ptr,
269                 topo,
270                 ncomp,
271                 nnodes,
272                 nqpts,
273                 interp.as_ptr(),
274                 div.as_ptr(),
275                 qref.as_ptr(),
276                 qweight.as_ptr(),
277                 &mut ptr,
278             )
279         };
280         ceed.check_error(ierr)?;
281         Ok(Self {
282             ptr,
283             _lifeline: PhantomData,
284         })
285     }
286 
287     pub fn create_Hcurl(
288         ceed: &crate::Ceed,
289         topo: crate::ElemTopology,
290         ncomp: usize,
291         nnodes: usize,
292         nqpts: usize,
293         interp: &[crate::Scalar],
294         curl: &[crate::Scalar],
295         qref: &[crate::Scalar],
296         qweight: &[crate::Scalar],
297     ) -> crate::Result<Self> {
298         let mut ptr = std::ptr::null_mut();
299         let (topo, ncomp, nnodes, nqpts) = (
300             topo as bind_ceed::CeedElemTopology,
301             i32::try_from(ncomp).unwrap(),
302             i32::try_from(nnodes).unwrap(),
303             i32::try_from(nqpts).unwrap(),
304         );
305         let ierr = unsafe {
306             bind_ceed::CeedBasisCreateHcurl(
307                 ceed.ptr,
308                 topo,
309                 ncomp,
310                 nnodes,
311                 nqpts,
312                 interp.as_ptr(),
313                 curl.as_ptr(),
314                 qref.as_ptr(),
315                 qweight.as_ptr(),
316                 &mut ptr,
317             )
318         };
319         ceed.check_error(ierr)?;
320         Ok(Self {
321             ptr,
322             _lifeline: PhantomData,
323         })
324     }
325 
326     // Error handling
327     #[doc(hidden)]
328     fn check_error(&self, ierr: i32) -> crate::Result<i32> {
329         let mut ptr = std::ptr::null_mut();
330         unsafe {
331             bind_ceed::CeedBasisGetCeed(self.ptr, &mut ptr);
332         }
333         crate::check_error(ptr, ierr)
334     }
335 
336     /// Apply basis evaluation from nodes to quadrature points or vice versa
337     ///
338     /// * `nelem` - The number of elements to apply the basis evaluation to
339     /// * `tmode` - `TrasposeMode::NoTranspose` to evaluate from nodes to
340     ///               quadrature points, `TransposeMode::Transpose` to apply the
341     ///               transpose, mapping from quadrature points to nodes
342     /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp`
343     ///               to use interpolated values, `EvalMode::Grad` to use
344     ///               gradients, `EvalMode::Weight` to use quadrature weights
345     /// * `u`     - Input Vector
346     /// * `v`     - Output Vector
347     ///
348     /// ```
349     /// # use libceed::prelude::*;
350     /// # fn main() -> libceed::Result<()> {
351     /// # let ceed = libceed::Ceed::default_init();
352     /// const Q: usize = 6;
353     /// let bu = ceed.basis_tensor_H1_Lagrange(1, 1, Q, Q, QuadMode::GaussLobatto)?;
354     /// let bx = ceed.basis_tensor_H1_Lagrange(1, 1, 2, Q, QuadMode::Gauss)?;
355     ///
356     /// let x_corners = ceed.vector_from_slice(&[-1., 1.])?;
357     /// let mut x_qpts = ceed.vector(Q)?;
358     /// let mut x_nodes = ceed.vector(Q)?;
359     /// bx.apply(
360     ///     1,
361     ///     TransposeMode::NoTranspose,
362     ///     EvalMode::Interp,
363     ///     &x_corners,
364     ///     &mut x_nodes,
365     /// )?;
366     /// bu.apply(
367     ///     1,
368     ///     TransposeMode::NoTranspose,
369     ///     EvalMode::Interp,
370     ///     &x_nodes,
371     ///     &mut x_qpts,
372     /// )?;
373     ///
374     /// // Create function x^3 + 1 on Gauss Lobatto points
375     /// let mut u_arr = [0.; Q];
376     /// u_arr
377     ///     .iter_mut()
378     ///     .zip(x_nodes.view()?.iter())
379     ///     .for_each(|(u, x)| *u = x * x * x + 1.);
380     /// let u = ceed.vector_from_slice(&u_arr)?;
381     ///
382     /// // Map function to Gauss points
383     /// let mut v = ceed.vector(Q)?;
384     /// v.set_value(0.);
385     /// bu.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
386     ///
387     /// // Verify results
388     /// v.view()?
389     ///     .iter()
390     ///     .zip(x_qpts.view()?.iter())
391     ///     .for_each(|(v, x)| {
392     ///         let true_value = x * x * x + 1.;
393     ///         assert!(
394     ///             (*v - true_value).abs() < 10.0 * libceed::EPSILON,
395     ///             "Incorrect basis application"
396     ///         );
397     ///     });
398     /// # Ok(())
399     /// # }
400     /// ```
401     pub fn apply(
402         &self,
403         nelem: usize,
404         tmode: TransposeMode,
405         emode: EvalMode,
406         u: &Vector,
407         v: &mut Vector,
408     ) -> crate::Result<i32> {
409         let (nelem, tmode, emode) = (
410             i32::try_from(nelem).unwrap(),
411             tmode as bind_ceed::CeedTransposeMode,
412             emode as bind_ceed::CeedEvalMode,
413         );
414         let ierr =
415             unsafe { bind_ceed::CeedBasisApply(self.ptr, nelem, tmode, emode, u.ptr, v.ptr) };
416         self.check_error(ierr)
417     }
418 
419     /// Returns the dimension for given Basis
420     ///
421     /// ```
422     /// # use libceed::prelude::*;
423     /// # fn main() -> libceed::Result<()> {
424     /// # let ceed = libceed::Ceed::default_init();
425     /// let dim = 2;
426     /// let b = ceed.basis_tensor_H1_Lagrange(dim, 1, 3, 4, QuadMode::Gauss)?;
427     ///
428     /// let d = b.dimension();
429     /// assert_eq!(d, dim, "Incorrect dimension");
430     /// # Ok(())
431     /// # }
432     /// ```
433     pub fn dimension(&self) -> usize {
434         let mut dim = 0;
435         unsafe { bind_ceed::CeedBasisGetDimension(self.ptr, &mut dim) };
436         usize::try_from(dim).unwrap()
437     }
438 
439     /// Returns number of components for given Basis
440     ///
441     /// ```
442     /// # use libceed::prelude::*;
443     /// # fn main() -> libceed::Result<()> {
444     /// # let ceed = libceed::Ceed::default_init();
445     /// let ncomp = 2;
446     /// let b = ceed.basis_tensor_H1_Lagrange(1, ncomp, 3, 4, QuadMode::Gauss)?;
447     ///
448     /// let n = b.num_components();
449     /// assert_eq!(n, ncomp, "Incorrect number of components");
450     /// # Ok(())
451     /// # }
452     /// ```
453     pub fn num_components(&self) -> usize {
454         let mut ncomp = 0;
455         unsafe { bind_ceed::CeedBasisGetNumComponents(self.ptr, &mut ncomp) };
456         usize::try_from(ncomp).unwrap()
457     }
458 
459     /// Returns total number of nodes (in dim dimensions) of a Basis
460     ///
461     /// ```
462     /// # use libceed::prelude::*;
463     /// # fn main() -> libceed::Result<()> {
464     /// # let ceed = libceed::Ceed::default_init();
465     /// let p = 3;
466     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, p, 4, QuadMode::Gauss)?;
467     ///
468     /// let nnodes = b.num_nodes();
469     /// assert_eq!(nnodes, p * p, "Incorrect number of nodes");
470     /// # Ok(())
471     /// # }
472     /// ```
473     pub fn num_nodes(&self) -> usize {
474         let mut nnodes = 0;
475         unsafe { bind_ceed::CeedBasisGetNumNodes(self.ptr, &mut nnodes) };
476         usize::try_from(nnodes).unwrap()
477     }
478 
479     /// Returns total number of quadrature points (in dim dimensions) of a
480     /// Basis
481     ///
482     /// ```
483     /// # use libceed::prelude::*;
484     /// # fn main() -> libceed::Result<()> {
485     /// # let ceed = libceed::Ceed::default_init();
486     /// let q = 4;
487     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, 3, q, QuadMode::Gauss)?;
488     ///
489     /// let nqpts = b.num_quadrature_points();
490     /// assert_eq!(nqpts, q * q, "Incorrect number of quadrature points");
491     /// # Ok(())
492     /// # }
493     /// ```
494     pub fn num_quadrature_points(&self) -> usize {
495         let mut Q = 0;
496         unsafe {
497             bind_ceed::CeedBasisGetNumQuadraturePoints(self.ptr, &mut Q);
498         }
499         usize::try_from(Q).unwrap()
500     }
501 
502     /// Create projection from self to specified Basis.
503     ///
504     /// Both bases must have the same quadrature space. The input bases need not
505     /// be nested as function spaces; this interface solves a least squares
506     /// problem to find a representation in the `to` basis that agrees at
507     /// quadrature points with the origin basis. Since the bases need not be
508     /// Lagrange, the resulting projection "basis" will have empty quadrature
509     /// points and weights.
510     ///
511     /// ```
512     /// # use libceed::prelude::*;
513     /// # fn main() -> libceed::Result<()> {
514     /// # let ceed = libceed::Ceed::default_init();
515     /// let coarse = ceed.basis_tensor_H1_Lagrange(1, 1, 2, 3, QuadMode::Gauss)?;
516     /// let fine = ceed.basis_tensor_H1_Lagrange(1, 1, 3, 3, QuadMode::Gauss)?;
517     /// let proj = coarse.create_projection(&fine)?;
518     /// let u = ceed.vector_from_slice(&[1., 2.])?;
519     /// let mut v = ceed.vector(3)?;
520     /// proj.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
521     /// let expected = [1., 1.5, 2.];
522     /// for (a, b) in v.view()?.iter().zip(expected) {
523     ///     assert!(
524     ///         (a - b).abs() < 10.0 * libceed::EPSILON,
525     ///         "Incorrect projection of linear Lagrange to quadratic Lagrange"
526     ///     );
527     /// }
528     /// # Ok(())
529     /// # }
530     /// ```
531     pub fn create_projection(&self, to: &Self) -> crate::Result<Self> {
532         let mut ptr = std::ptr::null_mut();
533         let ierr = unsafe { bind_ceed::CeedBasisCreateProjection(self.ptr, to.ptr, &mut ptr) };
534         self.check_error(ierr)?;
535         Ok(Self {
536             ptr,
537             _lifeline: PhantomData,
538         })
539     }
540 }
541 
542 // -----------------------------------------------------------------------------
543