xref: /libCEED/rust/libceed/src/qfunction.rs (revision 78a9fcb6bae7d7246469592ed2b1811dbe8e744f)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 //! A Ceed QFunction represents the spatial terms of the point-wise functions
9 //! describing the physics at the quadrature points.
10 
11 use std::pin::Pin;
12 
13 use crate::prelude::*;
14 
15 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS];
16 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS];
17 
18 // -----------------------------------------------------------------------------
19 // QFunction Field context wrapper
20 // -----------------------------------------------------------------------------
21 #[derive(Debug)]
22 pub struct QFunctionField<'a> {
23     ptr: bind_ceed::CeedQFunctionField,
24     _lifeline: PhantomData<&'a ()>,
25 }
26 
27 // -----------------------------------------------------------------------------
28 // Implementations
29 // -----------------------------------------------------------------------------
30 impl<'a> QFunctionField<'a> {
31     /// Get the name of a QFunctionField
32     ///
33     /// ```
34     /// # use libceed::prelude::*;
35     /// # fn main() -> libceed::Result<()> {
36     /// # let ceed = libceed::Ceed::default_init();
37     /// const Q: usize = 8;
38     /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?;
39     ///
40     /// let inputs = qf.inputs()?;
41     ///
42     /// assert_eq!(inputs[0].name(), "dx", "Incorrect input name");
43     /// assert_eq!(inputs[1].name(), "weights", "Incorrect input name");
44     /// # Ok(())
45     /// # }
46     /// ```
47     pub fn name(&self) -> &str {
48         let mut name_ptr: *mut std::os::raw::c_char = std::ptr::null_mut();
49         unsafe {
50             bind_ceed::CeedQFunctionFieldGetName(self.ptr, &mut name_ptr);
51         }
52         unsafe { CStr::from_ptr(name_ptr) }.to_str().unwrap()
53     }
54 
55     /// Get the size of a QFunctionField
56     ///
57     /// ```
58     /// # use libceed::prelude::*;
59     /// # fn main() -> libceed::Result<()> {
60     /// # let ceed = libceed::Ceed::default_init();
61     /// const Q: usize = 8;
62     /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?;
63     ///
64     /// let inputs = qf.inputs()?;
65     ///
66     /// assert_eq!(inputs[0].size(), 4, "Incorrect input size");
67     /// assert_eq!(inputs[1].size(), 1, "Incorrect input size");
68     /// # Ok(())
69     /// # }
70     /// ```
71     pub fn size(&self) -> usize {
72         let mut size = 0;
73         unsafe {
74             bind_ceed::CeedQFunctionFieldGetSize(self.ptr, &mut size);
75         }
76         usize::try_from(size).unwrap()
77     }
78 
79     /// Get the evaluation mode of a QFunctionField
80     ///
81     /// ```
82     /// # use libceed::prelude::*;
83     /// # fn main() -> libceed::Result<()> {
84     /// # let ceed = libceed::Ceed::default_init();
85     /// const Q: usize = 8;
86     /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?;
87     ///
88     /// let inputs = qf.inputs()?;
89     ///
90     /// assert_eq!(
91     ///     inputs[0].eval_mode(),
92     ///     EvalMode::Grad,
93     ///     "Incorrect input evaluation mode"
94     /// );
95     /// assert_eq!(
96     ///     inputs[1].eval_mode(),
97     ///     EvalMode::Weight,
98     ///     "Incorrect input evaluation mode"
99     /// );
100     /// # Ok(())
101     /// # }
102     /// ```
103     pub fn eval_mode(&self) -> crate::EvalMode {
104         let mut mode = 0;
105         unsafe {
106             bind_ceed::CeedQFunctionFieldGetEvalMode(self.ptr, &mut mode);
107         }
108         crate::EvalMode::from_u32(mode as u32)
109     }
110 }
111 
112 // -----------------------------------------------------------------------------
113 // QFunction option
114 // -----------------------------------------------------------------------------
115 pub enum QFunctionOpt<'a> {
116     SomeQFunction(&'a QFunction<'a>),
117     SomeQFunctionByName(&'a QFunctionByName<'a>),
118     None,
119 }
120 
121 /// Construct a QFunctionOpt reference from a QFunction reference
122 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> {
123     fn from(qfunc: &'a QFunction) -> Self {
124         debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
125         Self::SomeQFunction(qfunc)
126     }
127 }
128 
129 /// Construct a QFunctionOpt reference from a QFunction by Name reference
130 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> {
131     fn from(qfunc: &'a QFunctionByName) -> Self {
132         debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
133         Self::SomeQFunctionByName(qfunc)
134     }
135 }
136 
137 impl<'a> QFunctionOpt<'a> {
138     /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction
139     pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction {
140         match self {
141             Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr,
142             Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr,
143             Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE },
144         }
145     }
146 
147     /// Check if a QFunctionOpt is Some
148     ///
149     /// ```
150     /// # use libceed::prelude::*;
151     /// # fn main() -> libceed::Result<()> {
152     /// # let ceed = libceed::Ceed::default_init();
153     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
154     ///     // Iterate over quadrature points
155     ///     v.iter_mut()
156     ///         .zip(u.iter().zip(weights.iter()))
157     ///         .for_each(|(v, (u, w))| *v = u * w);
158     ///
159     ///     // Return clean error code
160     ///     0
161     /// };
162     ///
163     /// let qf = ceed
164     ///     .q_function_interior(1, Box::new(user_f))?
165     ///     .input("u", 1, EvalMode::Interp)?
166     ///     .input("weights", 1, EvalMode::Weight)?
167     ///     .output("v", 1, EvalMode::Interp)?;
168     /// let qf_opt = QFunctionOpt::from(&qf);
169     /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt");
170     ///
171     /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
172     /// let qf_opt = QFunctionOpt::from(&qf);
173     /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt");
174     ///
175     /// let qf_opt = QFunctionOpt::None;
176     /// assert!(!qf_opt.is_some(), "Incorrect QFunctionOpt");
177     /// # Ok(())
178     /// # }
179     /// ```
180     pub fn is_some(&self) -> bool {
181         match self {
182             Self::SomeQFunction(_) => true,
183             Self::SomeQFunctionByName(_) => true,
184             Self::None => false,
185         }
186     }
187 
188     /// Check if a QFunctionOpt is SomeQFunction
189     ///
190     /// ```
191     /// # use libceed::prelude::*;
192     /// # fn main() -> libceed::Result<()> {
193     /// # let ceed = libceed::Ceed::default_init();
194     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
195     ///     // Iterate over quadrature points
196     ///     v.iter_mut()
197     ///         .zip(u.iter().zip(weights.iter()))
198     ///         .for_each(|(v, (u, w))| *v = u * w);
199     ///
200     ///     // Return clean error code
201     ///     0
202     /// };
203     ///
204     /// let qf = ceed
205     ///     .q_function_interior(1, Box::new(user_f))?
206     ///     .input("u", 1, EvalMode::Interp)?
207     ///     .input("weights", 1, EvalMode::Weight)?
208     ///     .output("v", 1, EvalMode::Interp)?;
209     /// let qf_opt = QFunctionOpt::from(&qf);
210     /// assert!(qf_opt.is_some_q_function(), "Incorrect QFunctionOpt");
211     ///
212     /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
213     /// let qf_opt = QFunctionOpt::from(&qf);
214     /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt");
215     ///
216     /// let qf_opt = QFunctionOpt::None;
217     /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt");
218     /// # Ok(())
219     /// # }
220     /// ```
221     pub fn is_some_q_function(&self) -> bool {
222         match self {
223             Self::SomeQFunction(_) => true,
224             Self::SomeQFunctionByName(_) => false,
225             Self::None => false,
226         }
227     }
228 
229     /// Check if a QFunctionOpt is SomeQFunctionByName
230     ///
231     /// ```
232     /// # use libceed::prelude::*;
233     /// # fn main() -> libceed::Result<()> {
234     /// # let ceed = libceed::Ceed::default_init();
235     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
236     ///     // Iterate over quadrature points
237     ///     v.iter_mut()
238     ///         .zip(u.iter().zip(weights.iter()))
239     ///         .for_each(|(v, (u, w))| *v = u * w);
240     ///
241     ///     // Return clean error code
242     ///     0
243     /// };
244     ///
245     /// let qf = ceed
246     ///     .q_function_interior(1, Box::new(user_f))?
247     ///     .input("u", 1, EvalMode::Interp)?
248     ///     .input("weights", 1, EvalMode::Weight)?
249     ///     .output("v", 1, EvalMode::Interp)?;
250     /// let qf_opt = QFunctionOpt::from(&qf);
251     /// assert!(
252     ///     !qf_opt.is_some_q_function_by_name(),
253     ///     "Incorrect QFunctionOpt"
254     /// );
255     ///
256     /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
257     /// let qf_opt = QFunctionOpt::from(&qf);
258     /// assert!(
259     ///     qf_opt.is_some_q_function_by_name(),
260     ///     "Incorrect QFunctionOpt"
261     /// );
262     ///
263     /// let qf_opt = QFunctionOpt::None;
264     /// assert!(
265     ///     !qf_opt.is_some_q_function_by_name(),
266     ///     "Incorrect QFunctionOpt"
267     /// );
268     /// # Ok(())
269     /// # }
270     /// ```
271     pub fn is_some_q_function_by_name(&self) -> bool {
272         match self {
273             Self::SomeQFunction(_) => false,
274             Self::SomeQFunctionByName(_) => true,
275             Self::None => false,
276         }
277     }
278 
279     /// Check if a QFunctionOpt is None
280     ///
281     /// ```
282     /// # use libceed::prelude::*;
283     /// # fn main() -> libceed::Result<()> {
284     /// # let ceed = libceed::Ceed::default_init();
285     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
286     ///     // Iterate over quadrature points
287     ///     v.iter_mut()
288     ///         .zip(u.iter().zip(weights.iter()))
289     ///         .for_each(|(v, (u, w))| *v = u * w);
290     ///
291     ///     // Return clean error code
292     ///     0
293     /// };
294     ///
295     /// let qf = ceed
296     ///     .q_function_interior(1, Box::new(user_f))?
297     ///     .input("u", 1, EvalMode::Interp)?
298     ///     .input("weights", 1, EvalMode::Weight)?
299     ///     .output("v", 1, EvalMode::Interp)?;
300     /// let qf_opt = QFunctionOpt::from(&qf);
301     /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt");
302     ///
303     /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
304     /// let qf_opt = QFunctionOpt::from(&qf);
305     /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt");
306     ///
307     /// let qf_opt = QFunctionOpt::None;
308     /// assert!(qf_opt.is_none(), "Incorrect QFunctionOpt");
309     /// # Ok(())
310     /// # }
311     /// ```
312     pub fn is_none(&self) -> bool {
313         match self {
314             Self::SomeQFunction(_) => false,
315             Self::SomeQFunctionByName(_) => false,
316             Self::None => true,
317         }
318     }
319 }
320 
321 // -----------------------------------------------------------------------------
322 // QFunction context wrapper
323 // -----------------------------------------------------------------------------
324 #[derive(Debug)]
325 pub(crate) struct QFunctionCore<'a> {
326     ptr: bind_ceed::CeedQFunction,
327     _lifeline: PhantomData<&'a ()>,
328 }
329 
330 struct QFunctionTrampolineData {
331     number_inputs: usize,
332     number_outputs: usize,
333     input_sizes: [usize; MAX_QFUNCTION_FIELDS],
334     output_sizes: [usize; MAX_QFUNCTION_FIELDS],
335     user_f: Box<QFunctionUserClosure>,
336 }
337 
338 pub struct QFunction<'a> {
339     qf_core: QFunctionCore<'a>,
340     qf_ctx_ptr: bind_ceed::CeedQFunctionContext,
341     trampoline_data: Pin<Box<QFunctionTrampolineData>>,
342 }
343 
344 #[derive(Debug)]
345 pub struct QFunctionByName<'a> {
346     qf_core: QFunctionCore<'a>,
347 }
348 
349 // -----------------------------------------------------------------------------
350 // Destructor
351 // -----------------------------------------------------------------------------
352 impl<'a> Drop for QFunctionCore<'a> {
353     fn drop(&mut self) {
354         unsafe {
355             if self.ptr != bind_ceed::CEED_QFUNCTION_NONE {
356                 bind_ceed::CeedQFunctionDestroy(&mut self.ptr);
357             }
358         }
359     }
360 }
361 
362 impl<'a> Drop for QFunction<'a> {
363     fn drop(&mut self) {
364         unsafe {
365             bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr);
366         }
367     }
368 }
369 
370 // -----------------------------------------------------------------------------
371 // Display
372 // -----------------------------------------------------------------------------
373 impl<'a> fmt::Display for QFunctionCore<'a> {
374     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
375         let mut ptr = std::ptr::null_mut();
376         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
377         let cstring = unsafe {
378             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
379             bind_ceed::CeedQFunctionView(self.ptr, file);
380             bind_ceed::fclose(file);
381             CString::from_raw(ptr)
382         };
383         cstring.to_string_lossy().fmt(f)
384     }
385 }
386 /// View a QFunction
387 ///
388 /// ```
389 /// # use libceed::prelude::*;
390 /// # fn main() -> libceed::Result<()> {
391 /// # let ceed = libceed::Ceed::default_init();
392 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
393 ///     // Iterate over quadrature points
394 ///     v.iter_mut()
395 ///         .zip(u.iter().zip(weights.iter()))
396 ///         .for_each(|(v, (u, w))| *v = u * w);
397 ///
398 ///     // Return clean error code
399 ///     0
400 /// };
401 ///
402 /// let qf = ceed
403 ///     .q_function_interior(1, Box::new(user_f))?
404 ///     .input("u", 1, EvalMode::Interp)?
405 ///     .input("weights", 1, EvalMode::Weight)?
406 ///     .output("v", 1, EvalMode::Interp)?;
407 ///
408 /// println!("{}", qf);
409 /// # Ok(())
410 /// # }
411 /// ```
412 impl<'a> fmt::Display for QFunction<'a> {
413     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
414         self.qf_core.fmt(f)
415     }
416 }
417 
418 /// View a QFunction by Name
419 ///
420 /// ```
421 /// # use libceed::prelude::*;
422 /// # fn main() -> libceed::Result<()> {
423 /// # let ceed = libceed::Ceed::default_init();
424 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
425 /// println!("{}", qf);
426 /// # Ok(())
427 /// # }
428 /// ```
429 impl<'a> fmt::Display for QFunctionByName<'a> {
430     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
431         self.qf_core.fmt(f)
432     }
433 }
434 
435 // -----------------------------------------------------------------------------
436 // Core functionality
437 // -----------------------------------------------------------------------------
438 impl<'a> QFunctionCore<'a> {
439     // Error handling
440     #[doc(hidden)]
441     fn check_error(&self, ierr: i32) -> crate::Result<i32> {
442         let mut ptr = std::ptr::null_mut();
443         unsafe {
444             bind_ceed::CeedQFunctionGetCeed(self.ptr, &mut ptr);
445         }
446         crate::check_error(ptr, ierr)
447     }
448 
449     // Common implementation
450     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
451         let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
452         for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) {
453             u_c[i] = u[i].ptr;
454         }
455         let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
456         for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) {
457             v_c[i] = v[i].ptr;
458         }
459         let Q = i32::try_from(Q).unwrap();
460         let ierr = unsafe {
461             bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr())
462         };
463         self.check_error(ierr)
464     }
465 
466     pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> {
467         // Get array of raw C pointers for inputs
468         let mut num_inputs = 0;
469         let mut inputs_ptr = std::ptr::null_mut();
470         let ierr = unsafe {
471             bind_ceed::CeedQFunctionGetFields(
472                 self.ptr,
473                 &mut num_inputs,
474                 &mut inputs_ptr,
475                 std::ptr::null_mut() as *mut bind_ceed::CeedInt,
476                 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField,
477             )
478         };
479         self.check_error(ierr)?;
480         // Convert raw C pointers to fixed length slice
481         let inputs_slice = unsafe {
482             std::slice::from_raw_parts(
483                 inputs_ptr as *const crate::QFunctionField,
484                 num_inputs as usize,
485             )
486         };
487         Ok(inputs_slice)
488     }
489 
490     pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> {
491         // Get array of raw C pointers for outputs
492         let mut num_outputs = 0;
493         let mut outputs_ptr = std::ptr::null_mut();
494         let ierr = unsafe {
495             bind_ceed::CeedQFunctionGetFields(
496                 self.ptr,
497                 std::ptr::null_mut() as *mut bind_ceed::CeedInt,
498                 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField,
499                 &mut num_outputs,
500                 &mut outputs_ptr,
501             )
502         };
503         self.check_error(ierr)?;
504         // Convert raw C pointers to fixed length slice
505         let outputs_slice = unsafe {
506             std::slice::from_raw_parts(
507                 outputs_ptr as *const crate::QFunctionField,
508                 num_outputs as usize,
509             )
510         };
511         Ok(outputs_slice)
512     }
513 }
514 
515 // -----------------------------------------------------------------------------
516 // User QFunction Closure
517 // -----------------------------------------------------------------------------
518 pub type QFunctionUserClosure = dyn FnMut(
519     [&[crate::Scalar]; MAX_QFUNCTION_FIELDS],
520     [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS],
521 ) -> i32;
522 
523 macro_rules! mut_max_fields {
524     ($e:expr) => {
525         [
526             $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e,
527         ]
528     };
529 }
530 unsafe extern "C" fn trampoline(
531     ctx: *mut ::std::os::raw::c_void,
532     q: bind_ceed::CeedInt,
533     inputs: *const *const bind_ceed::CeedScalar,
534     outputs: *const *mut bind_ceed::CeedScalar,
535 ) -> ::std::os::raw::c_int {
536     let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx);
537 
538     // Inputs
539     let inputs_slice: &[*const bind_ceed::CeedScalar] =
540         std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS);
541     let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS];
542     inputs_slice
543         .iter()
544         .enumerate()
545         .map(|(i, &x)| {
546             std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize)
547                 as &[crate::Scalar]
548         })
549         .zip(inputs_array.iter_mut())
550         .for_each(|(x, a)| *a = x);
551 
552     // Outputs
553     let outputs_slice: &[*mut bind_ceed::CeedScalar] =
554         std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS);
555     let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] =
556         mut_max_fields!(&mut [0.0]);
557     outputs_slice
558         .iter()
559         .enumerate()
560         .map(|(i, &x)| {
561             std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize)
562                 as &mut [crate::Scalar]
563         })
564         .zip(outputs_array.iter_mut())
565         .for_each(|(x, a)| *a = x);
566 
567     // User closure
568     (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array)
569 }
570 
571 // -----------------------------------------------------------------------------
572 // QFunction
573 // -----------------------------------------------------------------------------
574 impl<'a> QFunction<'a> {
575     // Constructor
576     pub fn create(
577         ceed: &crate::Ceed,
578         vlength: usize,
579         user_f: Box<QFunctionUserClosure>,
580     ) -> crate::Result<Self> {
581         let source_c = CString::new("").expect("CString::new failed");
582         let mut ptr = std::ptr::null_mut();
583 
584         // Context for closure
585         let number_inputs = 0;
586         let number_outputs = 0;
587         let input_sizes = [0; MAX_QFUNCTION_FIELDS];
588         let output_sizes = [0; MAX_QFUNCTION_FIELDS];
589         let trampoline_data = unsafe {
590             Pin::new_unchecked(Box::new(QFunctionTrampolineData {
591                 number_inputs,
592                 number_outputs,
593                 input_sizes,
594                 output_sizes,
595                 user_f,
596             }))
597         };
598 
599         // Create QFunction
600         let vlength = i32::try_from(vlength).unwrap();
601         let mut ierr = unsafe {
602             bind_ceed::CeedQFunctionCreateInterior(
603                 ceed.ptr,
604                 vlength,
605                 Some(trampoline),
606                 source_c.as_ptr(),
607                 &mut ptr,
608             )
609         };
610         ceed.check_error(ierr)?;
611 
612         // Set closure
613         let mut qf_ctx_ptr = std::ptr::null_mut();
614         ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) };
615         ceed.check_error(ierr)?;
616         ierr = unsafe {
617             bind_ceed::CeedQFunctionContextSetData(
618                 qf_ctx_ptr,
619                 crate::MemType::Host as bind_ceed::CeedMemType,
620                 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode,
621                 std::mem::size_of::<QFunctionTrampolineData>() as u64,
622                 std::mem::transmute(trampoline_data.as_ref()),
623             )
624         };
625         ceed.check_error(ierr)?;
626         ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) };
627         ceed.check_error(ierr)?;
628         Ok(Self {
629             qf_core: QFunctionCore {
630                 ptr,
631                 _lifeline: PhantomData,
632             },
633             qf_ctx_ptr,
634             trampoline_data,
635         })
636     }
637 
638     /// Apply the action of a QFunction
639     ///
640     /// * `Q`      - The number of quadrature points
641     /// * `input`  - Array of input Vectors
642     /// * `output` - Array of output Vectors
643     ///
644     /// ```
645     /// # use libceed::prelude::*;
646     /// # fn main() -> libceed::Result<()> {
647     /// # let ceed = libceed::Ceed::default_init();
648     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
649     ///     // Iterate over quadrature points
650     ///     v.iter_mut()
651     ///         .zip(u.iter().zip(weights.iter()))
652     ///         .for_each(|(v, (u, w))| *v = u * w);
653     ///
654     ///     // Return clean error code
655     ///     0
656     /// };
657     ///
658     /// let qf = ceed
659     ///     .q_function_interior(1, Box::new(user_f))?
660     ///     .input("u", 1, EvalMode::Interp)?
661     ///     .input("weights", 1, EvalMode::Weight)?
662     ///     .output("v", 1, EvalMode::Interp)?;
663     ///
664     /// const Q: usize = 8;
665     /// let mut w = [0.; Q];
666     /// let mut u = [0.; Q];
667     /// let mut v = [0.; Q];
668     ///
669     /// for i in 0..Q {
670     ///     let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.;
671     ///     u[i] = 2. + 3. * x + 5. * x * x;
672     ///     w[i] = 1. - x * x;
673     ///     v[i] = u[i] * w[i];
674     /// }
675     ///
676     /// let uu = ceed.vector_from_slice(&u)?;
677     /// let ww = ceed.vector_from_slice(&w)?;
678     /// let mut vv = ceed.vector(Q)?;
679     /// vv.set_value(0.0);
680     /// {
681     ///     let input = vec![uu, ww];
682     ///     let mut output = vec![vv];
683     ///     qf.apply(Q, &input, &output)?;
684     ///     vv = output.remove(0);
685     /// }
686     ///
687     /// vv.view()?
688     ///     .iter()
689     ///     .zip(v.iter())
690     ///     .for_each(|(computed, actual)| {
691     ///         assert_eq!(
692     ///             *computed, *actual,
693     ///             "Incorrect value in QFunction application"
694     ///         );
695     ///     });
696     /// # Ok(())
697     /// # }
698     /// ```
699     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
700         self.qf_core.apply(Q, u, v)
701     }
702 
703     /// Add a QFunction input
704     ///
705     /// * `fieldname` - Name of QFunction field
706     /// * `size`      - Size of QFunction field, `(ncomp * dim)` for `Grad` or
707     ///                   `(ncomp * 1)` for `None`, `Interp`, and `Weight`
708     /// * `emode`     - `EvalMode::None` to use values directly, `EvalMode::Interp`
709     ///                   to use interpolated values, `EvalMode::Grad` to use
710     ///                   gradients, `EvalMode::Weight` to use quadrature weights
711     ///
712     /// ```
713     /// # use libceed::prelude::*;
714     /// # fn main() -> libceed::Result<()> {
715     /// # let ceed = libceed::Ceed::default_init();
716     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
717     ///     // Iterate over quadrature points
718     ///     v.iter_mut()
719     ///         .zip(u.iter().zip(weights.iter()))
720     ///         .for_each(|(v, (u, w))| *v = u * w);
721     ///
722     ///     // Return clean error code
723     ///     0
724     /// };
725     ///
726     /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?;
727     ///
728     /// qf = qf.input("u", 1, EvalMode::Interp)?;
729     /// qf = qf.input("weights", 1, EvalMode::Weight)?;
730     /// # Ok(())
731     /// # }
732     /// ```
733     pub fn input(
734         mut self,
735         fieldname: &str,
736         size: usize,
737         emode: crate::EvalMode,
738     ) -> crate::Result<Self> {
739         let name_c = CString::new(fieldname).expect("CString::new failed");
740         let idx = self.trampoline_data.number_inputs;
741         self.trampoline_data.input_sizes[idx] = size;
742         self.trampoline_data.number_inputs += 1;
743         let (size, emode) = (
744             i32::try_from(size).unwrap(),
745             emode as bind_ceed::CeedEvalMode,
746         );
747         let ierr = unsafe {
748             bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
749         };
750         self.qf_core.check_error(ierr)?;
751         Ok(self)
752     }
753 
754     /// Add a QFunction output
755     ///
756     /// * `fieldname` - Name of QFunction field
757     /// * `size`      - Size of QFunction field, `(ncomp * dim)` for `Grad` or
758     ///                   `(ncomp * 1)` for `None` and `Interp`
759     /// * `emode`     - `EvalMode::None` to use values directly, `EvalMode::Interp`
760     ///                   to use interpolated values, `EvalMode::Grad` to use
761     ///                   gradients
762     ///
763     /// ```
764     /// # use libceed::prelude::*;
765     /// # fn main() -> libceed::Result<()> {
766     /// # let ceed = libceed::Ceed::default_init();
767     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
768     ///     // Iterate over quadrature points
769     ///     v.iter_mut()
770     ///         .zip(u.iter().zip(weights.iter()))
771     ///         .for_each(|(v, (u, w))| *v = u * w);
772     ///
773     ///     // Return clean error code
774     ///     0
775     /// };
776     ///
777     /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?;
778     ///
779     /// qf.output("v", 1, EvalMode::Interp)?;
780     /// # Ok(())
781     /// # }
782     /// ```
783     pub fn output(
784         mut self,
785         fieldname: &str,
786         size: usize,
787         emode: crate::EvalMode,
788     ) -> crate::Result<Self> {
789         let name_c = CString::new(fieldname).expect("CString::new failed");
790         let idx = self.trampoline_data.number_outputs;
791         self.trampoline_data.output_sizes[idx] = size;
792         self.trampoline_data.number_outputs += 1;
793         let (size, emode) = (
794             i32::try_from(size).unwrap(),
795             emode as bind_ceed::CeedEvalMode,
796         );
797         let ierr = unsafe {
798             bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
799         };
800         self.qf_core.check_error(ierr)?;
801         Ok(self)
802     }
803 
804     /// Get a slice of QFunction inputs
805     ///
806     /// ```
807     /// # use libceed::prelude::*;
808     /// # fn main() -> libceed::Result<()> {
809     /// # let ceed = libceed::Ceed::default_init();
810     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
811     ///     // Iterate over quadrature points
812     ///     v.iter_mut()
813     ///         .zip(u.iter().zip(weights.iter()))
814     ///         .for_each(|(v, (u, w))| *v = u * w);
815     ///
816     ///     // Return clean error code
817     ///     0
818     /// };
819     ///
820     /// let mut qf = ceed
821     ///     .q_function_interior(1, Box::new(user_f))?
822     ///     .input("u", 1, EvalMode::Interp)?
823     ///     .input("weights", 1, EvalMode::Weight)?;
824     ///
825     /// let inputs = qf.inputs()?;
826     ///
827     /// assert_eq!(inputs.len(), 2, "Incorrect inputs array");
828     /// # Ok(())
829     /// # }
830     /// ```
831     pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> {
832         self.qf_core.inputs()
833     }
834 
835     /// Get a slice of QFunction outputs
836     ///
837     /// ```
838     /// # use libceed::prelude::*;
839     /// # fn main() -> libceed::Result<()> {
840     /// # let ceed = libceed::Ceed::default_init();
841     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
842     ///     // Iterate over quadrature points
843     ///     v.iter_mut()
844     ///         .zip(u.iter().zip(weights.iter()))
845     ///         .for_each(|(v, (u, w))| *v = u * w);
846     ///
847     ///     // Return clean error code
848     ///     0
849     /// };
850     ///
851     /// let mut qf = ceed
852     ///     .q_function_interior(1, Box::new(user_f))?
853     ///     .output("v", 1, EvalMode::Interp)?;
854     ///
855     /// let outputs = qf.outputs()?;
856     ///
857     /// assert_eq!(outputs.len(), 1, "Incorrect outputs array");
858     /// # Ok(())
859     /// # }
860     /// ```
861     pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> {
862         self.qf_core.outputs()
863     }
864 }
865 
866 // -----------------------------------------------------------------------------
867 // QFunction
868 // -----------------------------------------------------------------------------
869 impl<'a> QFunctionByName<'a> {
870     // Constructor
871     pub fn create(ceed: &crate::Ceed, name: &str) -> crate::Result<Self> {
872         let name_c = CString::new(name).expect("CString::new failed");
873         let mut ptr = std::ptr::null_mut();
874         let ierr = unsafe {
875             bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr)
876         };
877         ceed.check_error(ierr)?;
878         Ok(Self {
879             qf_core: QFunctionCore {
880                 ptr,
881                 _lifeline: PhantomData,
882             },
883         })
884     }
885 
886     /// Apply the action of a QFunction
887     ///
888     /// * `Q`      - The number of quadrature points
889     /// * `input`  - Array of input Vectors
890     /// * `output` - Array of output Vectors
891     ///
892     /// ```
893     /// # use libceed::prelude::*;
894     /// # fn main() -> libceed::Result<()> {
895     /// # let ceed = libceed::Ceed::default_init();
896     /// const Q: usize = 8;
897     /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild")?;
898     /// let qf_mass = ceed.q_function_interior_by_name("MassApply")?;
899     ///
900     /// let mut j = [0.; Q];
901     /// let mut w = [0.; Q];
902     /// let mut u = [0.; Q];
903     /// let mut v = [0.; Q];
904     ///
905     /// for i in 0..Q {
906     ///     let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.;
907     ///     j[i] = 1.;
908     ///     w[i] = 1. - x * x;
909     ///     u[i] = 2. + 3. * x + 5. * x * x;
910     ///     v[i] = w[i] * u[i];
911     /// }
912     ///
913     /// let jj = ceed.vector_from_slice(&j)?;
914     /// let ww = ceed.vector_from_slice(&w)?;
915     /// let uu = ceed.vector_from_slice(&u)?;
916     /// let mut vv = ceed.vector(Q)?;
917     /// vv.set_value(0.0);
918     /// let mut qdata = ceed.vector(Q)?;
919     /// qdata.set_value(0.0);
920     ///
921     /// {
922     ///     let mut input = vec![jj, ww];
923     ///     let mut output = vec![qdata];
924     ///     qf_build.apply(Q, &input, &output)?;
925     ///     qdata = output.remove(0);
926     /// }
927     ///
928     /// {
929     ///     let mut input = vec![qdata, uu];
930     ///     let mut output = vec![vv];
931     ///     qf_mass.apply(Q, &input, &output)?;
932     ///     vv = output.remove(0);
933     /// }
934     ///
935     /// vv.view()?
936     ///     .iter()
937     ///     .zip(v.iter())
938     ///     .for_each(|(computed, actual)| {
939     ///         assert_eq!(
940     ///             *computed, *actual,
941     ///             "Incorrect value in QFunction application"
942     ///         );
943     ///     });
944     /// # Ok(())
945     /// # }
946     /// ```
947     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
948         self.qf_core.apply(Q, u, v)
949     }
950 
951     /// Get a slice of QFunction inputs
952     ///
953     /// ```
954     /// # use libceed::prelude::*;
955     /// # fn main() -> libceed::Result<()> {
956     /// # let ceed = libceed::Ceed::default_init();
957     /// const Q: usize = 8;
958     /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
959     ///
960     /// let inputs = qf.inputs()?;
961     ///
962     /// assert_eq!(inputs.len(), 2, "Incorrect inputs array");
963     /// # Ok(())
964     /// # }
965     /// ```
966     pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> {
967         self.qf_core.inputs()
968     }
969 
970     /// Get a slice of QFunction outputs
971     ///
972     /// ```
973     /// # use libceed::prelude::*;
974     /// # fn main() -> libceed::Result<()> {
975     /// # let ceed = libceed::Ceed::default_init();
976     /// const Q: usize = 8;
977     /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
978     ///
979     /// let outputs = qf.outputs()?;
980     ///
981     /// assert_eq!(outputs.len(), 1, "Incorrect outputs array");
982     /// # Ok(())
983     /// # }
984     /// ```
985     pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> {
986         self.qf_core.outputs()
987     }
988 }
989 
990 // -----------------------------------------------------------------------------
991