xref: /libCEED/rust/libceed/src/qfunction.rs (revision 80a9ef0545a39c00cdcaab1ca26f8053604f3120)
1 // Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
2 // the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
3 // reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative
16 
17 //! A Ceed QFunction represents the spatial terms of the point-wise functions
18 //! describing the physics at the quadrature points.
19 
20 use std::pin::Pin;
21 
22 use crate::prelude::*;
23 
24 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS];
25 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS];
26 
27 // -----------------------------------------------------------------------------
28 // CeedQFunction option
29 // -----------------------------------------------------------------------------
30 #[derive(Clone, Copy)]
31 pub enum QFunctionOpt<'a> {
32     SomeQFunction(&'a QFunction<'a>),
33     SomeQFunctionByName(&'a QFunctionByName<'a>),
34     None,
35 }
36 
37 /// Construct a QFunctionOpt reference from a QFunction reference
38 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> {
39     fn from(qfunc: &'a QFunction) -> Self {
40         debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
41         Self::SomeQFunction(qfunc)
42     }
43 }
44 
45 /// Construct a QFunctionOpt reference from a QFunction by Name reference
46 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> {
47     fn from(qfunc: &'a QFunctionByName) -> Self {
48         debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
49         Self::SomeQFunctionByName(qfunc)
50     }
51 }
52 
53 impl<'a> QFunctionOpt<'a> {
54     /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction
55     pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction {
56         match self {
57             Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr,
58             Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr,
59             Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE },
60         }
61     }
62 }
63 
64 // -----------------------------------------------------------------------------
65 // CeedQFunction context wrapper
66 // -----------------------------------------------------------------------------
67 pub(crate) struct QFunctionCore<'a> {
68     ceed: &'a crate::Ceed,
69     ptr: bind_ceed::CeedQFunction,
70 }
71 
72 struct QFunctionTrampolineData {
73     number_inputs: usize,
74     number_outputs: usize,
75     input_sizes: [usize; MAX_QFUNCTION_FIELDS],
76     output_sizes: [usize; MAX_QFUNCTION_FIELDS],
77     user_f: Box<QFunctionUserClosure>,
78 }
79 
80 pub struct QFunction<'a> {
81     qf_core: QFunctionCore<'a>,
82     qf_ctx_ptr: bind_ceed::CeedQFunctionContext,
83     trampoline_data: Pin<Box<QFunctionTrampolineData>>,
84 }
85 
86 pub struct QFunctionByName<'a> {
87     qf_core: QFunctionCore<'a>,
88 }
89 
90 // -----------------------------------------------------------------------------
91 // Destructor
92 // -----------------------------------------------------------------------------
93 impl<'a> Drop for QFunctionCore<'a> {
94     fn drop(&mut self) {
95         unsafe {
96             if self.ptr != bind_ceed::CEED_QFUNCTION_NONE {
97                 bind_ceed::CeedQFunctionDestroy(&mut self.ptr);
98             }
99         }
100     }
101 }
102 
103 impl<'a> Drop for QFunction<'a> {
104     fn drop(&mut self) {
105         unsafe {
106             bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr);
107         }
108     }
109 }
110 
111 // -----------------------------------------------------------------------------
112 // Display
113 // -----------------------------------------------------------------------------
114 impl<'a> fmt::Display for QFunctionCore<'a> {
115     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116         let mut ptr = std::ptr::null_mut();
117         let mut sizeloc = crate::MAX_BUFFER_LENGTH;
118         let cstring = unsafe {
119             let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
120             bind_ceed::CeedQFunctionView(self.ptr, file);
121             bind_ceed::fclose(file);
122             CString::from_raw(ptr)
123         };
124         cstring.to_string_lossy().fmt(f)
125     }
126 }
127 /// View a QFunction
128 ///
129 /// ```
130 /// # use libceed::prelude::*;
131 /// # let ceed = libceed::Ceed::default_init();
132 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
133 ///     // Iterate over quadrature points
134 ///     v.iter_mut()
135 ///         .zip(u.iter().zip(weights.iter()))
136 ///         .for_each(|(v, (u, w))| *v = u * w);
137 ///
138 ///     // Return clean error code
139 ///     0
140 /// };
141 ///
142 /// let qf = ceed
143 ///     .q_function_interior(1, Box::new(user_f))
144 ///     .unwrap()
145 ///     .input("u", 1, EvalMode::Interp)
146 ///     .unwrap()
147 ///     .input("weights", 1, EvalMode::Weight)
148 ///     .unwrap()
149 ///     .output("v", 1, EvalMode::Interp)
150 ///     .unwrap();
151 ///
152 /// println!("{}", qf);
153 /// ```
154 impl<'a> fmt::Display for QFunction<'a> {
155     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
156         self.qf_core.fmt(f)
157     }
158 }
159 
160 /// View a QFunction by Name
161 ///
162 /// ```
163 /// # use libceed::prelude::*;
164 /// # let ceed = libceed::Ceed::default_init();
165 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild").unwrap();
166 /// println!("{}", qf);
167 /// ```
168 impl<'a> fmt::Display for QFunctionByName<'a> {
169     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
170         self.qf_core.fmt(f)
171     }
172 }
173 
174 // -----------------------------------------------------------------------------
175 // Core functionality
176 // -----------------------------------------------------------------------------
177 impl<'a> QFunctionCore<'a> {
178     // Common implementation
179     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
180         let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
181         for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) {
182             u_c[i] = u[i].ptr;
183         }
184         let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
185         for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) {
186             v_c[i] = v[i].ptr;
187         }
188         let Q = i32::try_from(Q).unwrap();
189         let ierr = unsafe {
190             bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr())
191         };
192         self.ceed.check_error(ierr)
193     }
194 }
195 
196 // -----------------------------------------------------------------------------
197 // User QFunction Closure
198 // -----------------------------------------------------------------------------
199 pub type QFunctionUserClosure = dyn FnMut(
200     [&[crate::Scalar]; MAX_QFUNCTION_FIELDS],
201     [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS],
202 ) -> i32;
203 
204 macro_rules! mut_max_fields {
205     ($e:expr) => {
206         [
207             $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e,
208         ]
209     };
210 }
211 unsafe extern "C" fn trampoline(
212     ctx: *mut ::std::os::raw::c_void,
213     q: bind_ceed::CeedInt,
214     inputs: *const *const bind_ceed::CeedScalar,
215     outputs: *const *mut bind_ceed::CeedScalar,
216 ) -> ::std::os::raw::c_int {
217     let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx);
218 
219     // Inputs
220     let inputs_slice: &[*const bind_ceed::CeedScalar] =
221         std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS);
222     let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS];
223     inputs_slice
224         .iter()
225         .enumerate()
226         .map(|(i, &x)| {
227             std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize)
228                 as &[crate::Scalar]
229         })
230         .zip(inputs_array.iter_mut())
231         .for_each(|(x, a)| *a = x);
232 
233     // Outputs
234     let outputs_slice: &[*mut bind_ceed::CeedScalar] =
235         std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS);
236     let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] =
237         mut_max_fields!(&mut [0.0]);
238     outputs_slice
239         .iter()
240         .enumerate()
241         .map(|(i, &x)| {
242             std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize)
243                 as &mut [crate::Scalar]
244         })
245         .zip(outputs_array.iter_mut())
246         .for_each(|(x, a)| *a = x);
247 
248     // User closure
249     (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array)
250 }
251 
252 // -----------------------------------------------------------------------------
253 // QFunction
254 // -----------------------------------------------------------------------------
255 impl<'a> QFunction<'a> {
256     // Constructor
257     pub fn create(
258         ceed: &'a crate::Ceed,
259         vlength: usize,
260         user_f: Box<QFunctionUserClosure>,
261     ) -> crate::Result<Self> {
262         let source_c = CString::new("").expect("CString::new failed");
263         let mut ptr = std::ptr::null_mut();
264 
265         // Context for closure
266         let number_inputs = 0;
267         let number_outputs = 0;
268         let input_sizes = [0; MAX_QFUNCTION_FIELDS];
269         let output_sizes = [0; MAX_QFUNCTION_FIELDS];
270         let trampoline_data = unsafe {
271             Pin::new_unchecked(Box::new(QFunctionTrampolineData {
272                 number_inputs,
273                 number_outputs,
274                 input_sizes,
275                 output_sizes,
276                 user_f,
277             }))
278         };
279 
280         // Create QFunction
281         let vlength = i32::try_from(vlength).unwrap();
282         let mut ierr = unsafe {
283             bind_ceed::CeedQFunctionCreateInterior(
284                 ceed.ptr,
285                 vlength,
286                 Some(trampoline),
287                 source_c.as_ptr(),
288                 &mut ptr,
289             )
290         };
291         ceed.check_error(ierr)?;
292 
293         // Set closure
294         let mut qf_ctx_ptr = std::ptr::null_mut();
295         ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) };
296         ceed.check_error(ierr)?;
297         ierr = unsafe {
298             bind_ceed::CeedQFunctionContextSetData(
299                 qf_ctx_ptr,
300                 crate::MemType::Host as bind_ceed::CeedMemType,
301                 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode,
302                 std::mem::size_of::<QFunctionTrampolineData>() as u64,
303                 std::mem::transmute(trampoline_data.as_ref()),
304             )
305         };
306         ceed.check_error(ierr)?;
307         ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) };
308         ceed.check_error(ierr)?;
309         Ok(Self {
310             qf_core: QFunctionCore { ceed, ptr },
311             qf_ctx_ptr,
312             trampoline_data,
313         })
314     }
315 
316     /// Apply the action of a QFunction
317     ///
318     /// * `Q`      - The number of quadrature points
319     /// * `input`  - Array of input Vectors
320     /// * `output` - Array of output Vectors
321     ///
322     /// ```
323     /// # use libceed::prelude::*;
324     /// # let ceed = libceed::Ceed::default_init();
325     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
326     ///     // Iterate over quadrature points
327     ///     v.iter_mut()
328     ///         .zip(u.iter().zip(weights.iter()))
329     ///         .for_each(|(v, (u, w))| *v = u * w);
330     ///
331     ///     // Return clean error code
332     ///     0
333     /// };
334     ///
335     /// let qf = ceed
336     ///     .q_function_interior(1, Box::new(user_f))
337     ///     .unwrap()
338     ///     .input("u", 1, EvalMode::Interp)
339     ///     .unwrap()
340     ///     .input("weights", 1, EvalMode::Weight)
341     ///     .unwrap()
342     ///     .output("v", 1, EvalMode::Interp)
343     ///     .unwrap();
344     ///
345     /// const Q: usize = 8;
346     /// let mut w = [0.; Q];
347     /// let mut u = [0.; Q];
348     /// let mut v = [0.; Q];
349     ///
350     /// for i in 0..Q {
351     ///     let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.;
352     ///     u[i] = 2. + 3. * x + 5. * x * x;
353     ///     w[i] = 1. - x * x;
354     ///     v[i] = u[i] * w[i];
355     /// }
356     ///
357     /// let uu = ceed.vector_from_slice(&u).unwrap();
358     /// let ww = ceed.vector_from_slice(&w).unwrap();
359     /// let mut vv = ceed.vector(Q).unwrap();
360     /// vv.set_value(0.0);
361     /// {
362     ///     let input = vec![uu, ww];
363     ///     let mut output = vec![vv];
364     ///     qf.apply(Q, &input, &output).unwrap();
365     ///     vv = output.remove(0);
366     /// }
367     ///
368     /// vv.view()
369     ///     .iter()
370     ///     .zip(v.iter())
371     ///     .for_each(|(computed, actual)| {
372     ///         assert_eq!(
373     ///             *computed, *actual,
374     ///             "Incorrect value in QFunction application"
375     ///         );
376     ///     });
377     /// ```
378     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
379         self.qf_core.apply(Q, u, v)
380     }
381 
382     /// Add a QFunction input
383     ///
384     /// * `fieldname` - Name of QFunction field
385     /// * `size`      - Size of QFunction field, `(ncomp * dim)` for `Grad` or
386     ///                   `(ncomp * 1)` for `None`, `Interp`, and `Weight`
387     /// * `emode`     - `EvalMode::None` to use values directly, `EvalMode::Interp`
388     ///                   to use interpolated values, `EvalMode::Grad` to use
389     ///                   gradients, `EvalMode::Weight` to use quadrature weights
390     ///
391     /// ```
392     /// # use libceed::prelude::*;
393     /// # let ceed = libceed::Ceed::default_init();
394     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
395     ///     // Iterate over quadrature points
396     ///     v.iter_mut()
397     ///         .zip(u.iter().zip(weights.iter()))
398     ///         .for_each(|(v, (u, w))| *v = u * w);
399     ///
400     ///     // Return clean error code
401     ///     0
402     /// };
403     ///
404     /// let mut qf = ceed.q_function_interior(1, Box::new(user_f)).unwrap();
405     ///
406     /// qf = qf.input("u", 1, EvalMode::Interp).unwrap();
407     /// qf = qf.input("weights", 1, EvalMode::Weight).unwrap();
408     /// ```
409     pub fn input(
410         mut self,
411         fieldname: &str,
412         size: usize,
413         emode: crate::EvalMode,
414     ) -> crate::Result<Self> {
415         let name_c = CString::new(fieldname).expect("CString::new failed");
416         let idx = self.trampoline_data.number_inputs;
417         self.trampoline_data.input_sizes[idx] = size;
418         self.trampoline_data.number_inputs += 1;
419         let (size, emode) = (
420             i32::try_from(size).unwrap(),
421             emode as bind_ceed::CeedEvalMode,
422         );
423         let ierr = unsafe {
424             bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
425         };
426         self.qf_core.ceed.check_error(ierr)?;
427         Ok(self)
428     }
429 
430     /// Add a QFunction output
431     ///
432     /// * `fieldname` - Name of QFunction field
433     /// * `size`      - Size of QFunction field, `(ncomp * dim)` for `Grad` or
434     ///                   `(ncomp * 1)` for `None` and `Interp`
435     /// * `emode`     - `EvalMode::None` to use values directly, `EvalMode::Interp`
436     ///                   to use interpolated values, `EvalMode::Grad` to use
437     ///                   gradients
438     ///
439     /// ```
440     /// # use libceed::prelude::*;
441     /// # let ceed = libceed::Ceed::default_init();
442     /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
443     ///     // Iterate over quadrature points
444     ///     v.iter_mut()
445     ///         .zip(u.iter().zip(weights.iter()))
446     ///         .for_each(|(v, (u, w))| *v = u * w);
447     ///
448     ///     // Return clean error code
449     ///     0
450     /// };
451     ///
452     /// let mut qf = ceed.q_function_interior(1, Box::new(user_f)).unwrap();
453     ///
454     /// qf.output("v", 1, EvalMode::Interp).unwrap();
455     /// ```
456     pub fn output(
457         mut self,
458         fieldname: &str,
459         size: usize,
460         emode: crate::EvalMode,
461     ) -> crate::Result<Self> {
462         let name_c = CString::new(fieldname).expect("CString::new failed");
463         let idx = self.trampoline_data.number_outputs;
464         self.trampoline_data.output_sizes[idx] = size;
465         self.trampoline_data.number_outputs += 1;
466         let (size, emode) = (
467             i32::try_from(size).unwrap(),
468             emode as bind_ceed::CeedEvalMode,
469         );
470         let ierr = unsafe {
471             bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
472         };
473         self.qf_core.ceed.check_error(ierr)?;
474         Ok(self)
475     }
476 }
477 
478 // -----------------------------------------------------------------------------
479 // QFunction
480 // -----------------------------------------------------------------------------
481 impl<'a> QFunctionByName<'a> {
482     // Constructor
483     pub fn create(ceed: &'a crate::Ceed, name: &str) -> crate::Result<Self> {
484         let name_c = CString::new(name).expect("CString::new failed");
485         let mut ptr = std::ptr::null_mut();
486         let ierr = unsafe {
487             bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr)
488         };
489         ceed.check_error(ierr)?;
490         Ok(Self {
491             qf_core: QFunctionCore { ceed, ptr },
492         })
493     }
494 
495     /// Apply the action of a QFunction
496     ///
497     /// * `Q`      - The number of quadrature points
498     /// * `input`  - Array of input Vectors
499     /// * `output` - Array of output Vectors
500     ///
501     /// ```
502     /// # use libceed::prelude::*;
503     /// # let ceed = libceed::Ceed::default_init();
504     /// const Q: usize = 8;
505     /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild").unwrap();
506     /// let qf_mass = ceed.q_function_interior_by_name("MassApply").unwrap();
507     ///
508     /// let mut j = [0.; Q];
509     /// let mut w = [0.; Q];
510     /// let mut u = [0.; Q];
511     /// let mut v = [0.; Q];
512     ///
513     /// for i in 0..Q {
514     ///     let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.;
515     ///     j[i] = 1.;
516     ///     w[i] = 1. - x * x;
517     ///     u[i] = 2. + 3. * x + 5. * x * x;
518     ///     v[i] = w[i] * u[i];
519     /// }
520     ///
521     /// let jj = ceed.vector_from_slice(&j).unwrap();
522     /// let ww = ceed.vector_from_slice(&w).unwrap();
523     /// let uu = ceed.vector_from_slice(&u).unwrap();
524     /// let mut vv = ceed.vector(Q).unwrap();
525     /// vv.set_value(0.0);
526     /// let mut qdata = ceed.vector(Q).unwrap();
527     /// qdata.set_value(0.0);
528     ///
529     /// {
530     ///     let mut input = vec![jj, ww];
531     ///     let mut output = vec![qdata];
532     ///     qf_build.apply(Q, &input, &output).unwrap();
533     ///     qdata = output.remove(0);
534     /// }
535     ///
536     /// {
537     ///     let mut input = vec![qdata, uu];
538     ///     let mut output = vec![vv];
539     ///     qf_mass.apply(Q, &input, &output).unwrap();
540     ///     vv = output.remove(0);
541     /// }
542     ///
543     /// vv.view()
544     ///     .iter()
545     ///     .zip(v.iter())
546     ///     .for_each(|(computed, actual)| {
547     ///         assert_eq!(
548     ///             *computed, *actual,
549     ///             "Incorrect value in QFunction application"
550     ///         );
551     ///     });
552     /// ```
553     pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
554         self.qf_core.apply(Q, u, v)
555     }
556 }
557 
558 // -----------------------------------------------------------------------------
559