xref: /libCEED/rust/libceed/src/qfunction.rs (revision 9df49d7ef0a77c7a3baec2427d8a7274681409b6)
1*9df49d7eSJed Brown // Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
2*9df49d7eSJed Brown // the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
3*9df49d7eSJed Brown // reserved. See files LICENSE and NOTICE for details.
4*9df49d7eSJed Brown //
5*9df49d7eSJed Brown // This file is part of CEED, a collection of benchmarks, miniapps, software
6*9df49d7eSJed Brown // libraries and APIs for efficient high-order finite element and spectral
7*9df49d7eSJed Brown // element discretizations for exascale applications. For more information and
8*9df49d7eSJed Brown // source code availability see http://github.com/ceed.
9*9df49d7eSJed Brown //
10*9df49d7eSJed Brown // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11*9df49d7eSJed Brown // a collaborative effort of two U.S. Department of Energy organizations (Office
12*9df49d7eSJed Brown // of Science and the National Nuclear Security Administration) responsible for
13*9df49d7eSJed Brown // the planning and preparation of a capable exascale ecosystem, including
14*9df49d7eSJed Brown // software, applications, hardware, advanced system engineering and early
15*9df49d7eSJed Brown // testbed platforms, in support of the nation's exascale computing imperative
16*9df49d7eSJed Brown 
17*9df49d7eSJed Brown //! A Ceed QFunction represents the spatial terms of the point-wise functions
18*9df49d7eSJed Brown //! describing the physics at the quadrature points.
19*9df49d7eSJed Brown 
20*9df49d7eSJed Brown use std::pin::Pin;
21*9df49d7eSJed Brown 
22*9df49d7eSJed Brown use crate::prelude::*;
23*9df49d7eSJed Brown 
24*9df49d7eSJed Brown pub type QFunctionInputs<'a> = [&'a [f64]; MAX_QFUNCTION_FIELDS];
25*9df49d7eSJed Brown pub type QFunctionOutputs<'a> = [&'a mut [f64]; MAX_QFUNCTION_FIELDS];
26*9df49d7eSJed Brown 
27*9df49d7eSJed Brown // -----------------------------------------------------------------------------
28*9df49d7eSJed Brown // CeedQFunction option
29*9df49d7eSJed Brown // -----------------------------------------------------------------------------
30*9df49d7eSJed Brown #[derive(Clone, Copy)]
31*9df49d7eSJed Brown pub enum QFunctionOpt<'a> {
32*9df49d7eSJed Brown     SomeQFunction(&'a QFunction<'a>),
33*9df49d7eSJed Brown     SomeQFunctionByName(&'a QFunctionByName<'a>),
34*9df49d7eSJed Brown     None,
35*9df49d7eSJed Brown }
36*9df49d7eSJed Brown 
37*9df49d7eSJed Brown /// Construct a QFunctionOpt reference from a QFunction reference
38*9df49d7eSJed Brown impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> {
39*9df49d7eSJed Brown     fn from(qfunc: &'a QFunction) -> Self {
40*9df49d7eSJed Brown         debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
41*9df49d7eSJed Brown         Self::SomeQFunction(qfunc)
42*9df49d7eSJed Brown     }
43*9df49d7eSJed Brown }
44*9df49d7eSJed Brown 
45*9df49d7eSJed Brown /// Construct a QFunctionOpt reference from a QFunction by Name reference
46*9df49d7eSJed Brown impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> {
47*9df49d7eSJed Brown     fn from(qfunc: &'a QFunctionByName) -> Self {
48*9df49d7eSJed Brown         debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
49*9df49d7eSJed Brown         Self::SomeQFunctionByName(qfunc)
50*9df49d7eSJed Brown     }
51*9df49d7eSJed Brown }
52*9df49d7eSJed Brown 
53*9df49d7eSJed Brown impl<'a> QFunctionOpt<'a> {
54*9df49d7eSJed Brown     /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction
55*9df49d7eSJed Brown     pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction {
56*9df49d7eSJed Brown         match self {
57*9df49d7eSJed Brown             Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr,
58*9df49d7eSJed Brown             Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr,
59*9df49d7eSJed Brown             Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE },
60*9df49d7eSJed Brown         }
61*9df49d7eSJed Brown     }
62*9df49d7eSJed Brown }
63*9df49d7eSJed Brown 
64*9df49d7eSJed Brown // -----------------------------------------------------------------------------
65*9df49d7eSJed Brown // CeedQFunction context wrapper
66*9df49d7eSJed Brown // -----------------------------------------------------------------------------
67*9df49d7eSJed Brown pub(crate) struct QFunctionCore<'a> {
68*9df49d7eSJed Brown     ceed: &'a crate::Ceed,
69*9df49d7eSJed Brown     ptr: bind_ceed::CeedQFunction,
70*9df49d7eSJed Brown }
71*9df49d7eSJed Brown 
72*9df49d7eSJed Brown struct QFunctionTrampolineData {
73*9df49d7eSJed Brown     number_inputs: usize,
74*9df49d7eSJed Brown     number_outputs: usize,
75*9df49d7eSJed Brown     input_sizes: [usize; MAX_QFUNCTION_FIELDS],
76*9df49d7eSJed Brown     output_sizes: [usize; MAX_QFUNCTION_FIELDS],
77*9df49d7eSJed Brown     user_f: Box<QFunctionUserClosure>,
78*9df49d7eSJed Brown }
79*9df49d7eSJed Brown 
80*9df49d7eSJed Brown pub struct QFunction<'a> {
81*9df49d7eSJed Brown     qf_core: QFunctionCore<'a>,
82*9df49d7eSJed Brown     qf_ctx_ptr: bind_ceed::CeedQFunctionContext,
83*9df49d7eSJed Brown     trampoline_data: Pin<Box<QFunctionTrampolineData>>,
84*9df49d7eSJed Brown }
85*9df49d7eSJed Brown 
86*9df49d7eSJed Brown pub struct QFunctionByName<'a> {
87*9df49d7eSJed Brown     qf_core: QFunctionCore<'a>,
88*9df49d7eSJed Brown }
89*9df49d7eSJed Brown 
90*9df49d7eSJed Brown // -----------------------------------------------------------------------------
91*9df49d7eSJed Brown // Destructor
92*9df49d7eSJed Brown // -----------------------------------------------------------------------------
93*9df49d7eSJed Brown impl<'a> Drop for QFunctionCore<'a> {
94*9df49d7eSJed Brown     fn drop(&mut self) {
95*9df49d7eSJed Brown         unsafe {
96*9df49d7eSJed Brown             if self.ptr != bind_ceed::CEED_QFUNCTION_NONE {
97*9df49d7eSJed Brown                 bind_ceed::CeedQFunctionDestroy(&mut self.ptr);
98*9df49d7eSJed Brown             }
99*9df49d7eSJed Brown         }
100*9df49d7eSJed Brown     }
101*9df49d7eSJed Brown }
102*9df49d7eSJed Brown 
103*9df49d7eSJed Brown impl<'a> Drop for QFunction<'a> {
104*9df49d7eSJed Brown     fn drop(&mut self) {
105*9df49d7eSJed Brown         unsafe {
106*9df49d7eSJed Brown             bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr);
107*9df49d7eSJed Brown         }
108*9df49d7eSJed Brown     }
109*9df49d7eSJed Brown }
110*9df49d7eSJed Brown 
111*9df49d7eSJed Brown // -----------------------------------------------------------------------------
112*9df49d7eSJed Brown // Display
113*9df49d7eSJed Brown // -----------------------------------------------------------------------------
114*9df49d7eSJed Brown impl<'a> fmt::Display for QFunctionCore<'a> {
115*9df49d7eSJed Brown     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116*9df49d7eSJed Brown         let mut ptr = std::ptr::null_mut();
117*9df49d7eSJed Brown         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
118*9df49d7eSJed Brown         let cstring = unsafe {
119*9df49d7eSJed Brown             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
120*9df49d7eSJed Brown             bind_ceed::CeedQFunctionView(self.ptr, file);
121*9df49d7eSJed Brown             bind_ceed::fclose(file);
122*9df49d7eSJed Brown             CString::from_raw(ptr)
123*9df49d7eSJed Brown         };
124*9df49d7eSJed Brown         cstring.to_string_lossy().fmt(f)
125*9df49d7eSJed Brown     }
126*9df49d7eSJed Brown }
127*9df49d7eSJed Brown /// View a QFunction
128*9df49d7eSJed Brown ///
129*9df49d7eSJed Brown /// ```
130*9df49d7eSJed Brown /// # use libceed::prelude::*;
131*9df49d7eSJed Brown /// # let ceed = libceed::Ceed::default_init();
132*9df49d7eSJed Brown /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
133*9df49d7eSJed Brown ///     // Iterate over quadrature points
134*9df49d7eSJed Brown ///     v.iter_mut()
135*9df49d7eSJed Brown ///         .zip(u.iter().zip(weights.iter()))
136*9df49d7eSJed Brown ///         .for_each(|(v, (u, w))| *v = u * w);
137*9df49d7eSJed Brown ///
138*9df49d7eSJed Brown ///     // Return clean error code
139*9df49d7eSJed Brown ///     0
140*9df49d7eSJed Brown /// };
141*9df49d7eSJed Brown ///
142*9df49d7eSJed Brown /// let qf = ceed
143*9df49d7eSJed Brown ///     .q_function_interior(1, Box::new(user_f))
144*9df49d7eSJed Brown ///     .unwrap()
145*9df49d7eSJed Brown ///     .input("u", 1, EvalMode::Interp)
146*9df49d7eSJed Brown ///     .unwrap()
147*9df49d7eSJed Brown ///     .input("weights", 1, EvalMode::Weight)
148*9df49d7eSJed Brown ///     .unwrap()
149*9df49d7eSJed Brown ///     .output("v", 1, EvalMode::Interp)
150*9df49d7eSJed Brown ///     .unwrap();
151*9df49d7eSJed Brown ///
152*9df49d7eSJed Brown /// println!("{}", qf);
153*9df49d7eSJed Brown /// ```
154*9df49d7eSJed Brown impl<'a> fmt::Display for QFunction<'a> {
155*9df49d7eSJed Brown     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
156*9df49d7eSJed Brown         self.qf_core.fmt(f)
157*9df49d7eSJed Brown     }
158*9df49d7eSJed Brown }
159*9df49d7eSJed Brown 
160*9df49d7eSJed Brown /// View a QFunction by Name
161*9df49d7eSJed Brown ///
162*9df49d7eSJed Brown /// ```
163*9df49d7eSJed Brown /// # use libceed::prelude::*;
164*9df49d7eSJed Brown /// # let ceed = libceed::Ceed::default_init();
165*9df49d7eSJed Brown /// let qf = ceed.q_function_interior_by_name("Mass1DBuild").unwrap();
166*9df49d7eSJed Brown /// println!("{}", qf);
167*9df49d7eSJed Brown /// ```
168*9df49d7eSJed Brown impl<'a> fmt::Display for QFunctionByName<'a> {
169*9df49d7eSJed Brown     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
170*9df49d7eSJed Brown         self.qf_core.fmt(f)
171*9df49d7eSJed Brown     }
172*9df49d7eSJed Brown }
173*9df49d7eSJed Brown 
174*9df49d7eSJed Brown // -----------------------------------------------------------------------------
175*9df49d7eSJed Brown // Core functionality
176*9df49d7eSJed Brown // -----------------------------------------------------------------------------
177*9df49d7eSJed Brown impl<'a> QFunctionCore<'a> {
178*9df49d7eSJed Brown     // Common implementation
179*9df49d7eSJed Brown     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
180*9df49d7eSJed Brown         let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
181*9df49d7eSJed Brown         for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) {
182*9df49d7eSJed Brown             u_c[i] = u[i].ptr;
183*9df49d7eSJed Brown         }
184*9df49d7eSJed Brown         let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
185*9df49d7eSJed Brown         for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) {
186*9df49d7eSJed Brown             v_c[i] = v[i].ptr;
187*9df49d7eSJed Brown         }
188*9df49d7eSJed Brown         let Q = i32::try_from(Q).unwrap();
189*9df49d7eSJed Brown         let ierr = unsafe {
190*9df49d7eSJed Brown             bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr())
191*9df49d7eSJed Brown         };
192*9df49d7eSJed Brown         self.ceed.check_error(ierr)
193*9df49d7eSJed Brown     }
194*9df49d7eSJed Brown }
195*9df49d7eSJed Brown 
196*9df49d7eSJed Brown // -----------------------------------------------------------------------------
197*9df49d7eSJed Brown // User QFunction Closure
198*9df49d7eSJed Brown // -----------------------------------------------------------------------------
199*9df49d7eSJed Brown pub type QFunctionUserClosure =
200*9df49d7eSJed Brown     dyn FnMut([&[f64]; MAX_QFUNCTION_FIELDS], [&mut [f64]; MAX_QFUNCTION_FIELDS]) -> i32;
201*9df49d7eSJed Brown 
202*9df49d7eSJed Brown macro_rules! mut_max_fields {
203*9df49d7eSJed Brown     ($e:expr) => {
204*9df49d7eSJed Brown         [
205*9df49d7eSJed Brown             $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e,
206*9df49d7eSJed Brown         ]
207*9df49d7eSJed Brown     };
208*9df49d7eSJed Brown }
209*9df49d7eSJed Brown unsafe extern "C" fn trampoline(
210*9df49d7eSJed Brown     ctx: *mut ::std::os::raw::c_void,
211*9df49d7eSJed Brown     q: bind_ceed::CeedInt,
212*9df49d7eSJed Brown     inputs: *const *const bind_ceed::CeedScalar,
213*9df49d7eSJed Brown     outputs: *const *mut bind_ceed::CeedScalar,
214*9df49d7eSJed Brown ) -> ::std::os::raw::c_int {
215*9df49d7eSJed Brown     let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx);
216*9df49d7eSJed Brown 
217*9df49d7eSJed Brown     // Inputs
218*9df49d7eSJed Brown     let inputs_slice: &[*const bind_ceed::CeedScalar] =
219*9df49d7eSJed Brown         std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS);
220*9df49d7eSJed Brown     let mut inputs_array: [&[f64]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS];
221*9df49d7eSJed Brown     inputs_slice
222*9df49d7eSJed Brown         .iter()
223*9df49d7eSJed Brown         .enumerate()
224*9df49d7eSJed Brown         .map(|(i, &x)| {
225*9df49d7eSJed Brown             std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize) as &[f64]
226*9df49d7eSJed Brown         })
227*9df49d7eSJed Brown         .zip(inputs_array.iter_mut())
228*9df49d7eSJed Brown         .for_each(|(x, a)| *a = x);
229*9df49d7eSJed Brown 
230*9df49d7eSJed Brown     // Outputs
231*9df49d7eSJed Brown     let outputs_slice: &[*mut bind_ceed::CeedScalar] =
232*9df49d7eSJed Brown         std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS);
233*9df49d7eSJed Brown     let mut outputs_array: [&mut [f64]; MAX_QFUNCTION_FIELDS] = mut_max_fields!(&mut [0.0]);
234*9df49d7eSJed Brown     outputs_slice
235*9df49d7eSJed Brown         .iter()
236*9df49d7eSJed Brown         .enumerate()
237*9df49d7eSJed Brown         .map(|(i, &x)| {
238*9df49d7eSJed Brown             std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize)
239*9df49d7eSJed Brown                 as &mut [f64]
240*9df49d7eSJed Brown         })
241*9df49d7eSJed Brown         .zip(outputs_array.iter_mut())
242*9df49d7eSJed Brown         .for_each(|(x, a)| *a = x);
243*9df49d7eSJed Brown 
244*9df49d7eSJed Brown     // User closure
245*9df49d7eSJed Brown     (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array)
246*9df49d7eSJed Brown }
247*9df49d7eSJed Brown 
248*9df49d7eSJed Brown // -----------------------------------------------------------------------------
249*9df49d7eSJed Brown // QFunction
250*9df49d7eSJed Brown // -----------------------------------------------------------------------------
251*9df49d7eSJed Brown impl<'a> QFunction<'a> {
252*9df49d7eSJed Brown     // Constructor
253*9df49d7eSJed Brown     pub fn create(
254*9df49d7eSJed Brown         ceed: &'a crate::Ceed,
255*9df49d7eSJed Brown         vlength: usize,
256*9df49d7eSJed Brown         user_f: Box<QFunctionUserClosure>,
257*9df49d7eSJed Brown     ) -> crate::Result<Self> {
258*9df49d7eSJed Brown         let source_c = CString::new("").expect("CString::new failed");
259*9df49d7eSJed Brown         let mut ptr = std::ptr::null_mut();
260*9df49d7eSJed Brown 
261*9df49d7eSJed Brown         // Context for closure
262*9df49d7eSJed Brown         let number_inputs = 0;
263*9df49d7eSJed Brown         let number_outputs = 0;
264*9df49d7eSJed Brown         let input_sizes = [0; MAX_QFUNCTION_FIELDS];
265*9df49d7eSJed Brown         let output_sizes = [0; MAX_QFUNCTION_FIELDS];
266*9df49d7eSJed Brown         let trampoline_data = unsafe {
267*9df49d7eSJed Brown             Pin::new_unchecked(Box::new(QFunctionTrampolineData {
268*9df49d7eSJed Brown                 number_inputs,
269*9df49d7eSJed Brown                 number_outputs,
270*9df49d7eSJed Brown                 input_sizes,
271*9df49d7eSJed Brown                 output_sizes,
272*9df49d7eSJed Brown                 user_f,
273*9df49d7eSJed Brown             }))
274*9df49d7eSJed Brown         };
275*9df49d7eSJed Brown 
276*9df49d7eSJed Brown         // Create QFunction
277*9df49d7eSJed Brown         let vlength = i32::try_from(vlength).unwrap();
278*9df49d7eSJed Brown         let mut ierr = unsafe {
279*9df49d7eSJed Brown             bind_ceed::CeedQFunctionCreateInterior(
280*9df49d7eSJed Brown                 ceed.ptr,
281*9df49d7eSJed Brown                 vlength,
282*9df49d7eSJed Brown                 Some(trampoline),
283*9df49d7eSJed Brown                 source_c.as_ptr(),
284*9df49d7eSJed Brown                 &mut ptr,
285*9df49d7eSJed Brown             )
286*9df49d7eSJed Brown         };
287*9df49d7eSJed Brown         ceed.check_error(ierr)?;
288*9df49d7eSJed Brown 
289*9df49d7eSJed Brown         // Set closure
290*9df49d7eSJed Brown         let mut qf_ctx_ptr = std::ptr::null_mut();
291*9df49d7eSJed Brown         ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) };
292*9df49d7eSJed Brown         ceed.check_error(ierr)?;
293*9df49d7eSJed Brown         ierr = unsafe {
294*9df49d7eSJed Brown             bind_ceed::CeedQFunctionContextSetData(
295*9df49d7eSJed Brown                 qf_ctx_ptr,
296*9df49d7eSJed Brown                 crate::MemType::Host as bind_ceed::CeedMemType,
297*9df49d7eSJed Brown                 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode,
298*9df49d7eSJed Brown                 std::mem::size_of::<QFunctionTrampolineData>() as u64,
299*9df49d7eSJed Brown                 std::mem::transmute(trampoline_data.as_ref()),
300*9df49d7eSJed Brown             )
301*9df49d7eSJed Brown         };
302*9df49d7eSJed Brown         ceed.check_error(ierr)?;
303*9df49d7eSJed Brown         ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) };
304*9df49d7eSJed Brown         ceed.check_error(ierr)?;
305*9df49d7eSJed Brown         Ok(Self {
306*9df49d7eSJed Brown             qf_core: QFunctionCore { ceed, ptr },
307*9df49d7eSJed Brown             qf_ctx_ptr,
308*9df49d7eSJed Brown             trampoline_data,
309*9df49d7eSJed Brown         })
310*9df49d7eSJed Brown     }
311*9df49d7eSJed Brown 
312*9df49d7eSJed Brown     /// Apply the action of a QFunction
313*9df49d7eSJed Brown     ///
314*9df49d7eSJed Brown     /// * `Q`      - The number of quadrature points
315*9df49d7eSJed Brown     /// * `input`  - Array of input Vectors
316*9df49d7eSJed Brown     /// * `output` - Array of output Vectors
317*9df49d7eSJed Brown     ///
318*9df49d7eSJed Brown     /// ```
319*9df49d7eSJed Brown     /// # use libceed::prelude::*;
320*9df49d7eSJed Brown     /// # let ceed = libceed::Ceed::default_init();
321*9df49d7eSJed Brown     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
322*9df49d7eSJed Brown     ///     // Iterate over quadrature points
323*9df49d7eSJed Brown     ///     v.iter_mut()
324*9df49d7eSJed Brown     ///         .zip(u.iter().zip(weights.iter()))
325*9df49d7eSJed Brown     ///         .for_each(|(v, (u, w))| *v = u * w);
326*9df49d7eSJed Brown     ///
327*9df49d7eSJed Brown     ///     // Return clean error code
328*9df49d7eSJed Brown     ///     0
329*9df49d7eSJed Brown     /// };
330*9df49d7eSJed Brown     ///
331*9df49d7eSJed Brown     /// let qf = ceed
332*9df49d7eSJed Brown     ///     .q_function_interior(1, Box::new(user_f))
333*9df49d7eSJed Brown     ///     .unwrap()
334*9df49d7eSJed Brown     ///     .input("u", 1, EvalMode::Interp)
335*9df49d7eSJed Brown     ///     .unwrap()
336*9df49d7eSJed Brown     ///     .input("weights", 1, EvalMode::Weight)
337*9df49d7eSJed Brown     ///     .unwrap()
338*9df49d7eSJed Brown     ///     .output("v", 1, EvalMode::Interp)
339*9df49d7eSJed Brown     ///     .unwrap();
340*9df49d7eSJed Brown     ///
341*9df49d7eSJed Brown     /// const Q: usize = 8;
342*9df49d7eSJed Brown     /// let mut w = [0.; Q];
343*9df49d7eSJed Brown     /// let mut u = [0.; Q];
344*9df49d7eSJed Brown     /// let mut v = [0.; Q];
345*9df49d7eSJed Brown     ///
346*9df49d7eSJed Brown     /// for i in 0..Q {
347*9df49d7eSJed Brown     ///     let x = 2. * (i as f64) / ((Q as f64) - 1.) - 1.;
348*9df49d7eSJed Brown     ///     u[i] = 2. + 3. * x + 5. * x * x;
349*9df49d7eSJed Brown     ///     w[i] = 1. - x * x;
350*9df49d7eSJed Brown     ///     v[i] = u[i] * w[i];
351*9df49d7eSJed Brown     /// }
352*9df49d7eSJed Brown     ///
353*9df49d7eSJed Brown     /// let uu = ceed.vector_from_slice(&u).unwrap();
354*9df49d7eSJed Brown     /// let ww = ceed.vector_from_slice(&w).unwrap();
355*9df49d7eSJed Brown     /// let mut vv = ceed.vector(Q).unwrap();
356*9df49d7eSJed Brown     /// vv.set_value(0.0);
357*9df49d7eSJed Brown     /// {
358*9df49d7eSJed Brown     ///     let input = vec![uu, ww];
359*9df49d7eSJed Brown     ///     let mut output = vec![vv];
360*9df49d7eSJed Brown     ///     qf.apply(Q, &input, &output).unwrap();
361*9df49d7eSJed Brown     ///     vv = output.remove(0);
362*9df49d7eSJed Brown     /// }
363*9df49d7eSJed Brown     ///
364*9df49d7eSJed Brown     /// vv.view()
365*9df49d7eSJed Brown     ///     .iter()
366*9df49d7eSJed Brown     ///     .zip(v.iter())
367*9df49d7eSJed Brown     ///     .for_each(|(computed, actual)| {
368*9df49d7eSJed Brown     ///         assert_eq!(
369*9df49d7eSJed Brown     ///             *computed, *actual,
370*9df49d7eSJed Brown     ///             "Incorrect value in QFunction application"
371*9df49d7eSJed Brown     ///         );
372*9df49d7eSJed Brown     ///     });
373*9df49d7eSJed Brown     /// ```
374*9df49d7eSJed Brown     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
375*9df49d7eSJed Brown         self.qf_core.apply(Q, u, v)
376*9df49d7eSJed Brown     }
377*9df49d7eSJed Brown 
378*9df49d7eSJed Brown     /// Add a QFunction input
379*9df49d7eSJed Brown     ///
380*9df49d7eSJed Brown     /// * `fieldname` - Name of QFunction field
381*9df49d7eSJed Brown     /// * `size`      - Size of QFunction field, `(ncomp * dim)` for `Grad` or
382*9df49d7eSJed Brown     ///                   `(ncomp * 1)` for `None`, `Interp`, and `Weight`
383*9df49d7eSJed Brown     /// * `emode`     - `EvalMode::None` to use values directly, `EvalMode::Interp`
384*9df49d7eSJed Brown     ///                   to use interpolated values, `EvalMode::Grad` to use
385*9df49d7eSJed Brown     ///                   gradients, `EvalMode::Weight` to use quadrature weights
386*9df49d7eSJed Brown     ///
387*9df49d7eSJed Brown     /// ```
388*9df49d7eSJed Brown     /// # use libceed::prelude::*;
389*9df49d7eSJed Brown     /// # let ceed = libceed::Ceed::default_init();
390*9df49d7eSJed Brown     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
391*9df49d7eSJed Brown     ///     // Iterate over quadrature points
392*9df49d7eSJed Brown     ///     v.iter_mut()
393*9df49d7eSJed Brown     ///         .zip(u.iter().zip(weights.iter()))
394*9df49d7eSJed Brown     ///         .for_each(|(v, (u, w))| *v = u * w);
395*9df49d7eSJed Brown     ///
396*9df49d7eSJed Brown     ///     // Return clean error code
397*9df49d7eSJed Brown     ///     0
398*9df49d7eSJed Brown     /// };
399*9df49d7eSJed Brown     ///
400*9df49d7eSJed Brown     /// let mut qf = ceed.q_function_interior(1, Box::new(user_f)).unwrap();
401*9df49d7eSJed Brown     ///
402*9df49d7eSJed Brown     /// qf = qf.input("u", 1, EvalMode::Interp).unwrap();
403*9df49d7eSJed Brown     /// qf = qf.input("weights", 1, EvalMode::Weight).unwrap();
404*9df49d7eSJed Brown     /// ```
405*9df49d7eSJed Brown     pub fn input(
406*9df49d7eSJed Brown         mut self,
407*9df49d7eSJed Brown         fieldname: &str,
408*9df49d7eSJed Brown         size: usize,
409*9df49d7eSJed Brown         emode: crate::EvalMode,
410*9df49d7eSJed Brown     ) -> crate::Result<Self> {
411*9df49d7eSJed Brown         let name_c = CString::new(fieldname).expect("CString::new failed");
412*9df49d7eSJed Brown         let idx = self.trampoline_data.number_inputs;
413*9df49d7eSJed Brown         self.trampoline_data.input_sizes[idx] = size;
414*9df49d7eSJed Brown         self.trampoline_data.number_inputs += 1;
415*9df49d7eSJed Brown         let (size, emode) = (
416*9df49d7eSJed Brown             i32::try_from(size).unwrap(),
417*9df49d7eSJed Brown             emode as bind_ceed::CeedEvalMode,
418*9df49d7eSJed Brown         );
419*9df49d7eSJed Brown         let ierr = unsafe {
420*9df49d7eSJed Brown             bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
421*9df49d7eSJed Brown         };
422*9df49d7eSJed Brown         self.qf_core.ceed.check_error(ierr)?;
423*9df49d7eSJed Brown         Ok(self)
424*9df49d7eSJed Brown     }
425*9df49d7eSJed Brown 
426*9df49d7eSJed Brown     /// Add a QFunction output
427*9df49d7eSJed Brown     ///
428*9df49d7eSJed Brown     /// * `fieldname` - Name of QFunction field
429*9df49d7eSJed Brown     /// * `size`      - Size of QFunction field, `(ncomp * dim)` for `Grad` or
430*9df49d7eSJed Brown     ///                   `(ncomp * 1)` for `None` and `Interp`
431*9df49d7eSJed Brown     /// * `emode`     - `EvalMode::None` to use values directly, `EvalMode::Interp`
432*9df49d7eSJed Brown     ///                   to use interpolated values, `EvalMode::Grad` to use
433*9df49d7eSJed Brown     ///                   gradients
434*9df49d7eSJed Brown     ///
435*9df49d7eSJed Brown     /// ```
436*9df49d7eSJed Brown     /// # use libceed::prelude::*;
437*9df49d7eSJed Brown     /// # let ceed = libceed::Ceed::default_init();
438*9df49d7eSJed Brown     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
439*9df49d7eSJed Brown     ///     // Iterate over quadrature points
440*9df49d7eSJed Brown     ///     v.iter_mut()
441*9df49d7eSJed Brown     ///         .zip(u.iter().zip(weights.iter()))
442*9df49d7eSJed Brown     ///         .for_each(|(v, (u, w))| *v = u * w);
443*9df49d7eSJed Brown     ///
444*9df49d7eSJed Brown     ///     // Return clean error code
445*9df49d7eSJed Brown     ///     0
446*9df49d7eSJed Brown     /// };
447*9df49d7eSJed Brown     ///
448*9df49d7eSJed Brown     /// let mut qf = ceed.q_function_interior(1, Box::new(user_f)).unwrap();
449*9df49d7eSJed Brown     ///
450*9df49d7eSJed Brown     /// qf.output("v", 1, EvalMode::Interp).unwrap();
451*9df49d7eSJed Brown     /// ```
452*9df49d7eSJed Brown     pub fn output(
453*9df49d7eSJed Brown         mut self,
454*9df49d7eSJed Brown         fieldname: &str,
455*9df49d7eSJed Brown         size: usize,
456*9df49d7eSJed Brown         emode: crate::EvalMode,
457*9df49d7eSJed Brown     ) -> crate::Result<Self> {
458*9df49d7eSJed Brown         let name_c = CString::new(fieldname).expect("CString::new failed");
459*9df49d7eSJed Brown         let idx = self.trampoline_data.number_outputs;
460*9df49d7eSJed Brown         self.trampoline_data.output_sizes[idx] = size;
461*9df49d7eSJed Brown         self.trampoline_data.number_outputs += 1;
462*9df49d7eSJed Brown         let (size, emode) = (
463*9df49d7eSJed Brown             i32::try_from(size).unwrap(),
464*9df49d7eSJed Brown             emode as bind_ceed::CeedEvalMode,
465*9df49d7eSJed Brown         );
466*9df49d7eSJed Brown         let ierr = unsafe {
467*9df49d7eSJed Brown             bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
468*9df49d7eSJed Brown         };
469*9df49d7eSJed Brown         self.qf_core.ceed.check_error(ierr)?;
470*9df49d7eSJed Brown         Ok(self)
471*9df49d7eSJed Brown     }
472*9df49d7eSJed Brown }
473*9df49d7eSJed Brown 
474*9df49d7eSJed Brown // -----------------------------------------------------------------------------
475*9df49d7eSJed Brown // QFunction
476*9df49d7eSJed Brown // -----------------------------------------------------------------------------
477*9df49d7eSJed Brown impl<'a> QFunctionByName<'a> {
478*9df49d7eSJed Brown     // Constructor
479*9df49d7eSJed Brown     pub fn create(ceed: &'a crate::Ceed, name: &str) -> crate::Result<Self> {
480*9df49d7eSJed Brown         let name_c = CString::new(name).expect("CString::new failed");
481*9df49d7eSJed Brown         let mut ptr = std::ptr::null_mut();
482*9df49d7eSJed Brown         let ierr = unsafe {
483*9df49d7eSJed Brown             bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr)
484*9df49d7eSJed Brown         };
485*9df49d7eSJed Brown         ceed.check_error(ierr)?;
486*9df49d7eSJed Brown         Ok(Self {
487*9df49d7eSJed Brown             qf_core: QFunctionCore { ceed, ptr },
488*9df49d7eSJed Brown         })
489*9df49d7eSJed Brown     }
490*9df49d7eSJed Brown 
491*9df49d7eSJed Brown     /// Apply the action of a QFunction
492*9df49d7eSJed Brown     ///
493*9df49d7eSJed Brown     /// * `Q`      - The number of quadrature points
494*9df49d7eSJed Brown     /// * `input`  - Array of input Vectors
495*9df49d7eSJed Brown     /// * `output` - Array of output Vectors
496*9df49d7eSJed Brown     ///
497*9df49d7eSJed Brown     /// ```
498*9df49d7eSJed Brown     /// # use libceed::prelude::*;
499*9df49d7eSJed Brown     /// # let ceed = libceed::Ceed::default_init();
500*9df49d7eSJed Brown     /// const Q: usize = 8;
501*9df49d7eSJed Brown     /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild").unwrap();
502*9df49d7eSJed Brown     /// let qf_mass = ceed.q_function_interior_by_name("MassApply").unwrap();
503*9df49d7eSJed Brown     ///
504*9df49d7eSJed Brown     /// let mut j = [0.; Q];
505*9df49d7eSJed Brown     /// let mut w = [0.; Q];
506*9df49d7eSJed Brown     /// let mut u = [0.; Q];
507*9df49d7eSJed Brown     /// let mut v = [0.; Q];
508*9df49d7eSJed Brown     ///
509*9df49d7eSJed Brown     /// for i in 0..Q {
510*9df49d7eSJed Brown     ///     let x = 2. * (i as f64) / ((Q as f64) - 1.) - 1.;
511*9df49d7eSJed Brown     ///     j[i] = 1.;
512*9df49d7eSJed Brown     ///     w[i] = 1. - x * x;
513*9df49d7eSJed Brown     ///     u[i] = 2. + 3. * x + 5. * x * x;
514*9df49d7eSJed Brown     ///     v[i] = w[i] * u[i];
515*9df49d7eSJed Brown     /// }
516*9df49d7eSJed Brown     ///
517*9df49d7eSJed Brown     /// let jj = ceed.vector_from_slice(&j).unwrap();
518*9df49d7eSJed Brown     /// let ww = ceed.vector_from_slice(&w).unwrap();
519*9df49d7eSJed Brown     /// let uu = ceed.vector_from_slice(&u).unwrap();
520*9df49d7eSJed Brown     /// let mut vv = ceed.vector(Q).unwrap();
521*9df49d7eSJed Brown     /// vv.set_value(0.0);
522*9df49d7eSJed Brown     /// let mut qdata = ceed.vector(Q).unwrap();
523*9df49d7eSJed Brown     /// qdata.set_value(0.0);
524*9df49d7eSJed Brown     ///
525*9df49d7eSJed Brown     /// {
526*9df49d7eSJed Brown     ///     let mut input = vec![jj, ww];
527*9df49d7eSJed Brown     ///     let mut output = vec![qdata];
528*9df49d7eSJed Brown     ///     qf_build.apply(Q, &input, &output).unwrap();
529*9df49d7eSJed Brown     ///     qdata = output.remove(0);
530*9df49d7eSJed Brown     /// }
531*9df49d7eSJed Brown     ///
532*9df49d7eSJed Brown     /// {
533*9df49d7eSJed Brown     ///     let mut input = vec![qdata, uu];
534*9df49d7eSJed Brown     ///     let mut output = vec![vv];
535*9df49d7eSJed Brown     ///     qf_mass.apply(Q, &input, &output).unwrap();
536*9df49d7eSJed Brown     ///     vv = output.remove(0);
537*9df49d7eSJed Brown     /// }
538*9df49d7eSJed Brown     ///
539*9df49d7eSJed Brown     /// vv.view()
540*9df49d7eSJed Brown     ///     .iter()
541*9df49d7eSJed Brown     ///     .zip(v.iter())
542*9df49d7eSJed Brown     ///     .for_each(|(computed, actual)| {
543*9df49d7eSJed Brown     ///         assert_eq!(
544*9df49d7eSJed Brown     ///             *computed, *actual,
545*9df49d7eSJed Brown     ///             "Incorrect value in QFunction application"
546*9df49d7eSJed Brown     ///         );
547*9df49d7eSJed Brown     ///     });
548*9df49d7eSJed Brown     /// ```
549*9df49d7eSJed Brown     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
550*9df49d7eSJed Brown         self.qf_core.apply(Q, u, v)
551*9df49d7eSJed Brown     }
552*9df49d7eSJed Brown }
553*9df49d7eSJed Brown 
554*9df49d7eSJed Brown // -----------------------------------------------------------------------------
555