xref: /libCEED/rust/libceed/src/qfunction.rs (revision c68be7a2e45631197b626561fad53d5b146fcd59)
1 // Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
2 // the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
3 // reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative
16 
17 //! A Ceed QFunction represents the spatial terms of the point-wise functions
18 //! describing the physics at the quadrature points.
19 
20 use std::pin::Pin;
21 
22 use crate::prelude::*;
23 
24 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS];
25 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS];
26 
27 // -----------------------------------------------------------------------------
28 // CeedQFunction option
29 // -----------------------------------------------------------------------------
30 pub enum QFunctionOpt<'a> {
31     SomeQFunction(&'a QFunction<'a>),
32     SomeQFunctionByName(&'a QFunctionByName<'a>),
33     None,
34 }
35 
36 /// Construct a QFunctionOpt reference from a QFunction reference
37 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> {
38     fn from(qfunc: &'a QFunction) -> Self {
39         debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
40         Self::SomeQFunction(qfunc)
41     }
42 }
43 
44 /// Construct a QFunctionOpt reference from a QFunction by Name reference
45 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> {
46     fn from(qfunc: &'a QFunctionByName) -> Self {
47         debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
48         Self::SomeQFunctionByName(qfunc)
49     }
50 }
51 
52 impl<'a> QFunctionOpt<'a> {
53     /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction
54     pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction {
55         match self {
56             Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr,
57             Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr,
58             Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE },
59         }
60     }
61 }
62 
63 // -----------------------------------------------------------------------------
64 // CeedQFunction context wrapper
65 // -----------------------------------------------------------------------------
66 #[derive(Debug)]
67 pub(crate) struct QFunctionCore<'a> {
68     ceed: &'a crate::Ceed,
69     ptr: bind_ceed::CeedQFunction,
70 }
71 
72 struct QFunctionTrampolineData {
73     number_inputs: usize,
74     number_outputs: usize,
75     input_sizes: [usize; MAX_QFUNCTION_FIELDS],
76     output_sizes: [usize; MAX_QFUNCTION_FIELDS],
77     user_f: Box<QFunctionUserClosure>,
78 }
79 
80 pub struct QFunction<'a> {
81     qf_core: QFunctionCore<'a>,
82     qf_ctx_ptr: bind_ceed::CeedQFunctionContext,
83     trampoline_data: Pin<Box<QFunctionTrampolineData>>,
84 }
85 
86 #[derive(Debug)]
87 pub struct QFunctionByName<'a> {
88     qf_core: QFunctionCore<'a>,
89 }
90 
91 // -----------------------------------------------------------------------------
92 // Destructor
93 // -----------------------------------------------------------------------------
94 impl<'a> Drop for QFunctionCore<'a> {
95     fn drop(&mut self) {
96         unsafe {
97             if self.ptr != bind_ceed::CEED_QFUNCTION_NONE {
98                 bind_ceed::CeedQFunctionDestroy(&mut self.ptr);
99             }
100         }
101     }
102 }
103 
104 impl<'a> Drop for QFunction<'a> {
105     fn drop(&mut self) {
106         unsafe {
107             bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr);
108         }
109     }
110 }
111 
112 // -----------------------------------------------------------------------------
113 // Display
114 // -----------------------------------------------------------------------------
115 impl<'a> fmt::Display for QFunctionCore<'a> {
116     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117         let mut ptr = std::ptr::null_mut();
118         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
119         let cstring = unsafe {
120             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
121             bind_ceed::CeedQFunctionView(self.ptr, file);
122             bind_ceed::fclose(file);
123             CString::from_raw(ptr)
124         };
125         cstring.to_string_lossy().fmt(f)
126     }
127 }
128 /// View a QFunction
129 ///
130 /// ```
131 /// # use libceed::prelude::*;
132 /// # fn main() -> Result<(), libceed::CeedError> {
133 /// # let ceed = libceed::Ceed::default_init();
134 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
135 ///     // Iterate over quadrature points
136 ///     v.iter_mut()
137 ///         .zip(u.iter().zip(weights.iter()))
138 ///         .for_each(|(v, (u, w))| *v = u * w);
139 ///
140 ///     // Return clean error code
141 ///     0
142 /// };
143 ///
144 /// let qf = ceed
145 ///     .q_function_interior(1, Box::new(user_f))?
146 ///     .input("u", 1, EvalMode::Interp)?
147 ///     .input("weights", 1, EvalMode::Weight)?
148 ///     .output("v", 1, EvalMode::Interp)?;
149 ///
150 /// println!("{}", qf);
151 /// # Ok(())
152 /// # }
153 /// ```
154 impl<'a> fmt::Display for QFunction<'a> {
155     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
156         self.qf_core.fmt(f)
157     }
158 }
159 
160 /// View a QFunction by Name
161 ///
162 /// ```
163 /// # use libceed::prelude::*;
164 /// # fn main() -> Result<(), libceed::CeedError> {
165 /// # let ceed = libceed::Ceed::default_init();
166 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
167 /// println!("{}", qf);
168 /// # Ok(())
169 /// # }
170 /// ```
171 impl<'a> fmt::Display for QFunctionByName<'a> {
172     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
173         self.qf_core.fmt(f)
174     }
175 }
176 
177 // -----------------------------------------------------------------------------
178 // Core functionality
179 // -----------------------------------------------------------------------------
180 impl<'a> QFunctionCore<'a> {
181     // Common implementation
182     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
183         let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
184         for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) {
185             u_c[i] = u[i].ptr;
186         }
187         let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
188         for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) {
189             v_c[i] = v[i].ptr;
190         }
191         let Q = i32::try_from(Q).unwrap();
192         let ierr = unsafe {
193             bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr())
194         };
195         self.ceed.check_error(ierr)
196     }
197 }
198 
199 // -----------------------------------------------------------------------------
200 // User QFunction Closure
201 // -----------------------------------------------------------------------------
202 pub type QFunctionUserClosure = dyn FnMut(
203     [&[crate::Scalar]; MAX_QFUNCTION_FIELDS],
204     [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS],
205 ) -> i32;
206 
207 macro_rules! mut_max_fields {
208     ($e:expr) => {
209         [
210             $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e,
211         ]
212     };
213 }
214 unsafe extern "C" fn trampoline(
215     ctx: *mut ::std::os::raw::c_void,
216     q: bind_ceed::CeedInt,
217     inputs: *const *const bind_ceed::CeedScalar,
218     outputs: *const *mut bind_ceed::CeedScalar,
219 ) -> ::std::os::raw::c_int {
220     let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx);
221 
222     // Inputs
223     let inputs_slice: &[*const bind_ceed::CeedScalar] =
224         std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS);
225     let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS];
226     inputs_slice
227         .iter()
228         .enumerate()
229         .map(|(i, &x)| {
230             std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize)
231                 as &[crate::Scalar]
232         })
233         .zip(inputs_array.iter_mut())
234         .for_each(|(x, a)| *a = x);
235 
236     // Outputs
237     let outputs_slice: &[*mut bind_ceed::CeedScalar] =
238         std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS);
239     let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] =
240         mut_max_fields!(&mut [0.0]);
241     outputs_slice
242         .iter()
243         .enumerate()
244         .map(|(i, &x)| {
245             std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize)
246                 as &mut [crate::Scalar]
247         })
248         .zip(outputs_array.iter_mut())
249         .for_each(|(x, a)| *a = x);
250 
251     // User closure
252     (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array)
253 }
254 
255 // -----------------------------------------------------------------------------
256 // QFunction
257 // -----------------------------------------------------------------------------
258 impl<'a> QFunction<'a> {
259     // Constructor
260     pub fn create(
261         ceed: &'a crate::Ceed,
262         vlength: usize,
263         user_f: Box<QFunctionUserClosure>,
264     ) -> crate::Result<Self> {
265         let source_c = CString::new("").expect("CString::new failed");
266         let mut ptr = std::ptr::null_mut();
267 
268         // Context for closure
269         let number_inputs = 0;
270         let number_outputs = 0;
271         let input_sizes = [0; MAX_QFUNCTION_FIELDS];
272         let output_sizes = [0; MAX_QFUNCTION_FIELDS];
273         let trampoline_data = unsafe {
274             Pin::new_unchecked(Box::new(QFunctionTrampolineData {
275                 number_inputs,
276                 number_outputs,
277                 input_sizes,
278                 output_sizes,
279                 user_f,
280             }))
281         };
282 
283         // Create QFunction
284         let vlength = i32::try_from(vlength).unwrap();
285         let mut ierr = unsafe {
286             bind_ceed::CeedQFunctionCreateInterior(
287                 ceed.ptr,
288                 vlength,
289                 Some(trampoline),
290                 source_c.as_ptr(),
291                 &mut ptr,
292             )
293         };
294         ceed.check_error(ierr)?;
295 
296         // Set closure
297         let mut qf_ctx_ptr = std::ptr::null_mut();
298         ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) };
299         ceed.check_error(ierr)?;
300         ierr = unsafe {
301             bind_ceed::CeedQFunctionContextSetData(
302                 qf_ctx_ptr,
303                 crate::MemType::Host as bind_ceed::CeedMemType,
304                 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode,
305                 std::mem::size_of::<QFunctionTrampolineData>() as u64,
306                 std::mem::transmute(trampoline_data.as_ref()),
307             )
308         };
309         ceed.check_error(ierr)?;
310         ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) };
311         ceed.check_error(ierr)?;
312         Ok(Self {
313             qf_core: QFunctionCore { ceed, ptr },
314             qf_ctx_ptr,
315             trampoline_data,
316         })
317     }
318 
319     /// Apply the action of a QFunction
320     ///
321     /// * `Q`      - The number of quadrature points
322     /// * `input`  - Array of input Vectors
323     /// * `output` - Array of output Vectors
324     ///
325     /// ```
326     /// # use libceed::prelude::*;
327     /// # fn main() -> Result<(), libceed::CeedError> {
328     /// # let ceed = libceed::Ceed::default_init();
329     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
330     ///     // Iterate over quadrature points
331     ///     v.iter_mut()
332     ///         .zip(u.iter().zip(weights.iter()))
333     ///         .for_each(|(v, (u, w))| *v = u * w);
334     ///
335     ///     // Return clean error code
336     ///     0
337     /// };
338     ///
339     /// let qf = ceed
340     ///     .q_function_interior(1, Box::new(user_f))?
341     ///     .input("u", 1, EvalMode::Interp)?
342     ///     .input("weights", 1, EvalMode::Weight)?
343     ///     .output("v", 1, EvalMode::Interp)?;
344     ///
345     /// const Q: usize = 8;
346     /// let mut w = [0.; Q];
347     /// let mut u = [0.; Q];
348     /// let mut v = [0.; Q];
349     ///
350     /// for i in 0..Q {
351     ///     let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.;
352     ///     u[i] = 2. + 3. * x + 5. * x * x;
353     ///     w[i] = 1. - x * x;
354     ///     v[i] = u[i] * w[i];
355     /// }
356     ///
357     /// let uu = ceed.vector_from_slice(&u)?;
358     /// let ww = ceed.vector_from_slice(&w)?;
359     /// let mut vv = ceed.vector(Q)?;
360     /// vv.set_value(0.0);
361     /// {
362     ///     let input = vec![uu, ww];
363     ///     let mut output = vec![vv];
364     ///     qf.apply(Q, &input, &output)?;
365     ///     vv = output.remove(0);
366     /// }
367     ///
368     /// vv.view()
369     ///     .iter()
370     ///     .zip(v.iter())
371     ///     .for_each(|(computed, actual)| {
372     ///         assert_eq!(
373     ///             *computed, *actual,
374     ///             "Incorrect value in QFunction application"
375     ///         );
376     ///     });
377     /// # Ok(())
378     /// # }
379     /// ```
380     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
381         self.qf_core.apply(Q, u, v)
382     }
383 
384     /// Add a QFunction input
385     ///
386     /// * `fieldname` - Name of QFunction field
387     /// * `size`      - Size of QFunction field, `(ncomp * dim)` for `Grad` or
388     ///                   `(ncomp * 1)` for `None`, `Interp`, and `Weight`
389     /// * `emode`     - `EvalMode::None` to use values directly, `EvalMode::Interp`
390     ///                   to use interpolated values, `EvalMode::Grad` to use
391     ///                   gradients, `EvalMode::Weight` to use quadrature weights
392     ///
393     /// ```
394     /// # use libceed::prelude::*;
395     /// # fn main() -> Result<(), libceed::CeedError> {
396     /// # let ceed = libceed::Ceed::default_init();
397     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
398     ///     // Iterate over quadrature points
399     ///     v.iter_mut()
400     ///         .zip(u.iter().zip(weights.iter()))
401     ///         .for_each(|(v, (u, w))| *v = u * w);
402     ///
403     ///     // Return clean error code
404     ///     0
405     /// };
406     ///
407     /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?;
408     ///
409     /// qf = qf.input("u", 1, EvalMode::Interp)?;
410     /// qf = qf.input("weights", 1, EvalMode::Weight)?;
411     /// # Ok(())
412     /// # }
413     /// ```
414     pub fn input(
415         mut self,
416         fieldname: &str,
417         size: usize,
418         emode: crate::EvalMode,
419     ) -> crate::Result<Self> {
420         let name_c = CString::new(fieldname).expect("CString::new failed");
421         let idx = self.trampoline_data.number_inputs;
422         self.trampoline_data.input_sizes[idx] = size;
423         self.trampoline_data.number_inputs += 1;
424         let (size, emode) = (
425             i32::try_from(size).unwrap(),
426             emode as bind_ceed::CeedEvalMode,
427         );
428         let ierr = unsafe {
429             bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
430         };
431         self.qf_core.ceed.check_error(ierr)?;
432         Ok(self)
433     }
434 
435     /// Add a QFunction output
436     ///
437     /// * `fieldname` - Name of QFunction field
438     /// * `size`      - Size of QFunction field, `(ncomp * dim)` for `Grad` or
439     ///                   `(ncomp * 1)` for `None` and `Interp`
440     /// * `emode`     - `EvalMode::None` to use values directly, `EvalMode::Interp`
441     ///                   to use interpolated values, `EvalMode::Grad` to use
442     ///                   gradients
443     ///
444     /// ```
445     /// # use libceed::prelude::*;
446     /// # fn main() -> Result<(), libceed::CeedError> {
447     /// # let ceed = libceed::Ceed::default_init();
448     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
449     ///     // Iterate over quadrature points
450     ///     v.iter_mut()
451     ///         .zip(u.iter().zip(weights.iter()))
452     ///         .for_each(|(v, (u, w))| *v = u * w);
453     ///
454     ///     // Return clean error code
455     ///     0
456     /// };
457     ///
458     /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?;
459     ///
460     /// qf.output("v", 1, EvalMode::Interp)?;
461     /// # Ok(())
462     /// # }
463     /// ```
464     pub fn output(
465         mut self,
466         fieldname: &str,
467         size: usize,
468         emode: crate::EvalMode,
469     ) -> crate::Result<Self> {
470         let name_c = CString::new(fieldname).expect("CString::new failed");
471         let idx = self.trampoline_data.number_outputs;
472         self.trampoline_data.output_sizes[idx] = size;
473         self.trampoline_data.number_outputs += 1;
474         let (size, emode) = (
475             i32::try_from(size).unwrap(),
476             emode as bind_ceed::CeedEvalMode,
477         );
478         let ierr = unsafe {
479             bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
480         };
481         self.qf_core.ceed.check_error(ierr)?;
482         Ok(self)
483     }
484 }
485 
486 // -----------------------------------------------------------------------------
487 // QFunction
488 // -----------------------------------------------------------------------------
489 impl<'a> QFunctionByName<'a> {
490     // Constructor
491     pub fn create(ceed: &'a crate::Ceed, name: &str) -> crate::Result<Self> {
492         let name_c = CString::new(name).expect("CString::new failed");
493         let mut ptr = std::ptr::null_mut();
494         let ierr = unsafe {
495             bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr)
496         };
497         ceed.check_error(ierr)?;
498         Ok(Self {
499             qf_core: QFunctionCore { ceed, ptr },
500         })
501     }
502 
503     /// Apply the action of a QFunction
504     ///
505     /// * `Q`      - The number of quadrature points
506     /// * `input`  - Array of input Vectors
507     /// * `output` - Array of output Vectors
508     ///
509     /// ```
510     /// # use libceed::prelude::*;
511     /// # fn main() -> Result<(), libceed::CeedError> {
512     /// # let ceed = libceed::Ceed::default_init();
513     /// const Q: usize = 8;
514     /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild")?;
515     /// let qf_mass = ceed.q_function_interior_by_name("MassApply")?;
516     ///
517     /// let mut j = [0.; Q];
518     /// let mut w = [0.; Q];
519     /// let mut u = [0.; Q];
520     /// let mut v = [0.; Q];
521     ///
522     /// for i in 0..Q {
523     ///     let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.;
524     ///     j[i] = 1.;
525     ///     w[i] = 1. - x * x;
526     ///     u[i] = 2. + 3. * x + 5. * x * x;
527     ///     v[i] = w[i] * u[i];
528     /// }
529     ///
530     /// let jj = ceed.vector_from_slice(&j)?;
531     /// let ww = ceed.vector_from_slice(&w)?;
532     /// let uu = ceed.vector_from_slice(&u)?;
533     /// let mut vv = ceed.vector(Q)?;
534     /// vv.set_value(0.0);
535     /// let mut qdata = ceed.vector(Q)?;
536     /// qdata.set_value(0.0);
537     ///
538     /// {
539     ///     let mut input = vec![jj, ww];
540     ///     let mut output = vec![qdata];
541     ///     qf_build.apply(Q, &input, &output)?;
542     ///     qdata = output.remove(0);
543     /// }
544     ///
545     /// {
546     ///     let mut input = vec![qdata, uu];
547     ///     let mut output = vec![vv];
548     ///     qf_mass.apply(Q, &input, &output)?;
549     ///     vv = output.remove(0);
550     /// }
551     ///
552     /// vv.view()
553     ///     .iter()
554     ///     .zip(v.iter())
555     ///     .for_each(|(computed, actual)| {
556     ///         assert_eq!(
557     ///             *computed, *actual,
558     ///             "Incorrect value in QFunction application"
559     ///         );
560     ///     });
561     /// # Ok(())
562     /// # }
563     /// ```
564     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
565         self.qf_core.apply(Q, u, v)
566     }
567 }
568 
569 // -----------------------------------------------------------------------------
570