xref: /libCEED/rust/libceed/src/basis.rs (revision 11544396610b36de1cb2f0d18032eefe5c670568)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 //! A Ceed Basis defines the discrete finite element basis and associated
9 //! quadrature rule.
10 
11 use crate::prelude::*;
12 
13 // -----------------------------------------------------------------------------
14 // Basis option
15 // -----------------------------------------------------------------------------
16 #[derive(Debug)]
17 pub enum BasisOpt<'a> {
18     Some(&'a Basis<'a>),
19     None,
20 }
21 /// Construct a BasisOpt reference from a Basis reference
22 impl<'a> From<&'a Basis<'_>> for BasisOpt<'a> {
23     fn from(basis: &'a Basis) -> Self {
24         debug_assert!(basis.ptr != unsafe { bind_ceed::CEED_BASIS_NONE });
25         Self::Some(basis)
26     }
27 }
28 impl<'a> BasisOpt<'a> {
29     /// Transform a Rust libCEED BasisOpt into C libCEED CeedBasis
30     pub(crate) fn to_raw(self) -> bind_ceed::CeedBasis {
31         match self {
32             Self::Some(basis) => basis.ptr,
33             Self::None => unsafe { bind_ceed::CEED_BASIS_NONE },
34         }
35     }
36 
37     /// Check if a BasisOpt is Some
38     ///
39     /// ```
40     /// # use libceed::prelude::*;
41     /// # fn main() -> libceed::Result<()> {
42     /// # let ceed = libceed::Ceed::default_init();
43     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
44     /// let b_opt = BasisOpt::from(&b);
45     /// assert!(b_opt.is_some(), "Incorrect BasisOpt");
46     ///
47     /// let b_opt = BasisOpt::None;
48     /// assert!(!b_opt.is_some(), "Incorrect BasisOpt");
49     /// # Ok(())
50     /// # }
51     /// ```
52     pub fn is_some(&self) -> bool {
53         match self {
54             Self::Some(_) => true,
55             Self::None => false,
56         }
57     }
58 
59     /// Check if a BasisOpt is None
60     ///
61     /// ```
62     /// # use libceed::prelude::*;
63     /// # fn main() -> libceed::Result<()> {
64     /// # let ceed = libceed::Ceed::default_init();
65     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
66     /// let b_opt = BasisOpt::from(&b);
67     /// assert!(!b_opt.is_none(), "Incorrect BasisOpt");
68     ///
69     /// let b_opt = BasisOpt::None;
70     /// assert!(b_opt.is_none(), "Incorrect BasisOpt");
71     /// # Ok(())
72     /// # }
73     /// ```
74     pub fn is_none(&self) -> bool {
75         match self {
76             Self::Some(_) => false,
77             Self::None => true,
78         }
79     }
80 }
81 
82 // -----------------------------------------------------------------------------
83 // Basis context wrapper
84 // -----------------------------------------------------------------------------
85 #[derive(Debug)]
86 pub struct Basis<'a> {
87     pub(crate) ptr: bind_ceed::CeedBasis,
88     _lifeline: PhantomData<&'a ()>,
89 }
90 
91 // -----------------------------------------------------------------------------
92 // Destructor
93 // -----------------------------------------------------------------------------
94 impl<'a> Drop for Basis<'a> {
95     fn drop(&mut self) {
96         unsafe {
97             if self.ptr != bind_ceed::CEED_BASIS_NONE {
98                 bind_ceed::CeedBasisDestroy(&mut self.ptr);
99             }
100         }
101     }
102 }
103 
104 // -----------------------------------------------------------------------------
105 // Display
106 // -----------------------------------------------------------------------------
107 impl<'a> fmt::Display for Basis<'a> {
108     /// View a Basis
109     ///
110     /// ```
111     /// # use libceed::prelude::*;
112     /// # fn main() -> libceed::Result<()> {
113     /// # let ceed = libceed::Ceed::default_init();
114     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
115     /// println!("{}", b);
116     /// # Ok(())
117     /// # }
118     /// ```
119     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120         let mut ptr = std::ptr::null_mut();
121         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
122         let cstring = unsafe {
123             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
124             bind_ceed::CeedBasisView(self.ptr, file);
125             bind_ceed::fclose(file);
126             CString::from_raw(ptr)
127         };
128         cstring.to_string_lossy().fmt(f)
129     }
130 }
131 
132 // -----------------------------------------------------------------------------
133 // Implementations
134 // -----------------------------------------------------------------------------
135 impl<'a> Basis<'a> {
136     // Constructors
137     pub fn create_tensor_H1(
138         ceed: &crate::Ceed,
139         dim: usize,
140         ncomp: usize,
141         P1d: usize,
142         Q1d: usize,
143         interp1d: &[crate::Scalar],
144         grad1d: &[crate::Scalar],
145         qref1d: &[crate::Scalar],
146         qweight1d: &[crate::Scalar],
147     ) -> crate::Result<Self> {
148         let mut ptr = std::ptr::null_mut();
149         let (dim, ncomp, P1d, Q1d) = (
150             i32::try_from(dim).unwrap(),
151             i32::try_from(ncomp).unwrap(),
152             i32::try_from(P1d).unwrap(),
153             i32::try_from(Q1d).unwrap(),
154         );
155         let ierr = unsafe {
156             bind_ceed::CeedBasisCreateTensorH1(
157                 ceed.ptr,
158                 dim,
159                 ncomp,
160                 P1d,
161                 Q1d,
162                 interp1d.as_ptr(),
163                 grad1d.as_ptr(),
164                 qref1d.as_ptr(),
165                 qweight1d.as_ptr(),
166                 &mut ptr,
167             )
168         };
169         ceed.check_error(ierr)?;
170         Ok(Self {
171             ptr,
172             _lifeline: PhantomData,
173         })
174     }
175 
176     pub(crate) fn from_raw(ptr: bind_ceed::CeedBasis) -> crate::Result<Self> {
177         Ok(Self {
178             ptr,
179             _lifeline: PhantomData,
180         })
181     }
182 
183     pub fn create_tensor_H1_Lagrange(
184         ceed: &crate::Ceed,
185         dim: usize,
186         ncomp: usize,
187         P: usize,
188         Q: usize,
189         qmode: crate::QuadMode,
190     ) -> crate::Result<Self> {
191         let mut ptr = std::ptr::null_mut();
192         let (dim, ncomp, P, Q, qmode) = (
193             i32::try_from(dim).unwrap(),
194             i32::try_from(ncomp).unwrap(),
195             i32::try_from(P).unwrap(),
196             i32::try_from(Q).unwrap(),
197             qmode as bind_ceed::CeedQuadMode,
198         );
199         let ierr = unsafe {
200             bind_ceed::CeedBasisCreateTensorH1Lagrange(ceed.ptr, dim, ncomp, P, Q, qmode, &mut ptr)
201         };
202         ceed.check_error(ierr)?;
203         Ok(Self {
204             ptr,
205             _lifeline: PhantomData,
206         })
207     }
208 
209     pub fn create_H1(
210         ceed: &crate::Ceed,
211         topo: crate::ElemTopology,
212         ncomp: usize,
213         nnodes: usize,
214         nqpts: usize,
215         interp: &[crate::Scalar],
216         grad: &[crate::Scalar],
217         qref: &[crate::Scalar],
218         qweight: &[crate::Scalar],
219     ) -> crate::Result<Self> {
220         let mut ptr = std::ptr::null_mut();
221         let (topo, ncomp, nnodes, nqpts) = (
222             topo as bind_ceed::CeedElemTopology,
223             i32::try_from(ncomp).unwrap(),
224             i32::try_from(nnodes).unwrap(),
225             i32::try_from(nqpts).unwrap(),
226         );
227         let ierr = unsafe {
228             bind_ceed::CeedBasisCreateH1(
229                 ceed.ptr,
230                 topo,
231                 ncomp,
232                 nnodes,
233                 nqpts,
234                 interp.as_ptr(),
235                 grad.as_ptr(),
236                 qref.as_ptr(),
237                 qweight.as_ptr(),
238                 &mut ptr,
239             )
240         };
241         ceed.check_error(ierr)?;
242         Ok(Self {
243             ptr,
244             _lifeline: PhantomData,
245         })
246     }
247 
248     pub fn create_Hdiv(
249         ceed: &crate::Ceed,
250         topo: crate::ElemTopology,
251         ncomp: usize,
252         nnodes: usize,
253         nqpts: usize,
254         interp: &[crate::Scalar],
255         div: &[crate::Scalar],
256         qref: &[crate::Scalar],
257         qweight: &[crate::Scalar],
258     ) -> crate::Result<Self> {
259         let mut ptr = std::ptr::null_mut();
260         let (topo, ncomp, nnodes, nqpts) = (
261             topo as bind_ceed::CeedElemTopology,
262             i32::try_from(ncomp).unwrap(),
263             i32::try_from(nnodes).unwrap(),
264             i32::try_from(nqpts).unwrap(),
265         );
266         let ierr = unsafe {
267             bind_ceed::CeedBasisCreateHdiv(
268                 ceed.ptr,
269                 topo,
270                 ncomp,
271                 nnodes,
272                 nqpts,
273                 interp.as_ptr(),
274                 div.as_ptr(),
275                 qref.as_ptr(),
276                 qweight.as_ptr(),
277                 &mut ptr,
278             )
279         };
280         ceed.check_error(ierr)?;
281         Ok(Self {
282             ptr,
283             _lifeline: PhantomData,
284         })
285     }
286 
287     pub fn create_Hcurl(
288         ceed: &crate::Ceed,
289         topo: crate::ElemTopology,
290         ncomp: usize,
291         nnodes: usize,
292         nqpts: usize,
293         interp: &[crate::Scalar],
294         curl: &[crate::Scalar],
295         qref: &[crate::Scalar],
296         qweight: &[crate::Scalar],
297     ) -> crate::Result<Self> {
298         let mut ptr = std::ptr::null_mut();
299         let (topo, ncomp, nnodes, nqpts) = (
300             topo as bind_ceed::CeedElemTopology,
301             i32::try_from(ncomp).unwrap(),
302             i32::try_from(nnodes).unwrap(),
303             i32::try_from(nqpts).unwrap(),
304         );
305         let ierr = unsafe {
306             bind_ceed::CeedBasisCreateHcurl(
307                 ceed.ptr,
308                 topo,
309                 ncomp,
310                 nnodes,
311                 nqpts,
312                 interp.as_ptr(),
313                 curl.as_ptr(),
314                 qref.as_ptr(),
315                 qweight.as_ptr(),
316                 &mut ptr,
317             )
318         };
319         ceed.check_error(ierr)?;
320         Ok(Self {
321             ptr,
322             _lifeline: PhantomData,
323         })
324     }
325 
326     // Raw Ceed for error handling
327     #[doc(hidden)]
328     fn ceed(&self) -> bind_ceed::Ceed {
329         unsafe { bind_ceed::CeedBasisReturnCeed(self.ptr) }
330     }
331 
332     // Error handling
333     #[doc(hidden)]
334     fn check_error(&self, ierr: i32) -> crate::Result<i32> {
335         crate::check_error(|| self.ceed(), ierr)
336     }
337 
338     /// Apply basis evaluation from nodes to quadrature points or vice versa
339     ///
340     /// * `nelem` - The number of elements to apply the basis evaluation to
341     /// * `tmode` - `TrasposeMode::NoTranspose` to evaluate from nodes to
342     ///               quadrature points, `TransposeMode::Transpose` to apply the
343     ///               transpose, mapping from quadrature points to nodes
344     /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp`
345     ///               to use interpolated values, `EvalMode::Grad` to use
346     ///               gradients, `EvalMode::Weight` to use quadrature weights
347     /// * `u`     - Input Vector
348     /// * `v`     - Output Vector
349     ///
350     /// ```
351     /// # use libceed::prelude::*;
352     /// # fn main() -> libceed::Result<()> {
353     /// # let ceed = libceed::Ceed::default_init();
354     /// const Q: usize = 6;
355     /// let bu = ceed.basis_tensor_H1_Lagrange(1, 1, Q, Q, QuadMode::GaussLobatto)?;
356     /// let bx = ceed.basis_tensor_H1_Lagrange(1, 1, 2, Q, QuadMode::Gauss)?;
357     ///
358     /// let x_corners = ceed.vector_from_slice(&[-1., 1.])?;
359     /// let mut x_qpts = ceed.vector(Q)?;
360     /// let mut x_nodes = ceed.vector(Q)?;
361     /// bx.apply(
362     ///     1,
363     ///     TransposeMode::NoTranspose,
364     ///     EvalMode::Interp,
365     ///     &x_corners,
366     ///     &mut x_nodes,
367     /// )?;
368     /// bu.apply(
369     ///     1,
370     ///     TransposeMode::NoTranspose,
371     ///     EvalMode::Interp,
372     ///     &x_nodes,
373     ///     &mut x_qpts,
374     /// )?;
375     ///
376     /// // Create function x^3 + 1 on Gauss Lobatto points
377     /// let mut u_arr = [0.; Q];
378     /// u_arr
379     ///     .iter_mut()
380     ///     .zip(x_nodes.view()?.iter())
381     ///     .for_each(|(u, x)| *u = x * x * x + 1.);
382     /// let u = ceed.vector_from_slice(&u_arr)?;
383     ///
384     /// // Map function to Gauss points
385     /// let mut v = ceed.vector(Q)?;
386     /// v.set_value(0.);
387     /// bu.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
388     ///
389     /// // Verify results
390     /// v.view()?
391     ///     .iter()
392     ///     .zip(x_qpts.view()?.iter())
393     ///     .for_each(|(v, x)| {
394     ///         let true_value = x * x * x + 1.;
395     ///         assert!(
396     ///             (*v - true_value).abs() < 10.0 * libceed::EPSILON,
397     ///             "Incorrect basis application"
398     ///         );
399     ///     });
400     /// # Ok(())
401     /// # }
402     /// ```
403     pub fn apply(
404         &self,
405         nelem: usize,
406         tmode: TransposeMode,
407         emode: EvalMode,
408         u: &Vector,
409         v: &mut Vector,
410     ) -> crate::Result<i32> {
411         let (nelem, tmode, emode) = (
412             i32::try_from(nelem).unwrap(),
413             tmode as bind_ceed::CeedTransposeMode,
414             emode as bind_ceed::CeedEvalMode,
415         );
416         let ierr =
417             unsafe { bind_ceed::CeedBasisApply(self.ptr, nelem, tmode, emode, u.ptr, v.ptr) };
418         self.check_error(ierr)
419     }
420 
421     /// Returns the dimension for given Basis
422     ///
423     /// ```
424     /// # use libceed::prelude::*;
425     /// # fn main() -> libceed::Result<()> {
426     /// # let ceed = libceed::Ceed::default_init();
427     /// let dim = 2;
428     /// let b = ceed.basis_tensor_H1_Lagrange(dim, 1, 3, 4, QuadMode::Gauss)?;
429     ///
430     /// let d = b.dimension();
431     /// assert_eq!(d, dim, "Incorrect dimension");
432     /// # Ok(())
433     /// # }
434     /// ```
435     pub fn dimension(&self) -> usize {
436         let mut dim = 0;
437         unsafe { bind_ceed::CeedBasisGetDimension(self.ptr, &mut dim) };
438         usize::try_from(dim).unwrap()
439     }
440 
441     /// Returns number of components for given Basis
442     ///
443     /// ```
444     /// # use libceed::prelude::*;
445     /// # fn main() -> libceed::Result<()> {
446     /// # let ceed = libceed::Ceed::default_init();
447     /// let ncomp = 2;
448     /// let b = ceed.basis_tensor_H1_Lagrange(1, ncomp, 3, 4, QuadMode::Gauss)?;
449     ///
450     /// let n = b.num_components();
451     /// assert_eq!(n, ncomp, "Incorrect number of components");
452     /// # Ok(())
453     /// # }
454     /// ```
455     pub fn num_components(&self) -> usize {
456         let mut ncomp = 0;
457         unsafe { bind_ceed::CeedBasisGetNumComponents(self.ptr, &mut ncomp) };
458         usize::try_from(ncomp).unwrap()
459     }
460 
461     /// Returns total number of nodes (in dim dimensions) of a Basis
462     ///
463     /// ```
464     /// # use libceed::prelude::*;
465     /// # fn main() -> libceed::Result<()> {
466     /// # let ceed = libceed::Ceed::default_init();
467     /// let p = 3;
468     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, p, 4, QuadMode::Gauss)?;
469     ///
470     /// let nnodes = b.num_nodes();
471     /// assert_eq!(nnodes, p * p, "Incorrect number of nodes");
472     /// # Ok(())
473     /// # }
474     /// ```
475     pub fn num_nodes(&self) -> usize {
476         let mut nnodes = 0;
477         unsafe { bind_ceed::CeedBasisGetNumNodes(self.ptr, &mut nnodes) };
478         usize::try_from(nnodes).unwrap()
479     }
480 
481     /// Returns total number of quadrature points (in dim dimensions) of a
482     /// Basis
483     ///
484     /// ```
485     /// # use libceed::prelude::*;
486     /// # fn main() -> libceed::Result<()> {
487     /// # let ceed = libceed::Ceed::default_init();
488     /// let q = 4;
489     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, 3, q, QuadMode::Gauss)?;
490     ///
491     /// let nqpts = b.num_quadrature_points();
492     /// assert_eq!(nqpts, q * q, "Incorrect number of quadrature points");
493     /// # Ok(())
494     /// # }
495     /// ```
496     pub fn num_quadrature_points(&self) -> usize {
497         let mut Q = 0;
498         unsafe {
499             bind_ceed::CeedBasisGetNumQuadraturePoints(self.ptr, &mut Q);
500         }
501         usize::try_from(Q).unwrap()
502     }
503 
504     /// Create projection from self to specified Basis.
505     ///
506     /// Both bases must have the same quadrature space. The input bases need not
507     /// be nested as function spaces; this interface solves a least squares
508     /// problem to find a representation in the `to` basis that agrees at
509     /// quadrature points with the origin basis. Since the bases need not be
510     /// Lagrange, the resulting projection "basis" will have empty quadrature
511     /// points and weights.
512     ///
513     /// ```
514     /// # use libceed::prelude::*;
515     /// # fn main() -> libceed::Result<()> {
516     /// # let ceed = libceed::Ceed::default_init();
517     /// let coarse = ceed.basis_tensor_H1_Lagrange(1, 1, 2, 3, QuadMode::Gauss)?;
518     /// let fine = ceed.basis_tensor_H1_Lagrange(1, 1, 3, 3, QuadMode::Gauss)?;
519     /// let proj = coarse.create_projection(&fine)?;
520     /// let u = ceed.vector_from_slice(&[1., 2.])?;
521     /// let mut v = ceed.vector(3)?;
522     /// proj.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
523     /// let expected = [1., 1.5, 2.];
524     /// for (a, b) in v.view()?.iter().zip(expected) {
525     ///     assert!(
526     ///         (a - b).abs() < 10.0 * libceed::EPSILON,
527     ///         "Incorrect projection of linear Lagrange to quadratic Lagrange"
528     ///     );
529     /// }
530     /// # Ok(())
531     /// # }
532     /// ```
533     pub fn create_projection(&self, to: &Self) -> crate::Result<Self> {
534         let mut ptr = std::ptr::null_mut();
535         let ierr = unsafe { bind_ceed::CeedBasisCreateProjection(self.ptr, to.ptr, &mut ptr) };
536         self.check_error(ierr)?;
537         Ok(Self {
538             ptr,
539             _lifeline: PhantomData,
540         })
541     }
542 }
543 
544 // -----------------------------------------------------------------------------
545