xref: /libCEED/rust/libceed/src/qfunction.rs (revision 80a9ef0545a39c00cdcaab1ca26f8053604f3120)
19df49d7eSJed Brown // Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
29df49d7eSJed Brown // the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
39df49d7eSJed Brown // reserved. See files LICENSE and NOTICE for details.
49df49d7eSJed Brown //
59df49d7eSJed Brown // This file is part of CEED, a collection of benchmarks, miniapps, software
69df49d7eSJed Brown // libraries and APIs for efficient high-order finite element and spectral
79df49d7eSJed Brown // element discretizations for exascale applications. For more information and
89df49d7eSJed Brown // source code availability see http://github.com/ceed.
99df49d7eSJed Brown //
109df49d7eSJed Brown // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
119df49d7eSJed Brown // a collaborative effort of two U.S. Department of Energy organizations (Office
129df49d7eSJed Brown // of Science and the National Nuclear Security Administration) responsible for
139df49d7eSJed Brown // the planning and preparation of a capable exascale ecosystem, including
149df49d7eSJed Brown // software, applications, hardware, advanced system engineering and early
159df49d7eSJed Brown // testbed platforms, in support of the nation's exascale computing imperative
169df49d7eSJed Brown 
179df49d7eSJed Brown //! A Ceed QFunction represents the spatial terms of the point-wise functions
189df49d7eSJed Brown //! describing the physics at the quadrature points.
199df49d7eSJed Brown 
209df49d7eSJed Brown use std::pin::Pin;
219df49d7eSJed Brown 
229df49d7eSJed Brown use crate::prelude::*;
239df49d7eSJed Brown 
24*80a9ef05SNatalie Beams pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS];
25*80a9ef05SNatalie Beams pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS];
269df49d7eSJed Brown 
279df49d7eSJed Brown // -----------------------------------------------------------------------------
289df49d7eSJed Brown // CeedQFunction option
299df49d7eSJed Brown // -----------------------------------------------------------------------------
309df49d7eSJed Brown #[derive(Clone, Copy)]
319df49d7eSJed Brown pub enum QFunctionOpt<'a> {
329df49d7eSJed Brown     SomeQFunction(&'a QFunction<'a>),
339df49d7eSJed Brown     SomeQFunctionByName(&'a QFunctionByName<'a>),
349df49d7eSJed Brown     None,
359df49d7eSJed Brown }
369df49d7eSJed Brown 
379df49d7eSJed Brown /// Construct a QFunctionOpt reference from a QFunction reference
389df49d7eSJed Brown impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> {
399df49d7eSJed Brown     fn from(qfunc: &'a QFunction) -> Self {
409df49d7eSJed Brown         debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
419df49d7eSJed Brown         Self::SomeQFunction(qfunc)
429df49d7eSJed Brown     }
439df49d7eSJed Brown }
449df49d7eSJed Brown 
459df49d7eSJed Brown /// Construct a QFunctionOpt reference from a QFunction by Name reference
469df49d7eSJed Brown impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> {
479df49d7eSJed Brown     fn from(qfunc: &'a QFunctionByName) -> Self {
489df49d7eSJed Brown         debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
499df49d7eSJed Brown         Self::SomeQFunctionByName(qfunc)
509df49d7eSJed Brown     }
519df49d7eSJed Brown }
529df49d7eSJed Brown 
539df49d7eSJed Brown impl<'a> QFunctionOpt<'a> {
549df49d7eSJed Brown     /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction
559df49d7eSJed Brown     pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction {
569df49d7eSJed Brown         match self {
579df49d7eSJed Brown             Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr,
589df49d7eSJed Brown             Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr,
599df49d7eSJed Brown             Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE },
609df49d7eSJed Brown         }
619df49d7eSJed Brown     }
629df49d7eSJed Brown }
639df49d7eSJed Brown 
649df49d7eSJed Brown // -----------------------------------------------------------------------------
659df49d7eSJed Brown // CeedQFunction context wrapper
669df49d7eSJed Brown // -----------------------------------------------------------------------------
679df49d7eSJed Brown pub(crate) struct QFunctionCore<'a> {
689df49d7eSJed Brown     ceed: &'a crate::Ceed,
699df49d7eSJed Brown     ptr: bind_ceed::CeedQFunction,
709df49d7eSJed Brown }
719df49d7eSJed Brown 
729df49d7eSJed Brown struct QFunctionTrampolineData {
739df49d7eSJed Brown     number_inputs: usize,
749df49d7eSJed Brown     number_outputs: usize,
759df49d7eSJed Brown     input_sizes: [usize; MAX_QFUNCTION_FIELDS],
769df49d7eSJed Brown     output_sizes: [usize; MAX_QFUNCTION_FIELDS],
779df49d7eSJed Brown     user_f: Box<QFunctionUserClosure>,
789df49d7eSJed Brown }
799df49d7eSJed Brown 
809df49d7eSJed Brown pub struct QFunction<'a> {
819df49d7eSJed Brown     qf_core: QFunctionCore<'a>,
829df49d7eSJed Brown     qf_ctx_ptr: bind_ceed::CeedQFunctionContext,
839df49d7eSJed Brown     trampoline_data: Pin<Box<QFunctionTrampolineData>>,
849df49d7eSJed Brown }
859df49d7eSJed Brown 
869df49d7eSJed Brown pub struct QFunctionByName<'a> {
879df49d7eSJed Brown     qf_core: QFunctionCore<'a>,
889df49d7eSJed Brown }
899df49d7eSJed Brown 
909df49d7eSJed Brown // -----------------------------------------------------------------------------
919df49d7eSJed Brown // Destructor
929df49d7eSJed Brown // -----------------------------------------------------------------------------
939df49d7eSJed Brown impl<'a> Drop for QFunctionCore<'a> {
949df49d7eSJed Brown     fn drop(&mut self) {
959df49d7eSJed Brown         unsafe {
969df49d7eSJed Brown             if self.ptr != bind_ceed::CEED_QFUNCTION_NONE {
979df49d7eSJed Brown                 bind_ceed::CeedQFunctionDestroy(&mut self.ptr);
989df49d7eSJed Brown             }
999df49d7eSJed Brown         }
1009df49d7eSJed Brown     }
1019df49d7eSJed Brown }
1029df49d7eSJed Brown 
1039df49d7eSJed Brown impl<'a> Drop for QFunction<'a> {
1049df49d7eSJed Brown     fn drop(&mut self) {
1059df49d7eSJed Brown         unsafe {
1069df49d7eSJed Brown             bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr);
1079df49d7eSJed Brown         }
1089df49d7eSJed Brown     }
1099df49d7eSJed Brown }
1109df49d7eSJed Brown 
1119df49d7eSJed Brown // -----------------------------------------------------------------------------
1129df49d7eSJed Brown // Display
1139df49d7eSJed Brown // -----------------------------------------------------------------------------
1149df49d7eSJed Brown impl<'a> fmt::Display for QFunctionCore<'a> {
1159df49d7eSJed Brown     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1169df49d7eSJed Brown         let mut ptr = std::ptr::null_mut();
1179df49d7eSJed Brown         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
1189df49d7eSJed Brown         let cstring = unsafe {
1199df49d7eSJed Brown             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
1209df49d7eSJed Brown             bind_ceed::CeedQFunctionView(self.ptr, file);
1219df49d7eSJed Brown             bind_ceed::fclose(file);
1229df49d7eSJed Brown             CString::from_raw(ptr)
1239df49d7eSJed Brown         };
1249df49d7eSJed Brown         cstring.to_string_lossy().fmt(f)
1259df49d7eSJed Brown     }
1269df49d7eSJed Brown }
1279df49d7eSJed Brown /// View a QFunction
1289df49d7eSJed Brown ///
1299df49d7eSJed Brown /// ```
1309df49d7eSJed Brown /// # use libceed::prelude::*;
1319df49d7eSJed Brown /// # let ceed = libceed::Ceed::default_init();
1329df49d7eSJed Brown /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
1339df49d7eSJed Brown ///     // Iterate over quadrature points
1349df49d7eSJed Brown ///     v.iter_mut()
1359df49d7eSJed Brown ///         .zip(u.iter().zip(weights.iter()))
1369df49d7eSJed Brown ///         .for_each(|(v, (u, w))| *v = u * w);
1379df49d7eSJed Brown ///
1389df49d7eSJed Brown ///     // Return clean error code
1399df49d7eSJed Brown ///     0
1409df49d7eSJed Brown /// };
1419df49d7eSJed Brown ///
1429df49d7eSJed Brown /// let qf = ceed
1439df49d7eSJed Brown ///     .q_function_interior(1, Box::new(user_f))
1449df49d7eSJed Brown ///     .unwrap()
1459df49d7eSJed Brown ///     .input("u", 1, EvalMode::Interp)
1469df49d7eSJed Brown ///     .unwrap()
1479df49d7eSJed Brown ///     .input("weights", 1, EvalMode::Weight)
1489df49d7eSJed Brown ///     .unwrap()
1499df49d7eSJed Brown ///     .output("v", 1, EvalMode::Interp)
1509df49d7eSJed Brown ///     .unwrap();
1519df49d7eSJed Brown ///
1529df49d7eSJed Brown /// println!("{}", qf);
1539df49d7eSJed Brown /// ```
1549df49d7eSJed Brown impl<'a> fmt::Display for QFunction<'a> {
1559df49d7eSJed Brown     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1569df49d7eSJed Brown         self.qf_core.fmt(f)
1579df49d7eSJed Brown     }
1589df49d7eSJed Brown }
1599df49d7eSJed Brown 
1609df49d7eSJed Brown /// View a QFunction by Name
1619df49d7eSJed Brown ///
1629df49d7eSJed Brown /// ```
1639df49d7eSJed Brown /// # use libceed::prelude::*;
1649df49d7eSJed Brown /// # let ceed = libceed::Ceed::default_init();
1659df49d7eSJed Brown /// let qf = ceed.q_function_interior_by_name("Mass1DBuild").unwrap();
1669df49d7eSJed Brown /// println!("{}", qf);
1679df49d7eSJed Brown /// ```
1689df49d7eSJed Brown impl<'a> fmt::Display for QFunctionByName<'a> {
1699df49d7eSJed Brown     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1709df49d7eSJed Brown         self.qf_core.fmt(f)
1719df49d7eSJed Brown     }
1729df49d7eSJed Brown }
1739df49d7eSJed Brown 
1749df49d7eSJed Brown // -----------------------------------------------------------------------------
1759df49d7eSJed Brown // Core functionality
1769df49d7eSJed Brown // -----------------------------------------------------------------------------
1779df49d7eSJed Brown impl<'a> QFunctionCore<'a> {
1789df49d7eSJed Brown     // Common implementation
1799df49d7eSJed Brown     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
1809df49d7eSJed Brown         let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
1819df49d7eSJed Brown         for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) {
1829df49d7eSJed Brown             u_c[i] = u[i].ptr;
1839df49d7eSJed Brown         }
1849df49d7eSJed Brown         let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
1859df49d7eSJed Brown         for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) {
1869df49d7eSJed Brown             v_c[i] = v[i].ptr;
1879df49d7eSJed Brown         }
1889df49d7eSJed Brown         let Q = i32::try_from(Q).unwrap();
1899df49d7eSJed Brown         let ierr = unsafe {
1909df49d7eSJed Brown             bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr())
1919df49d7eSJed Brown         };
1929df49d7eSJed Brown         self.ceed.check_error(ierr)
1939df49d7eSJed Brown     }
1949df49d7eSJed Brown }
1959df49d7eSJed Brown 
1969df49d7eSJed Brown // -----------------------------------------------------------------------------
1979df49d7eSJed Brown // User QFunction Closure
1989df49d7eSJed Brown // -----------------------------------------------------------------------------
199*80a9ef05SNatalie Beams pub type QFunctionUserClosure = dyn FnMut(
200*80a9ef05SNatalie Beams     [&[crate::Scalar]; MAX_QFUNCTION_FIELDS],
201*80a9ef05SNatalie Beams     [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS],
202*80a9ef05SNatalie Beams ) -> i32;
2039df49d7eSJed Brown 
2049df49d7eSJed Brown macro_rules! mut_max_fields {
2059df49d7eSJed Brown     ($e:expr) => {
2069df49d7eSJed Brown         [
2079df49d7eSJed Brown             $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e,
2089df49d7eSJed Brown         ]
2099df49d7eSJed Brown     };
2109df49d7eSJed Brown }
2119df49d7eSJed Brown unsafe extern "C" fn trampoline(
2129df49d7eSJed Brown     ctx: *mut ::std::os::raw::c_void,
2139df49d7eSJed Brown     q: bind_ceed::CeedInt,
2149df49d7eSJed Brown     inputs: *const *const bind_ceed::CeedScalar,
2159df49d7eSJed Brown     outputs: *const *mut bind_ceed::CeedScalar,
2169df49d7eSJed Brown ) -> ::std::os::raw::c_int {
2179df49d7eSJed Brown     let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx);
2189df49d7eSJed Brown 
2199df49d7eSJed Brown     // Inputs
2209df49d7eSJed Brown     let inputs_slice: &[*const bind_ceed::CeedScalar] =
2219df49d7eSJed Brown         std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS);
222*80a9ef05SNatalie Beams     let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS];
2239df49d7eSJed Brown     inputs_slice
2249df49d7eSJed Brown         .iter()
2259df49d7eSJed Brown         .enumerate()
2269df49d7eSJed Brown         .map(|(i, &x)| {
227*80a9ef05SNatalie Beams             std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize)
228*80a9ef05SNatalie Beams                 as &[crate::Scalar]
2299df49d7eSJed Brown         })
2309df49d7eSJed Brown         .zip(inputs_array.iter_mut())
2319df49d7eSJed Brown         .for_each(|(x, a)| *a = x);
2329df49d7eSJed Brown 
2339df49d7eSJed Brown     // Outputs
2349df49d7eSJed Brown     let outputs_slice: &[*mut bind_ceed::CeedScalar] =
2359df49d7eSJed Brown         std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS);
236*80a9ef05SNatalie Beams     let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] =
237*80a9ef05SNatalie Beams         mut_max_fields!(&mut [0.0]);
2389df49d7eSJed Brown     outputs_slice
2399df49d7eSJed Brown         .iter()
2409df49d7eSJed Brown         .enumerate()
2419df49d7eSJed Brown         .map(|(i, &x)| {
2429df49d7eSJed Brown             std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize)
243*80a9ef05SNatalie Beams                 as &mut [crate::Scalar]
2449df49d7eSJed Brown         })
2459df49d7eSJed Brown         .zip(outputs_array.iter_mut())
2469df49d7eSJed Brown         .for_each(|(x, a)| *a = x);
2479df49d7eSJed Brown 
2489df49d7eSJed Brown     // User closure
2499df49d7eSJed Brown     (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array)
2509df49d7eSJed Brown }
2519df49d7eSJed Brown 
2529df49d7eSJed Brown // -----------------------------------------------------------------------------
2539df49d7eSJed Brown // QFunction
2549df49d7eSJed Brown // -----------------------------------------------------------------------------
2559df49d7eSJed Brown impl<'a> QFunction<'a> {
2569df49d7eSJed Brown     // Constructor
2579df49d7eSJed Brown     pub fn create(
2589df49d7eSJed Brown         ceed: &'a crate::Ceed,
2599df49d7eSJed Brown         vlength: usize,
2609df49d7eSJed Brown         user_f: Box<QFunctionUserClosure>,
2619df49d7eSJed Brown     ) -> crate::Result<Self> {
2629df49d7eSJed Brown         let source_c = CString::new("").expect("CString::new failed");
2639df49d7eSJed Brown         let mut ptr = std::ptr::null_mut();
2649df49d7eSJed Brown 
2659df49d7eSJed Brown         // Context for closure
2669df49d7eSJed Brown         let number_inputs = 0;
2679df49d7eSJed Brown         let number_outputs = 0;
2689df49d7eSJed Brown         let input_sizes = [0; MAX_QFUNCTION_FIELDS];
2699df49d7eSJed Brown         let output_sizes = [0; MAX_QFUNCTION_FIELDS];
2709df49d7eSJed Brown         let trampoline_data = unsafe {
2719df49d7eSJed Brown             Pin::new_unchecked(Box::new(QFunctionTrampolineData {
2729df49d7eSJed Brown                 number_inputs,
2739df49d7eSJed Brown                 number_outputs,
2749df49d7eSJed Brown                 input_sizes,
2759df49d7eSJed Brown                 output_sizes,
2769df49d7eSJed Brown                 user_f,
2779df49d7eSJed Brown             }))
2789df49d7eSJed Brown         };
2799df49d7eSJed Brown 
2809df49d7eSJed Brown         // Create QFunction
2819df49d7eSJed Brown         let vlength = i32::try_from(vlength).unwrap();
2829df49d7eSJed Brown         let mut ierr = unsafe {
2839df49d7eSJed Brown             bind_ceed::CeedQFunctionCreateInterior(
2849df49d7eSJed Brown                 ceed.ptr,
2859df49d7eSJed Brown                 vlength,
2869df49d7eSJed Brown                 Some(trampoline),
2879df49d7eSJed Brown                 source_c.as_ptr(),
2889df49d7eSJed Brown                 &mut ptr,
2899df49d7eSJed Brown             )
2909df49d7eSJed Brown         };
2919df49d7eSJed Brown         ceed.check_error(ierr)?;
2929df49d7eSJed Brown 
2939df49d7eSJed Brown         // Set closure
2949df49d7eSJed Brown         let mut qf_ctx_ptr = std::ptr::null_mut();
2959df49d7eSJed Brown         ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) };
2969df49d7eSJed Brown         ceed.check_error(ierr)?;
2979df49d7eSJed Brown         ierr = unsafe {
2989df49d7eSJed Brown             bind_ceed::CeedQFunctionContextSetData(
2999df49d7eSJed Brown                 qf_ctx_ptr,
3009df49d7eSJed Brown                 crate::MemType::Host as bind_ceed::CeedMemType,
3019df49d7eSJed Brown                 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode,
3029df49d7eSJed Brown                 std::mem::size_of::<QFunctionTrampolineData>() as u64,
3039df49d7eSJed Brown                 std::mem::transmute(trampoline_data.as_ref()),
3049df49d7eSJed Brown             )
3059df49d7eSJed Brown         };
3069df49d7eSJed Brown         ceed.check_error(ierr)?;
3079df49d7eSJed Brown         ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) };
3089df49d7eSJed Brown         ceed.check_error(ierr)?;
3099df49d7eSJed Brown         Ok(Self {
3109df49d7eSJed Brown             qf_core: QFunctionCore { ceed, ptr },
3119df49d7eSJed Brown             qf_ctx_ptr,
3129df49d7eSJed Brown             trampoline_data,
3139df49d7eSJed Brown         })
3149df49d7eSJed Brown     }
3159df49d7eSJed Brown 
3169df49d7eSJed Brown     /// Apply the action of a QFunction
3179df49d7eSJed Brown     ///
3189df49d7eSJed Brown     /// * `Q`      - The number of quadrature points
3199df49d7eSJed Brown     /// * `input`  - Array of input Vectors
3209df49d7eSJed Brown     /// * `output` - Array of output Vectors
3219df49d7eSJed Brown     ///
3229df49d7eSJed Brown     /// ```
3239df49d7eSJed Brown     /// # use libceed::prelude::*;
3249df49d7eSJed Brown     /// # let ceed = libceed::Ceed::default_init();
3259df49d7eSJed Brown     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
3269df49d7eSJed Brown     ///     // Iterate over quadrature points
3279df49d7eSJed Brown     ///     v.iter_mut()
3289df49d7eSJed Brown     ///         .zip(u.iter().zip(weights.iter()))
3299df49d7eSJed Brown     ///         .for_each(|(v, (u, w))| *v = u * w);
3309df49d7eSJed Brown     ///
3319df49d7eSJed Brown     ///     // Return clean error code
3329df49d7eSJed Brown     ///     0
3339df49d7eSJed Brown     /// };
3349df49d7eSJed Brown     ///
3359df49d7eSJed Brown     /// let qf = ceed
3369df49d7eSJed Brown     ///     .q_function_interior(1, Box::new(user_f))
3379df49d7eSJed Brown     ///     .unwrap()
3389df49d7eSJed Brown     ///     .input("u", 1, EvalMode::Interp)
3399df49d7eSJed Brown     ///     .unwrap()
3409df49d7eSJed Brown     ///     .input("weights", 1, EvalMode::Weight)
3419df49d7eSJed Brown     ///     .unwrap()
3429df49d7eSJed Brown     ///     .output("v", 1, EvalMode::Interp)
3439df49d7eSJed Brown     ///     .unwrap();
3449df49d7eSJed Brown     ///
3459df49d7eSJed Brown     /// const Q: usize = 8;
3469df49d7eSJed Brown     /// let mut w = [0.; Q];
3479df49d7eSJed Brown     /// let mut u = [0.; Q];
3489df49d7eSJed Brown     /// let mut v = [0.; Q];
3499df49d7eSJed Brown     ///
3509df49d7eSJed Brown     /// for i in 0..Q {
351*80a9ef05SNatalie Beams     ///     let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.;
3529df49d7eSJed Brown     ///     u[i] = 2. + 3. * x + 5. * x * x;
3539df49d7eSJed Brown     ///     w[i] = 1. - x * x;
3549df49d7eSJed Brown     ///     v[i] = u[i] * w[i];
3559df49d7eSJed Brown     /// }
3569df49d7eSJed Brown     ///
3579df49d7eSJed Brown     /// let uu = ceed.vector_from_slice(&u).unwrap();
3589df49d7eSJed Brown     /// let ww = ceed.vector_from_slice(&w).unwrap();
3599df49d7eSJed Brown     /// let mut vv = ceed.vector(Q).unwrap();
3609df49d7eSJed Brown     /// vv.set_value(0.0);
3619df49d7eSJed Brown     /// {
3629df49d7eSJed Brown     ///     let input = vec![uu, ww];
3639df49d7eSJed Brown     ///     let mut output = vec![vv];
3649df49d7eSJed Brown     ///     qf.apply(Q, &input, &output).unwrap();
3659df49d7eSJed Brown     ///     vv = output.remove(0);
3669df49d7eSJed Brown     /// }
3679df49d7eSJed Brown     ///
3689df49d7eSJed Brown     /// vv.view()
3699df49d7eSJed Brown     ///     .iter()
3709df49d7eSJed Brown     ///     .zip(v.iter())
3719df49d7eSJed Brown     ///     .for_each(|(computed, actual)| {
3729df49d7eSJed Brown     ///         assert_eq!(
3739df49d7eSJed Brown     ///             *computed, *actual,
3749df49d7eSJed Brown     ///             "Incorrect value in QFunction application"
3759df49d7eSJed Brown     ///         );
3769df49d7eSJed Brown     ///     });
3779df49d7eSJed Brown     /// ```
3789df49d7eSJed Brown     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
3799df49d7eSJed Brown         self.qf_core.apply(Q, u, v)
3809df49d7eSJed Brown     }
3819df49d7eSJed Brown 
3829df49d7eSJed Brown     /// Add a QFunction input
3839df49d7eSJed Brown     ///
3849df49d7eSJed Brown     /// * `fieldname` - Name of QFunction field
3859df49d7eSJed Brown     /// * `size`      - Size of QFunction field, `(ncomp * dim)` for `Grad` or
3869df49d7eSJed Brown     ///                   `(ncomp * 1)` for `None`, `Interp`, and `Weight`
3879df49d7eSJed Brown     /// * `emode`     - `EvalMode::None` to use values directly, `EvalMode::Interp`
3889df49d7eSJed Brown     ///                   to use interpolated values, `EvalMode::Grad` to use
3899df49d7eSJed Brown     ///                   gradients, `EvalMode::Weight` to use quadrature weights
3909df49d7eSJed Brown     ///
3919df49d7eSJed Brown     /// ```
3929df49d7eSJed Brown     /// # use libceed::prelude::*;
3939df49d7eSJed Brown     /// # let ceed = libceed::Ceed::default_init();
3949df49d7eSJed Brown     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
3959df49d7eSJed Brown     ///     // Iterate over quadrature points
3969df49d7eSJed Brown     ///     v.iter_mut()
3979df49d7eSJed Brown     ///         .zip(u.iter().zip(weights.iter()))
3989df49d7eSJed Brown     ///         .for_each(|(v, (u, w))| *v = u * w);
3999df49d7eSJed Brown     ///
4009df49d7eSJed Brown     ///     // Return clean error code
4019df49d7eSJed Brown     ///     0
4029df49d7eSJed Brown     /// };
4039df49d7eSJed Brown     ///
4049df49d7eSJed Brown     /// let mut qf = ceed.q_function_interior(1, Box::new(user_f)).unwrap();
4059df49d7eSJed Brown     ///
4069df49d7eSJed Brown     /// qf = qf.input("u", 1, EvalMode::Interp).unwrap();
4079df49d7eSJed Brown     /// qf = qf.input("weights", 1, EvalMode::Weight).unwrap();
4089df49d7eSJed Brown     /// ```
4099df49d7eSJed Brown     pub fn input(
4109df49d7eSJed Brown         mut self,
4119df49d7eSJed Brown         fieldname: &str,
4129df49d7eSJed Brown         size: usize,
4139df49d7eSJed Brown         emode: crate::EvalMode,
4149df49d7eSJed Brown     ) -> crate::Result<Self> {
4159df49d7eSJed Brown         let name_c = CString::new(fieldname).expect("CString::new failed");
4169df49d7eSJed Brown         let idx = self.trampoline_data.number_inputs;
4179df49d7eSJed Brown         self.trampoline_data.input_sizes[idx] = size;
4189df49d7eSJed Brown         self.trampoline_data.number_inputs += 1;
4199df49d7eSJed Brown         let (size, emode) = (
4209df49d7eSJed Brown             i32::try_from(size).unwrap(),
4219df49d7eSJed Brown             emode as bind_ceed::CeedEvalMode,
4229df49d7eSJed Brown         );
4239df49d7eSJed Brown         let ierr = unsafe {
4249df49d7eSJed Brown             bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
4259df49d7eSJed Brown         };
4269df49d7eSJed Brown         self.qf_core.ceed.check_error(ierr)?;
4279df49d7eSJed Brown         Ok(self)
4289df49d7eSJed Brown     }
4299df49d7eSJed Brown 
4309df49d7eSJed Brown     /// Add a QFunction output
4319df49d7eSJed Brown     ///
4329df49d7eSJed Brown     /// * `fieldname` - Name of QFunction field
4339df49d7eSJed Brown     /// * `size`      - Size of QFunction field, `(ncomp * dim)` for `Grad` or
4349df49d7eSJed Brown     ///                   `(ncomp * 1)` for `None` and `Interp`
4359df49d7eSJed Brown     /// * `emode`     - `EvalMode::None` to use values directly, `EvalMode::Interp`
4369df49d7eSJed Brown     ///                   to use interpolated values, `EvalMode::Grad` to use
4379df49d7eSJed Brown     ///                   gradients
4389df49d7eSJed Brown     ///
4399df49d7eSJed Brown     /// ```
4409df49d7eSJed Brown     /// # use libceed::prelude::*;
4419df49d7eSJed Brown     /// # let ceed = libceed::Ceed::default_init();
4429df49d7eSJed Brown     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
4439df49d7eSJed Brown     ///     // Iterate over quadrature points
4449df49d7eSJed Brown     ///     v.iter_mut()
4459df49d7eSJed Brown     ///         .zip(u.iter().zip(weights.iter()))
4469df49d7eSJed Brown     ///         .for_each(|(v, (u, w))| *v = u * w);
4479df49d7eSJed Brown     ///
4489df49d7eSJed Brown     ///     // Return clean error code
4499df49d7eSJed Brown     ///     0
4509df49d7eSJed Brown     /// };
4519df49d7eSJed Brown     ///
4529df49d7eSJed Brown     /// let mut qf = ceed.q_function_interior(1, Box::new(user_f)).unwrap();
4539df49d7eSJed Brown     ///
4549df49d7eSJed Brown     /// qf.output("v", 1, EvalMode::Interp).unwrap();
4559df49d7eSJed Brown     /// ```
4569df49d7eSJed Brown     pub fn output(
4579df49d7eSJed Brown         mut self,
4589df49d7eSJed Brown         fieldname: &str,
4599df49d7eSJed Brown         size: usize,
4609df49d7eSJed Brown         emode: crate::EvalMode,
4619df49d7eSJed Brown     ) -> crate::Result<Self> {
4629df49d7eSJed Brown         let name_c = CString::new(fieldname).expect("CString::new failed");
4639df49d7eSJed Brown         let idx = self.trampoline_data.number_outputs;
4649df49d7eSJed Brown         self.trampoline_data.output_sizes[idx] = size;
4659df49d7eSJed Brown         self.trampoline_data.number_outputs += 1;
4669df49d7eSJed Brown         let (size, emode) = (
4679df49d7eSJed Brown             i32::try_from(size).unwrap(),
4689df49d7eSJed Brown             emode as bind_ceed::CeedEvalMode,
4699df49d7eSJed Brown         );
4709df49d7eSJed Brown         let ierr = unsafe {
4719df49d7eSJed Brown             bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
4729df49d7eSJed Brown         };
4739df49d7eSJed Brown         self.qf_core.ceed.check_error(ierr)?;
4749df49d7eSJed Brown         Ok(self)
4759df49d7eSJed Brown     }
4769df49d7eSJed Brown }
4779df49d7eSJed Brown 
4789df49d7eSJed Brown // -----------------------------------------------------------------------------
4799df49d7eSJed Brown // QFunction
4809df49d7eSJed Brown // -----------------------------------------------------------------------------
4819df49d7eSJed Brown impl<'a> QFunctionByName<'a> {
4829df49d7eSJed Brown     // Constructor
4839df49d7eSJed Brown     pub fn create(ceed: &'a crate::Ceed, name: &str) -> crate::Result<Self> {
4849df49d7eSJed Brown         let name_c = CString::new(name).expect("CString::new failed");
4859df49d7eSJed Brown         let mut ptr = std::ptr::null_mut();
4869df49d7eSJed Brown         let ierr = unsafe {
4879df49d7eSJed Brown             bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr)
4889df49d7eSJed Brown         };
4899df49d7eSJed Brown         ceed.check_error(ierr)?;
4909df49d7eSJed Brown         Ok(Self {
4919df49d7eSJed Brown             qf_core: QFunctionCore { ceed, ptr },
4929df49d7eSJed Brown         })
4939df49d7eSJed Brown     }
4949df49d7eSJed Brown 
4959df49d7eSJed Brown     /// Apply the action of a QFunction
4969df49d7eSJed Brown     ///
4979df49d7eSJed Brown     /// * `Q`      - The number of quadrature points
4989df49d7eSJed Brown     /// * `input`  - Array of input Vectors
4999df49d7eSJed Brown     /// * `output` - Array of output Vectors
5009df49d7eSJed Brown     ///
5019df49d7eSJed Brown     /// ```
5029df49d7eSJed Brown     /// # use libceed::prelude::*;
5039df49d7eSJed Brown     /// # let ceed = libceed::Ceed::default_init();
5049df49d7eSJed Brown     /// const Q: usize = 8;
5059df49d7eSJed Brown     /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild").unwrap();
5069df49d7eSJed Brown     /// let qf_mass = ceed.q_function_interior_by_name("MassApply").unwrap();
5079df49d7eSJed Brown     ///
5089df49d7eSJed Brown     /// let mut j = [0.; Q];
5099df49d7eSJed Brown     /// let mut w = [0.; Q];
5109df49d7eSJed Brown     /// let mut u = [0.; Q];
5119df49d7eSJed Brown     /// let mut v = [0.; Q];
5129df49d7eSJed Brown     ///
5139df49d7eSJed Brown     /// for i in 0..Q {
514*80a9ef05SNatalie Beams     ///     let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.;
5159df49d7eSJed Brown     ///     j[i] = 1.;
5169df49d7eSJed Brown     ///     w[i] = 1. - x * x;
5179df49d7eSJed Brown     ///     u[i] = 2. + 3. * x + 5. * x * x;
5189df49d7eSJed Brown     ///     v[i] = w[i] * u[i];
5199df49d7eSJed Brown     /// }
5209df49d7eSJed Brown     ///
5219df49d7eSJed Brown     /// let jj = ceed.vector_from_slice(&j).unwrap();
5229df49d7eSJed Brown     /// let ww = ceed.vector_from_slice(&w).unwrap();
5239df49d7eSJed Brown     /// let uu = ceed.vector_from_slice(&u).unwrap();
5249df49d7eSJed Brown     /// let mut vv = ceed.vector(Q).unwrap();
5259df49d7eSJed Brown     /// vv.set_value(0.0);
5269df49d7eSJed Brown     /// let mut qdata = ceed.vector(Q).unwrap();
5279df49d7eSJed Brown     /// qdata.set_value(0.0);
5289df49d7eSJed Brown     ///
5299df49d7eSJed Brown     /// {
5309df49d7eSJed Brown     ///     let mut input = vec![jj, ww];
5319df49d7eSJed Brown     ///     let mut output = vec![qdata];
5329df49d7eSJed Brown     ///     qf_build.apply(Q, &input, &output).unwrap();
5339df49d7eSJed Brown     ///     qdata = output.remove(0);
5349df49d7eSJed Brown     /// }
5359df49d7eSJed Brown     ///
5369df49d7eSJed Brown     /// {
5379df49d7eSJed Brown     ///     let mut input = vec![qdata, uu];
5389df49d7eSJed Brown     ///     let mut output = vec![vv];
5399df49d7eSJed Brown     ///     qf_mass.apply(Q, &input, &output).unwrap();
5409df49d7eSJed Brown     ///     vv = output.remove(0);
5419df49d7eSJed Brown     /// }
5429df49d7eSJed Brown     ///
5439df49d7eSJed Brown     /// vv.view()
5449df49d7eSJed Brown     ///     .iter()
5459df49d7eSJed Brown     ///     .zip(v.iter())
5469df49d7eSJed Brown     ///     .for_each(|(computed, actual)| {
5479df49d7eSJed Brown     ///         assert_eq!(
5489df49d7eSJed Brown     ///             *computed, *actual,
5499df49d7eSJed Brown     ///             "Incorrect value in QFunction application"
5509df49d7eSJed Brown     ///         );
5519df49d7eSJed Brown     ///     });
5529df49d7eSJed Brown     /// ```
5539df49d7eSJed Brown     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
5549df49d7eSJed Brown         self.qf_core.apply(Q, u, v)
5559df49d7eSJed Brown     }
5569df49d7eSJed Brown }
5579df49d7eSJed Brown 
5589df49d7eSJed Brown // -----------------------------------------------------------------------------
559