xref: /libCEED/rust/libceed/src/basis.rs (revision a3b195ef6dd39c849072dd5df2f934c50a4df099)
1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 //! A Ceed Basis defines the discrete finite element basis and associated
9 //! quadrature rule.
10 
11 use crate::{prelude::*, vector::Vector, EvalMode, TransposeMode};
12 
13 // -----------------------------------------------------------------------------
14 // Basis option
15 // -----------------------------------------------------------------------------
16 #[derive(Debug)]
17 pub enum BasisOpt<'a> {
18     Some(&'a Basis<'a>),
19     None,
20 }
21 /// Construct a BasisOpt reference from a Basis reference
22 impl<'a> From<&'a Basis<'_>> for BasisOpt<'a> {
23     fn from(basis: &'a Basis) -> Self {
24         debug_assert!(basis.ptr != unsafe { bind_ceed::CEED_BASIS_NONE });
25         Self::Some(basis)
26     }
27 }
28 impl<'a> BasisOpt<'a> {
29     /// Transform a Rust libCEED BasisOpt into C libCEED CeedBasis
30     pub(crate) fn to_raw(&self) -> bind_ceed::CeedBasis {
31         match self {
32             Self::Some(basis) => basis.ptr,
33             Self::None => unsafe { bind_ceed::CEED_BASIS_NONE },
34         }
35     }
36 
37     /// Check if a BasisOpt is Some
38     ///
39     /// ```
40     /// # use libceed::{prelude::*, BasisOpt, QuadMode};
41     /// # fn main() -> libceed::Result<()> {
42     /// # let ceed = libceed::Ceed::default_init();
43     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
44     /// let b_opt = BasisOpt::from(&b);
45     /// assert!(b_opt.is_some(), "Incorrect BasisOpt");
46     ///
47     /// let b_opt = BasisOpt::None;
48     /// assert!(!b_opt.is_some(), "Incorrect BasisOpt");
49     /// # Ok(())
50     /// # }
51     /// ```
52     pub fn is_some(&self) -> bool {
53         match self {
54             Self::Some(_) => true,
55             Self::None => false,
56         }
57     }
58 
59     /// Check if a BasisOpt is None
60     ///
61     /// ```
62     /// # use libceed::{prelude::*, BasisOpt, QuadMode};
63     /// # fn main() -> libceed::Result<()> {
64     /// # let ceed = libceed::Ceed::default_init();
65     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
66     /// let b_opt = BasisOpt::from(&b);
67     /// assert!(!b_opt.is_none(), "Incorrect BasisOpt");
68     ///
69     /// let b_opt = BasisOpt::None;
70     /// assert!(b_opt.is_none(), "Incorrect BasisOpt");
71     /// # Ok(())
72     /// # }
73     /// ```
74     pub fn is_none(&self) -> bool {
75         match self {
76             Self::Some(_) => false,
77             Self::None => true,
78         }
79     }
80 }
81 
82 // -----------------------------------------------------------------------------
83 // Basis context wrapper
84 // -----------------------------------------------------------------------------
85 #[derive(Debug)]
86 pub struct Basis<'a> {
87     pub(crate) ptr: bind_ceed::CeedBasis,
88     _lifeline: PhantomData<&'a ()>,
89 }
90 
91 // -----------------------------------------------------------------------------
92 // Destructor
93 // -----------------------------------------------------------------------------
94 impl<'a> Drop for Basis<'a> {
95     fn drop(&mut self) {
96         unsafe {
97             if self.ptr != bind_ceed::CEED_BASIS_NONE {
98                 bind_ceed::CeedBasisDestroy(&mut self.ptr);
99             }
100         }
101     }
102 }
103 
104 // -----------------------------------------------------------------------------
105 // Display
106 // -----------------------------------------------------------------------------
107 impl<'a> fmt::Display for Basis<'a> {
108     /// View a Basis
109     ///
110     /// ```
111     /// # use libceed::{prelude::*, QuadMode};
112     /// # fn main() -> libceed::Result<()> {
113     /// # let ceed = libceed::Ceed::default_init();
114     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
115     /// println!("{}", b);
116     /// # Ok(())
117     /// # }
118     /// ```
119     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120         let mut ptr = std::ptr::null_mut();
121         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
122         let cstring = unsafe {
123             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
124             bind_ceed::CeedBasisView(self.ptr, file);
125             bind_ceed::fclose(file);
126             CString::from_raw(ptr)
127         };
128         cstring.to_string_lossy().fmt(f)
129     }
130 }
131 
132 // -----------------------------------------------------------------------------
133 // Implementations
134 // -----------------------------------------------------------------------------
135 impl<'a> Basis<'a> {
136     // Constructors
137     #[allow(clippy::too_many_arguments)]
138     pub fn create_tensor_H1(
139         ceed: &crate::Ceed,
140         dim: usize,
141         ncomp: usize,
142         P1d: usize,
143         Q1d: usize,
144         interp1d: &[crate::Scalar],
145         grad1d: &[crate::Scalar],
146         qref1d: &[crate::Scalar],
147         qweight1d: &[crate::Scalar],
148     ) -> crate::Result<Self> {
149         let mut ptr = std::ptr::null_mut();
150         let (dim, ncomp, P1d, Q1d) = (
151             i32::try_from(dim).unwrap(),
152             i32::try_from(ncomp).unwrap(),
153             i32::try_from(P1d).unwrap(),
154             i32::try_from(Q1d).unwrap(),
155         );
156         ceed.check_error(unsafe {
157             bind_ceed::CeedBasisCreateTensorH1(
158                 ceed.ptr,
159                 dim,
160                 ncomp,
161                 P1d,
162                 Q1d,
163                 interp1d.as_ptr(),
164                 grad1d.as_ptr(),
165                 qref1d.as_ptr(),
166                 qweight1d.as_ptr(),
167                 &mut ptr,
168             )
169         })?;
170         Ok(Self {
171             ptr,
172             _lifeline: PhantomData,
173         })
174     }
175 
176     pub(crate) unsafe fn from_raw(ptr: bind_ceed::CeedBasis) -> crate::Result<Self> {
177         Ok(Self {
178             ptr,
179             _lifeline: PhantomData,
180         })
181     }
182 
183     pub fn create_tensor_H1_Lagrange(
184         ceed: &crate::Ceed,
185         dim: usize,
186         ncomp: usize,
187         P: usize,
188         Q: usize,
189         qmode: crate::QuadMode,
190     ) -> crate::Result<Self> {
191         let mut ptr = std::ptr::null_mut();
192         let (dim, ncomp, P, Q, qmode) = (
193             i32::try_from(dim).unwrap(),
194             i32::try_from(ncomp).unwrap(),
195             i32::try_from(P).unwrap(),
196             i32::try_from(Q).unwrap(),
197             qmode as bind_ceed::CeedQuadMode,
198         );
199         ceed.check_error(unsafe {
200             bind_ceed::CeedBasisCreateTensorH1Lagrange(ceed.ptr, dim, ncomp, P, Q, qmode, &mut ptr)
201         })?;
202         Ok(Self {
203             ptr,
204             _lifeline: PhantomData,
205         })
206     }
207 
208     #[allow(clippy::too_many_arguments)]
209     pub fn create_H1(
210         ceed: &crate::Ceed,
211         topo: crate::ElemTopology,
212         ncomp: usize,
213         nnodes: usize,
214         nqpts: usize,
215         interp: &[crate::Scalar],
216         grad: &[crate::Scalar],
217         qref: &[crate::Scalar],
218         qweight: &[crate::Scalar],
219     ) -> crate::Result<Self> {
220         let mut ptr = std::ptr::null_mut();
221         let (topo, ncomp, nnodes, nqpts) = (
222             topo as bind_ceed::CeedElemTopology,
223             i32::try_from(ncomp).unwrap(),
224             i32::try_from(nnodes).unwrap(),
225             i32::try_from(nqpts).unwrap(),
226         );
227         ceed.check_error(unsafe {
228             bind_ceed::CeedBasisCreateH1(
229                 ceed.ptr,
230                 topo,
231                 ncomp,
232                 nnodes,
233                 nqpts,
234                 interp.as_ptr(),
235                 grad.as_ptr(),
236                 qref.as_ptr(),
237                 qweight.as_ptr(),
238                 &mut ptr,
239             )
240         })?;
241         Ok(Self {
242             ptr,
243             _lifeline: PhantomData,
244         })
245     }
246 
247     #[allow(clippy::too_many_arguments)]
248     pub fn create_Hdiv(
249         ceed: &crate::Ceed,
250         topo: crate::ElemTopology,
251         ncomp: usize,
252         nnodes: usize,
253         nqpts: usize,
254         interp: &[crate::Scalar],
255         div: &[crate::Scalar],
256         qref: &[crate::Scalar],
257         qweight: &[crate::Scalar],
258     ) -> crate::Result<Self> {
259         let mut ptr = std::ptr::null_mut();
260         let (topo, ncomp, nnodes, nqpts) = (
261             topo as bind_ceed::CeedElemTopology,
262             i32::try_from(ncomp).unwrap(),
263             i32::try_from(nnodes).unwrap(),
264             i32::try_from(nqpts).unwrap(),
265         );
266         ceed.check_error(unsafe {
267             bind_ceed::CeedBasisCreateHdiv(
268                 ceed.ptr,
269                 topo,
270                 ncomp,
271                 nnodes,
272                 nqpts,
273                 interp.as_ptr(),
274                 div.as_ptr(),
275                 qref.as_ptr(),
276                 qweight.as_ptr(),
277                 &mut ptr,
278             )
279         })?;
280         Ok(Self {
281             ptr,
282             _lifeline: PhantomData,
283         })
284     }
285 
286     #[allow(clippy::too_many_arguments)]
287     pub fn create_Hcurl(
288         ceed: &crate::Ceed,
289         topo: crate::ElemTopology,
290         ncomp: usize,
291         nnodes: usize,
292         nqpts: usize,
293         interp: &[crate::Scalar],
294         curl: &[crate::Scalar],
295         qref: &[crate::Scalar],
296         qweight: &[crate::Scalar],
297     ) -> crate::Result<Self> {
298         let mut ptr = std::ptr::null_mut();
299         let (topo, ncomp, nnodes, nqpts) = (
300             topo as bind_ceed::CeedElemTopology,
301             i32::try_from(ncomp).unwrap(),
302             i32::try_from(nnodes).unwrap(),
303             i32::try_from(nqpts).unwrap(),
304         );
305         ceed.check_error(unsafe {
306             bind_ceed::CeedBasisCreateHcurl(
307                 ceed.ptr,
308                 topo,
309                 ncomp,
310                 nnodes,
311                 nqpts,
312                 interp.as_ptr(),
313                 curl.as_ptr(),
314                 qref.as_ptr(),
315                 qweight.as_ptr(),
316                 &mut ptr,
317             )
318         })?;
319         Ok(Self {
320             ptr,
321             _lifeline: PhantomData,
322         })
323     }
324 
325     // Raw Ceed for error handling
326     #[doc(hidden)]
327     fn ceed(&self) -> bind_ceed::Ceed {
328         unsafe { bind_ceed::CeedBasisReturnCeed(self.ptr) }
329     }
330 
331     // Error handling
332     #[doc(hidden)]
333     fn check_error(&self, ierr: i32) -> crate::Result<i32> {
334         crate::check_error(|| self.ceed(), ierr)
335     }
336 
337     /// Apply basis evaluation from nodes to quadrature points or vice versa
338     ///
339     /// * `nelem` - The number of elements to apply the basis evaluation to
340     /// * `tmode` - `TrasposeMode::NoTranspose` to evaluate from nodes to
341     ///               quadrature points, `TransposeMode::Transpose` to apply the
342     ///               transpose, mapping from quadrature points to nodes
343     /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp`
344     ///               to use interpolated values, `EvalMode::Grad` to use
345     ///               gradients, `EvalMode::Weight` to use quadrature weights
346     /// * `u`     - Input Vector
347     /// * `v`     - Output Vector
348     ///
349     /// ```
350     /// # use libceed::{prelude::*, EvalMode, TransposeMode, QuadMode};
351     /// # fn main() -> libceed::Result<()> {
352     /// # let ceed = libceed::Ceed::default_init();
353     /// const Q: usize = 6;
354     /// let bu = ceed.basis_tensor_H1_Lagrange(1, 1, Q, Q, QuadMode::GaussLobatto)?;
355     /// let bx = ceed.basis_tensor_H1_Lagrange(1, 1, 2, Q, QuadMode::Gauss)?;
356     ///
357     /// let x_corners = ceed.vector_from_slice(&[-1., 1.])?;
358     /// let mut x_qpts = ceed.vector(Q)?;
359     /// let mut x_nodes = ceed.vector(Q)?;
360     /// bx.apply(
361     ///     1,
362     ///     TransposeMode::NoTranspose,
363     ///     EvalMode::Interp,
364     ///     &x_corners,
365     ///     &mut x_nodes,
366     /// )?;
367     /// bu.apply(
368     ///     1,
369     ///     TransposeMode::NoTranspose,
370     ///     EvalMode::Interp,
371     ///     &x_nodes,
372     ///     &mut x_qpts,
373     /// )?;
374     ///
375     /// // Create function x^3 + 1 on Gauss Lobatto points
376     /// let mut u_arr = [0.; Q];
377     /// u_arr
378     ///     .iter_mut()
379     ///     .zip(x_nodes.view()?.iter())
380     ///     .for_each(|(u, x)| *u = x * x * x + 1.);
381     /// let u = ceed.vector_from_slice(&u_arr)?;
382     ///
383     /// // Map function to Gauss points
384     /// let mut v = ceed.vector(Q)?;
385     /// v.set_value(0.);
386     /// bu.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
387     ///
388     /// // Verify results
389     /// v.view()?
390     ///     .iter()
391     ///     .zip(x_qpts.view()?.iter())
392     ///     .for_each(|(v, x)| {
393     ///         let true_value = x * x * x + 1.;
394     ///         assert!(
395     ///             (*v - true_value).abs() < 10.0 * libceed::EPSILON,
396     ///             "Incorrect basis application"
397     ///         );
398     ///     });
399     /// # Ok(())
400     /// # }
401     /// ```
402     pub fn apply(
403         &self,
404         nelem: usize,
405         tmode: TransposeMode,
406         emode: EvalMode,
407         u: &Vector,
408         v: &mut Vector,
409     ) -> crate::Result<i32> {
410         let (nelem, tmode, emode) = (
411             i32::try_from(nelem).unwrap(),
412             tmode as bind_ceed::CeedTransposeMode,
413             emode as bind_ceed::CeedEvalMode,
414         );
415         self.check_error(unsafe {
416             bind_ceed::CeedBasisApply(self.ptr, nelem, tmode, emode, u.ptr, v.ptr)
417         })
418     }
419 
420     /// Returns the dimension for given Basis
421     ///
422     /// ```
423     /// # use libceed::{prelude::*, QuadMode};
424     /// # fn main() -> libceed::Result<()> {
425     /// # let ceed = libceed::Ceed::default_init();
426     /// let dim = 2;
427     /// let b = ceed.basis_tensor_H1_Lagrange(dim, 1, 3, 4, QuadMode::Gauss)?;
428     ///
429     /// let d = b.dimension();
430     /// assert_eq!(d, dim, "Incorrect dimension");
431     /// # Ok(())
432     /// # }
433     /// ```
434     pub fn dimension(&self) -> usize {
435         let mut dim = 0;
436         unsafe { bind_ceed::CeedBasisGetDimension(self.ptr, &mut dim) };
437         usize::try_from(dim).unwrap()
438     }
439 
440     /// Returns number of components for given Basis
441     ///
442     /// ```
443     /// # use libceed::{prelude::*, QuadMode};
444     /// # fn main() -> libceed::Result<()> {
445     /// # let ceed = libceed::Ceed::default_init();
446     /// let ncomp = 2;
447     /// let b = ceed.basis_tensor_H1_Lagrange(1, ncomp, 3, 4, QuadMode::Gauss)?;
448     ///
449     /// let n = b.num_components();
450     /// assert_eq!(n, ncomp, "Incorrect number of components");
451     /// # Ok(())
452     /// # }
453     /// ```
454     pub fn num_components(&self) -> usize {
455         let mut ncomp = 0;
456         unsafe { bind_ceed::CeedBasisGetNumComponents(self.ptr, &mut ncomp) };
457         usize::try_from(ncomp).unwrap()
458     }
459 
460     /// Returns total number of nodes (in dim dimensions) of a Basis
461     ///
462     /// ```
463     /// # use libceed::{prelude::*, QuadMode};
464     /// # fn main() -> libceed::Result<()> {
465     /// # let ceed = libceed::Ceed::default_init();
466     /// let p = 3;
467     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, p, 4, QuadMode::Gauss)?;
468     ///
469     /// let nnodes = b.num_nodes();
470     /// assert_eq!(nnodes, p * p, "Incorrect number of nodes");
471     /// # Ok(())
472     /// # }
473     /// ```
474     pub fn num_nodes(&self) -> usize {
475         let mut nnodes = 0;
476         unsafe { bind_ceed::CeedBasisGetNumNodes(self.ptr, &mut nnodes) };
477         usize::try_from(nnodes).unwrap()
478     }
479 
480     /// Returns total number of quadrature points (in dim dimensions) of a
481     /// Basis
482     ///
483     /// ```
484     /// # use libceed::{prelude::*, QuadMode};
485     /// # fn main() -> libceed::Result<()> {
486     /// # let ceed = libceed::Ceed::default_init();
487     /// let q = 4;
488     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, 3, q, QuadMode::Gauss)?;
489     ///
490     /// let nqpts = b.num_quadrature_points();
491     /// assert_eq!(nqpts, q * q, "Incorrect number of quadrature points");
492     /// # Ok(())
493     /// # }
494     /// ```
495     pub fn num_quadrature_points(&self) -> usize {
496         let mut Q = 0;
497         unsafe {
498             bind_ceed::CeedBasisGetNumQuadraturePoints(self.ptr, &mut Q);
499         }
500         usize::try_from(Q).unwrap()
501     }
502 
503     /// Create projection from self to specified Basis.
504     ///
505     /// Both bases must have the same quadrature space. The input bases need not
506     /// be nested as function spaces; this interface solves a least squares
507     /// problem to find a representation in the `to` basis that agrees at
508     /// quadrature points with the origin basis. Since the bases need not be
509     /// Lagrange, the resulting projection "basis" will have empty quadrature
510     /// points and weights.
511     ///
512     /// ```
513     /// # use libceed::{prelude::*, EvalMode, TransposeMode, QuadMode};
514     /// # fn main() -> libceed::Result<()> {
515     /// # let ceed = libceed::Ceed::default_init();
516     /// let coarse = ceed.basis_tensor_H1_Lagrange(1, 1, 2, 3, QuadMode::Gauss)?;
517     /// let fine = ceed.basis_tensor_H1_Lagrange(1, 1, 3, 3, QuadMode::Gauss)?;
518     /// let proj = coarse.create_projection(&fine)?;
519     /// let u = ceed.vector_from_slice(&[1., 2.])?;
520     /// let mut v = ceed.vector(3)?;
521     /// proj.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
522     /// let expected = [1., 1.5, 2.];
523     /// for (a, b) in v.view()?.iter().zip(expected) {
524     ///     assert!(
525     ///         (a - b).abs() < 10.0 * libceed::EPSILON,
526     ///         "Incorrect projection of linear Lagrange to quadratic Lagrange"
527     ///     );
528     /// }
529     /// # Ok(())
530     /// # }
531     /// ```
532     pub fn create_projection(&self, to: &Self) -> crate::Result<Self> {
533         let mut ptr = std::ptr::null_mut();
534         self.check_error(unsafe {
535             bind_ceed::CeedBasisCreateProjection(self.ptr, to.ptr, &mut ptr)
536         })?;
537         Ok(Self {
538             ptr,
539             _lifeline: PhantomData,
540         })
541     }
542 }
543 
544 // -----------------------------------------------------------------------------
545