xref: /libCEED/rust/libceed/src/basis.rs (revision 97c1c57a2210bceee970d471329a94c21686be34)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 //! A Ceed Basis defines the discrete finite element basis and associated
9 //! quadrature rule.
10 
11 use crate::prelude::*;
12 
13 // -----------------------------------------------------------------------------
14 // Basis option
15 // -----------------------------------------------------------------------------
16 #[derive(Debug)]
17 pub enum BasisOpt<'a> {
18     Some(&'a Basis<'a>),
19     Collocated,
20 }
21 /// Construct a BasisOpt reference from a Basis reference
22 impl<'a> From<&'a Basis<'_>> for BasisOpt<'a> {
23     fn from(basis: &'a Basis) -> Self {
24         debug_assert!(basis.ptr != unsafe { bind_ceed::CEED_BASIS_COLLOCATED });
25         Self::Some(basis)
26     }
27 }
28 impl<'a> BasisOpt<'a> {
29     /// Transform a Rust libCEED BasisOpt into C libCEED CeedBasis
30     pub(crate) fn to_raw(self) -> bind_ceed::CeedBasis {
31         match self {
32             Self::Some(basis) => basis.ptr,
33             Self::Collocated => unsafe { bind_ceed::CEED_BASIS_COLLOCATED },
34         }
35     }
36 
37     /// Check if a BasisOpt is Some
38     ///
39     /// ```
40     /// # use libceed::prelude::*;
41     /// # fn main() -> libceed::Result<()> {
42     /// # let ceed = libceed::Ceed::default_init();
43     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
44     /// let b_opt = BasisOpt::from(&b);
45     /// assert!(b_opt.is_some(), "Incorrect BasisOpt");
46     ///
47     /// let b_opt = BasisOpt::Collocated;
48     /// assert!(!b_opt.is_some(), "Incorrect BasisOpt");
49     /// # Ok(())
50     /// # }
51     /// ```
52     pub fn is_some(&self) -> bool {
53         match self {
54             Self::Some(_) => true,
55             Self::Collocated => false,
56         }
57     }
58 
59     /// Check if a BasisOpt is Collocated
60     ///
61     /// ```
62     /// # use libceed::prelude::*;
63     /// # fn main() -> libceed::Result<()> {
64     /// # let ceed = libceed::Ceed::default_init();
65     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
66     /// let b_opt = BasisOpt::from(&b);
67     /// assert!(!b_opt.is_collocated(), "Incorrect BasisOpt");
68     ///
69     /// let b_opt = BasisOpt::Collocated;
70     /// assert!(b_opt.is_collocated(), "Incorrect BasisOpt");
71     /// # Ok(())
72     /// # }
73     /// ```
74     pub fn is_collocated(&self) -> bool {
75         match self {
76             Self::Some(_) => false,
77             Self::Collocated => true,
78         }
79     }
80 }
81 
82 // -----------------------------------------------------------------------------
83 // Basis context wrapper
84 // -----------------------------------------------------------------------------
85 #[derive(Debug)]
86 pub struct Basis<'a> {
87     pub(crate) ptr: bind_ceed::CeedBasis,
88     _lifeline: PhantomData<&'a ()>,
89 }
90 
91 // -----------------------------------------------------------------------------
92 // Destructor
93 // -----------------------------------------------------------------------------
94 impl<'a> Drop for Basis<'a> {
95     fn drop(&mut self) {
96         unsafe {
97             if self.ptr != bind_ceed::CEED_BASIS_COLLOCATED {
98                 bind_ceed::CeedBasisDestroy(&mut self.ptr);
99             }
100         }
101     }
102 }
103 
104 // -----------------------------------------------------------------------------
105 // Display
106 // -----------------------------------------------------------------------------
107 impl<'a> fmt::Display for Basis<'a> {
108     /// View a Basis
109     ///
110     /// ```
111     /// # use libceed::prelude::*;
112     /// # fn main() -> libceed::Result<()> {
113     /// # let ceed = libceed::Ceed::default_init();
114     /// let b = ceed.basis_tensor_H1_Lagrange(1, 2, 3, 4, QuadMode::Gauss)?;
115     /// println!("{}", b);
116     /// # Ok(())
117     /// # }
118     /// ```
119     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
120         let mut ptr = std::ptr::null_mut();
121         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
122         let cstring = unsafe {
123             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
124             bind_ceed::CeedBasisView(self.ptr, file);
125             bind_ceed::fclose(file);
126             CString::from_raw(ptr)
127         };
128         cstring.to_string_lossy().fmt(f)
129     }
130 }
131 
132 // -----------------------------------------------------------------------------
133 // Implementations
134 // -----------------------------------------------------------------------------
135 impl<'a> Basis<'a> {
136     // Constructors
137     pub fn create_tensor_H1(
138         ceed: &crate::Ceed,
139         dim: usize,
140         ncomp: usize,
141         P1d: usize,
142         Q1d: usize,
143         interp1d: &[crate::Scalar],
144         grad1d: &[crate::Scalar],
145         qref1d: &[crate::Scalar],
146         qweight1d: &[crate::Scalar],
147     ) -> crate::Result<Self> {
148         let mut ptr = std::ptr::null_mut();
149         let (dim, ncomp, P1d, Q1d) = (
150             i32::try_from(dim).unwrap(),
151             i32::try_from(ncomp).unwrap(),
152             i32::try_from(P1d).unwrap(),
153             i32::try_from(Q1d).unwrap(),
154         );
155         let ierr = unsafe {
156             bind_ceed::CeedBasisCreateTensorH1(
157                 ceed.ptr,
158                 dim,
159                 ncomp,
160                 P1d,
161                 Q1d,
162                 interp1d.as_ptr(),
163                 grad1d.as_ptr(),
164                 qref1d.as_ptr(),
165                 qweight1d.as_ptr(),
166                 &mut ptr,
167             )
168         };
169         ceed.check_error(ierr)?;
170         Ok(Self {
171             ptr,
172             _lifeline: PhantomData,
173         })
174     }
175 
176     pub fn create_tensor_H1_Lagrange(
177         ceed: &crate::Ceed,
178         dim: usize,
179         ncomp: usize,
180         P: usize,
181         Q: usize,
182         qmode: crate::QuadMode,
183     ) -> crate::Result<Self> {
184         let mut ptr = std::ptr::null_mut();
185         let (dim, ncomp, P, Q, qmode) = (
186             i32::try_from(dim).unwrap(),
187             i32::try_from(ncomp).unwrap(),
188             i32::try_from(P).unwrap(),
189             i32::try_from(Q).unwrap(),
190             qmode as bind_ceed::CeedQuadMode,
191         );
192         let ierr = unsafe {
193             bind_ceed::CeedBasisCreateTensorH1Lagrange(ceed.ptr, dim, ncomp, P, Q, qmode, &mut ptr)
194         };
195         ceed.check_error(ierr)?;
196         Ok(Self {
197             ptr,
198             _lifeline: PhantomData,
199         })
200     }
201 
202     pub fn create_H1(
203         ceed: &crate::Ceed,
204         topo: crate::ElemTopology,
205         ncomp: usize,
206         nnodes: usize,
207         nqpts: usize,
208         interp: &[crate::Scalar],
209         grad: &[crate::Scalar],
210         qref: &[crate::Scalar],
211         qweight: &[crate::Scalar],
212     ) -> crate::Result<Self> {
213         let mut ptr = std::ptr::null_mut();
214         let (topo, ncomp, nnodes, nqpts) = (
215             topo as bind_ceed::CeedElemTopology,
216             i32::try_from(ncomp).unwrap(),
217             i32::try_from(nnodes).unwrap(),
218             i32::try_from(nqpts).unwrap(),
219         );
220         let ierr = unsafe {
221             bind_ceed::CeedBasisCreateH1(
222                 ceed.ptr,
223                 topo,
224                 ncomp,
225                 nnodes,
226                 nqpts,
227                 interp.as_ptr(),
228                 grad.as_ptr(),
229                 qref.as_ptr(),
230                 qweight.as_ptr(),
231                 &mut ptr,
232             )
233         };
234         ceed.check_error(ierr)?;
235         Ok(Self {
236             ptr,
237             _lifeline: PhantomData,
238         })
239     }
240 
241     pub fn create_Hdiv(
242         ceed: &crate::Ceed,
243         topo: crate::ElemTopology,
244         ncomp: usize,
245         nnodes: usize,
246         nqpts: usize,
247         interp: &[crate::Scalar],
248         div: &[crate::Scalar],
249         qref: &[crate::Scalar],
250         qweight: &[crate::Scalar],
251     ) -> crate::Result<Self> {
252         let mut ptr = std::ptr::null_mut();
253         let (topo, ncomp, nnodes, nqpts) = (
254             topo as bind_ceed::CeedElemTopology,
255             i32::try_from(ncomp).unwrap(),
256             i32::try_from(nnodes).unwrap(),
257             i32::try_from(nqpts).unwrap(),
258         );
259         let ierr = unsafe {
260             bind_ceed::CeedBasisCreateHdiv(
261                 ceed.ptr,
262                 topo,
263                 ncomp,
264                 nnodes,
265                 nqpts,
266                 interp.as_ptr(),
267                 div.as_ptr(),
268                 qref.as_ptr(),
269                 qweight.as_ptr(),
270                 &mut ptr,
271             )
272         };
273         ceed.check_error(ierr)?;
274         Ok(Self {
275             ptr,
276             _lifeline: PhantomData,
277         })
278     }
279 
280     pub fn create_Hcurl(
281         ceed: &crate::Ceed,
282         topo: crate::ElemTopology,
283         ncomp: usize,
284         nnodes: usize,
285         nqpts: usize,
286         interp: &[crate::Scalar],
287         curl: &[crate::Scalar],
288         qref: &[crate::Scalar],
289         qweight: &[crate::Scalar],
290     ) -> crate::Result<Self> {
291         let mut ptr = std::ptr::null_mut();
292         let (topo, ncomp, nnodes, nqpts) = (
293             topo as bind_ceed::CeedElemTopology,
294             i32::try_from(ncomp).unwrap(),
295             i32::try_from(nnodes).unwrap(),
296             i32::try_from(nqpts).unwrap(),
297         );
298         let ierr = unsafe {
299             bind_ceed::CeedBasisCreateHcurl(
300                 ceed.ptr,
301                 topo,
302                 ncomp,
303                 nnodes,
304                 nqpts,
305                 interp.as_ptr(),
306                 curl.as_ptr(),
307                 qref.as_ptr(),
308                 qweight.as_ptr(),
309                 &mut ptr,
310             )
311         };
312         ceed.check_error(ierr)?;
313         Ok(Self {
314             ptr,
315             _lifeline: PhantomData,
316         })
317     }
318 
319     // Error handling
320     #[doc(hidden)]
321     fn check_error(&self, ierr: i32) -> crate::Result<i32> {
322         let mut ptr = std::ptr::null_mut();
323         unsafe {
324             bind_ceed::CeedBasisGetCeed(self.ptr, &mut ptr);
325         }
326         crate::check_error(ptr, ierr)
327     }
328 
329     /// Apply basis evaluation from nodes to quadrature points or vice versa
330     ///
331     /// * `nelem` - The number of elements to apply the basis evaluation to
332     /// * `tmode` - `TrasposeMode::NoTranspose` to evaluate from nodes to
333     ///               quadrature points, `TransposeMode::Transpose` to apply the
334     ///               transpose, mapping from quadrature points to nodes
335     /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp`
336     ///               to use interpolated values, `EvalMode::Grad` to use
337     ///               gradients, `EvalMode::Weight` to use quadrature weights
338     /// * `u`     - Input Vector
339     /// * `v`     - Output Vector
340     ///
341     /// ```
342     /// # use libceed::prelude::*;
343     /// # fn main() -> libceed::Result<()> {
344     /// # let ceed = libceed::Ceed::default_init();
345     /// const Q: usize = 6;
346     /// let bu = ceed.basis_tensor_H1_Lagrange(1, 1, Q, Q, QuadMode::GaussLobatto)?;
347     /// let bx = ceed.basis_tensor_H1_Lagrange(1, 1, 2, Q, QuadMode::Gauss)?;
348     ///
349     /// let x_corners = ceed.vector_from_slice(&[-1., 1.])?;
350     /// let mut x_qpts = ceed.vector(Q)?;
351     /// let mut x_nodes = ceed.vector(Q)?;
352     /// bx.apply(
353     ///     1,
354     ///     TransposeMode::NoTranspose,
355     ///     EvalMode::Interp,
356     ///     &x_corners,
357     ///     &mut x_nodes,
358     /// )?;
359     /// bu.apply(
360     ///     1,
361     ///     TransposeMode::NoTranspose,
362     ///     EvalMode::Interp,
363     ///     &x_nodes,
364     ///     &mut x_qpts,
365     /// )?;
366     ///
367     /// // Create function x^3 + 1 on Gauss Lobatto points
368     /// let mut u_arr = [0.; Q];
369     /// u_arr
370     ///     .iter_mut()
371     ///     .zip(x_nodes.view()?.iter())
372     ///     .for_each(|(u, x)| *u = x * x * x + 1.);
373     /// let u = ceed.vector_from_slice(&u_arr)?;
374     ///
375     /// // Map function to Gauss points
376     /// let mut v = ceed.vector(Q)?;
377     /// v.set_value(0.);
378     /// bu.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
379     ///
380     /// // Verify results
381     /// v.view()?
382     ///     .iter()
383     ///     .zip(x_qpts.view()?.iter())
384     ///     .for_each(|(v, x)| {
385     ///         let true_value = x * x * x + 1.;
386     ///         assert!(
387     ///             (*v - true_value).abs() < 10.0 * libceed::EPSILON,
388     ///             "Incorrect basis application"
389     ///         );
390     ///     });
391     /// # Ok(())
392     /// # }
393     /// ```
394     pub fn apply(
395         &self,
396         nelem: usize,
397         tmode: TransposeMode,
398         emode: EvalMode,
399         u: &Vector,
400         v: &mut Vector,
401     ) -> crate::Result<i32> {
402         let (nelem, tmode, emode) = (
403             i32::try_from(nelem).unwrap(),
404             tmode as bind_ceed::CeedTransposeMode,
405             emode as bind_ceed::CeedEvalMode,
406         );
407         let ierr =
408             unsafe { bind_ceed::CeedBasisApply(self.ptr, nelem, tmode, emode, u.ptr, v.ptr) };
409         self.check_error(ierr)
410     }
411 
412     /// Returns the dimension for given Basis
413     ///
414     /// ```
415     /// # use libceed::prelude::*;
416     /// # fn main() -> libceed::Result<()> {
417     /// # let ceed = libceed::Ceed::default_init();
418     /// let dim = 2;
419     /// let b = ceed.basis_tensor_H1_Lagrange(dim, 1, 3, 4, QuadMode::Gauss)?;
420     ///
421     /// let d = b.dimension();
422     /// assert_eq!(d, dim, "Incorrect dimension");
423     /// # Ok(())
424     /// # }
425     /// ```
426     pub fn dimension(&self) -> usize {
427         let mut dim = 0;
428         unsafe { bind_ceed::CeedBasisGetDimension(self.ptr, &mut dim) };
429         usize::try_from(dim).unwrap()
430     }
431 
432     /// Returns number of components for given Basis
433     ///
434     /// ```
435     /// # use libceed::prelude::*;
436     /// # fn main() -> libceed::Result<()> {
437     /// # let ceed = libceed::Ceed::default_init();
438     /// let ncomp = 2;
439     /// let b = ceed.basis_tensor_H1_Lagrange(1, ncomp, 3, 4, QuadMode::Gauss)?;
440     ///
441     /// let n = b.num_components();
442     /// assert_eq!(n, ncomp, "Incorrect number of components");
443     /// # Ok(())
444     /// # }
445     /// ```
446     pub fn num_components(&self) -> usize {
447         let mut ncomp = 0;
448         unsafe { bind_ceed::CeedBasisGetNumComponents(self.ptr, &mut ncomp) };
449         usize::try_from(ncomp).unwrap()
450     }
451 
452     /// Returns total number of nodes (in dim dimensions) of a Basis
453     ///
454     /// ```
455     /// # use libceed::prelude::*;
456     /// # fn main() -> libceed::Result<()> {
457     /// # let ceed = libceed::Ceed::default_init();
458     /// let p = 3;
459     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, p, 4, QuadMode::Gauss)?;
460     ///
461     /// let nnodes = b.num_nodes();
462     /// assert_eq!(nnodes, p * p, "Incorrect number of nodes");
463     /// # Ok(())
464     /// # }
465     /// ```
466     pub fn num_nodes(&self) -> usize {
467         let mut nnodes = 0;
468         unsafe { bind_ceed::CeedBasisGetNumNodes(self.ptr, &mut nnodes) };
469         usize::try_from(nnodes).unwrap()
470     }
471 
472     /// Returns total number of quadrature points (in dim dimensions) of a
473     /// Basis
474     ///
475     /// ```
476     /// # use libceed::prelude::*;
477     /// # fn main() -> libceed::Result<()> {
478     /// # let ceed = libceed::Ceed::default_init();
479     /// let q = 4;
480     /// let b = ceed.basis_tensor_H1_Lagrange(2, 1, 3, q, QuadMode::Gauss)?;
481     ///
482     /// let nqpts = b.num_quadrature_points();
483     /// assert_eq!(nqpts, q * q, "Incorrect number of quadrature points");
484     /// # Ok(())
485     /// # }
486     /// ```
487     pub fn num_quadrature_points(&self) -> usize {
488         let mut Q = 0;
489         unsafe {
490             bind_ceed::CeedBasisGetNumQuadraturePoints(self.ptr, &mut Q);
491         }
492         usize::try_from(Q).unwrap()
493     }
494 
495     /// Create projection from self to specified Basis.
496     ///
497     /// Both bases must have the same quadrature space. The input bases need not
498     /// be nested as function spaces; this interface solves a least squares
499     /// problem to find a representation in the `to` basis that agrees at
500     /// quadrature points with the origin basis. Since the bases need not be
501     /// Lagrange, the resulting projection "basis" will have empty quadrature
502     /// points and weights.
503     ///
504     /// ```
505     /// # use libceed::prelude::*;
506     /// # fn main() -> libceed::Result<()> {
507     /// # let ceed = libceed::Ceed::default_init();
508     /// let coarse = ceed.basis_tensor_H1_Lagrange(1, 1, 2, 3, QuadMode::Gauss)?;
509     /// let fine = ceed.basis_tensor_H1_Lagrange(1, 1, 3, 3, QuadMode::Gauss)?;
510     /// let proj = coarse.create_projection(&fine)?;
511     /// let u = ceed.vector_from_slice(&[1., 2.])?;
512     /// let mut v = ceed.vector(3)?;
513     /// proj.apply(1, TransposeMode::NoTranspose, EvalMode::Interp, &u, &mut v)?;
514     /// let expected = [1., 1.5, 2.];
515     /// for (a, b) in v.view()?.iter().zip(expected) {
516     ///     assert!(
517     ///         (a - b).abs() < 10.0 * libceed::EPSILON,
518     ///         "Incorrect projection of linear Lagrange to quadratic Lagrange"
519     ///     );
520     /// }
521     /// # Ok(())
522     /// # }
523     /// ```
524     pub fn create_projection(&self, to: &Self) -> crate::Result<Self> {
525         let mut ptr = std::ptr::null_mut();
526         let ierr = unsafe { bind_ceed::CeedBasisCreateProjection(self.ptr, to.ptr, &mut ptr) };
527         self.check_error(ierr)?;
528         Ok(Self {
529             ptr,
530             _lifeline: PhantomData,
531         })
532     }
533 }
534 
535 // -----------------------------------------------------------------------------
536