1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 //! A Ceed QFunction represents the spatial terms of the point-wise functions
9 //! describing the physics at the quadrature points.
10
11 use std::pin::Pin;
12
13 use crate::{prelude::*, vector::Vector, MAX_QFUNCTION_FIELDS};
14
15 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS];
16 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS];
17
18 // -----------------------------------------------------------------------------
19 // QFunction Field context wrapper
20 // -----------------------------------------------------------------------------
21 #[derive(Debug)]
22 pub struct QFunctionField<'a> {
23 ptr: bind_ceed::CeedQFunctionField,
24 _lifeline: PhantomData<&'a ()>,
25 }
26
27 // -----------------------------------------------------------------------------
28 // Implementations
29 // -----------------------------------------------------------------------------
30 impl<'a> QFunctionField<'a> {
31 /// Get the name of a QFunctionField
32 ///
33 /// ```
34 /// # use libceed::prelude::*;
35 /// # fn main() -> libceed::Result<()> {
36 /// # let ceed = libceed::Ceed::default_init();
37 /// const Q: usize = 8;
38 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?;
39 ///
40 /// let inputs = qf.inputs()?;
41 ///
42 /// assert_eq!(inputs[0].name(), "dx", "Incorrect input name");
43 /// assert_eq!(inputs[1].name(), "weights", "Incorrect input name");
44 /// # Ok(())
45 /// # }
46 /// ```
name(&self) -> &str47 pub fn name(&self) -> &str {
48 let mut name_ptr: *mut std::os::raw::c_char = std::ptr::null_mut();
49 unsafe {
50 bind_ceed::CeedQFunctionFieldGetName(
51 self.ptr,
52 &mut name_ptr as *const _ as *mut *const _,
53 );
54 }
55 unsafe { CStr::from_ptr(name_ptr) }.to_str().unwrap()
56 }
57
58 /// Get the size of a QFunctionField
59 ///
60 /// ```
61 /// # use libceed::prelude::*;
62 /// # fn main() -> libceed::Result<()> {
63 /// # let ceed = libceed::Ceed::default_init();
64 /// const Q: usize = 8;
65 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?;
66 ///
67 /// let inputs = qf.inputs()?;
68 ///
69 /// assert_eq!(inputs[0].size(), 4, "Incorrect input size");
70 /// assert_eq!(inputs[1].size(), 1, "Incorrect input size");
71 /// # Ok(())
72 /// # }
73 /// ```
size(&self) -> usize74 pub fn size(&self) -> usize {
75 let mut size = 0;
76 unsafe {
77 bind_ceed::CeedQFunctionFieldGetSize(self.ptr, &mut size);
78 }
79 usize::try_from(size).unwrap()
80 }
81
82 /// Get the evaluation mode of a QFunctionField
83 ///
84 /// ```
85 /// # use libceed::{prelude::*, EvalMode};
86 /// # fn main() -> libceed::Result<()> {
87 /// # let ceed = libceed::Ceed::default_init();
88 /// const Q: usize = 8;
89 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?;
90 ///
91 /// let inputs = qf.inputs()?;
92 ///
93 /// assert_eq!(
94 /// inputs[0].eval_mode(),
95 /// EvalMode::Grad,
96 /// "Incorrect input evaluation mode"
97 /// );
98 /// assert_eq!(
99 /// inputs[1].eval_mode(),
100 /// EvalMode::Weight,
101 /// "Incorrect input evaluation mode"
102 /// );
103 /// # Ok(())
104 /// # }
105 /// ```
eval_mode(&self) -> crate::EvalMode106 pub fn eval_mode(&self) -> crate::EvalMode {
107 let mut mode = 0;
108 unsafe {
109 bind_ceed::CeedQFunctionFieldGetEvalMode(self.ptr, &mut mode);
110 }
111 crate::EvalMode::from_u32(mode)
112 }
113 }
114
115 // -----------------------------------------------------------------------------
116 // QFunction option
117 // -----------------------------------------------------------------------------
118 pub enum QFunctionOpt<'a> {
119 SomeQFunction(&'a QFunction<'a>),
120 SomeQFunctionByName(&'a QFunctionByName<'a>),
121 None,
122 }
123
124 /// Construct a QFunctionOpt reference from a QFunction reference
125 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> {
from(qfunc: &'a QFunction) -> Self126 fn from(qfunc: &'a QFunction) -> Self {
127 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
128 Self::SomeQFunction(qfunc)
129 }
130 }
131
132 /// Construct a QFunctionOpt reference from a QFunction by Name reference
133 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> {
from(qfunc: &'a QFunctionByName) -> Self134 fn from(qfunc: &'a QFunctionByName) -> Self {
135 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
136 Self::SomeQFunctionByName(qfunc)
137 }
138 }
139
140 impl<'a> QFunctionOpt<'a> {
141 /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction
to_raw(&self) -> bind_ceed::CeedQFunction142 pub(crate) fn to_raw(&self) -> bind_ceed::CeedQFunction {
143 match self {
144 Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr,
145 Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr,
146 Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE },
147 }
148 }
149
150 /// Check if a QFunctionOpt is Some
151 ///
152 /// ```
153 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOpt, QFunctionOutputs};
154 /// # fn main() -> libceed::Result<()> {
155 /// # let ceed = libceed::Ceed::default_init();
156 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
157 /// // Iterate over quadrature points
158 /// v.iter_mut()
159 /// .zip(u.iter().zip(weights.iter()))
160 /// .for_each(|(v, (u, w))| *v = u * w);
161 ///
162 /// // Return clean error code
163 /// 0
164 /// };
165 ///
166 /// let qf = ceed
167 /// .q_function_interior(1, Box::new(user_f))?
168 /// .input("u", 1, EvalMode::Interp)?
169 /// .input("weights", 1, EvalMode::Weight)?
170 /// .output("v", 1, EvalMode::Interp)?;
171 /// let qf_opt = QFunctionOpt::from(&qf);
172 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt");
173 ///
174 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
175 /// let qf_opt = QFunctionOpt::from(&qf);
176 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt");
177 ///
178 /// let qf_opt = QFunctionOpt::None;
179 /// assert!(!qf_opt.is_some(), "Incorrect QFunctionOpt");
180 /// # Ok(())
181 /// # }
182 /// ```
is_some(&self) -> bool183 pub fn is_some(&self) -> bool {
184 match self {
185 Self::SomeQFunction(_) => true,
186 Self::SomeQFunctionByName(_) => true,
187 Self::None => false,
188 }
189 }
190
191 /// Check if a QFunctionOpt is SomeQFunction
192 ///
193 /// ```
194 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOpt, QFunctionOutputs};
195 /// # fn main() -> libceed::Result<()> {
196 /// # let ceed = libceed::Ceed::default_init();
197 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
198 /// // Iterate over quadrature points
199 /// v.iter_mut()
200 /// .zip(u.iter().zip(weights.iter()))
201 /// .for_each(|(v, (u, w))| *v = u * w);
202 ///
203 /// // Return clean error code
204 /// 0
205 /// };
206 ///
207 /// let qf = ceed
208 /// .q_function_interior(1, Box::new(user_f))?
209 /// .input("u", 1, EvalMode::Interp)?
210 /// .input("weights", 1, EvalMode::Weight)?
211 /// .output("v", 1, EvalMode::Interp)?;
212 /// let qf_opt = QFunctionOpt::from(&qf);
213 /// assert!(qf_opt.is_some_q_function(), "Incorrect QFunctionOpt");
214 ///
215 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
216 /// let qf_opt = QFunctionOpt::from(&qf);
217 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt");
218 ///
219 /// let qf_opt = QFunctionOpt::None;
220 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt");
221 /// # Ok(())
222 /// # }
223 /// ```
is_some_q_function(&self) -> bool224 pub fn is_some_q_function(&self) -> bool {
225 match self {
226 Self::SomeQFunction(_) => true,
227 Self::SomeQFunctionByName(_) => false,
228 Self::None => false,
229 }
230 }
231
232 /// Check if a QFunctionOpt is SomeQFunctionByName
233 ///
234 /// ```
235 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOpt, QFunctionOutputs};
236 /// # fn main() -> libceed::Result<()> {
237 /// # let ceed = libceed::Ceed::default_init();
238 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
239 /// // Iterate over quadrature points
240 /// v.iter_mut()
241 /// .zip(u.iter().zip(weights.iter()))
242 /// .for_each(|(v, (u, w))| *v = u * w);
243 ///
244 /// // Return clean error code
245 /// 0
246 /// };
247 ///
248 /// let qf = ceed
249 /// .q_function_interior(1, Box::new(user_f))?
250 /// .input("u", 1, EvalMode::Interp)?
251 /// .input("weights", 1, EvalMode::Weight)?
252 /// .output("v", 1, EvalMode::Interp)?;
253 /// let qf_opt = QFunctionOpt::from(&qf);
254 /// assert!(
255 /// !qf_opt.is_some_q_function_by_name(),
256 /// "Incorrect QFunctionOpt"
257 /// );
258 ///
259 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
260 /// let qf_opt = QFunctionOpt::from(&qf);
261 /// assert!(
262 /// qf_opt.is_some_q_function_by_name(),
263 /// "Incorrect QFunctionOpt"
264 /// );
265 ///
266 /// let qf_opt = QFunctionOpt::None;
267 /// assert!(
268 /// !qf_opt.is_some_q_function_by_name(),
269 /// "Incorrect QFunctionOpt"
270 /// );
271 /// # Ok(())
272 /// # }
273 /// ```
is_some_q_function_by_name(&self) -> bool274 pub fn is_some_q_function_by_name(&self) -> bool {
275 match self {
276 Self::SomeQFunction(_) => false,
277 Self::SomeQFunctionByName(_) => true,
278 Self::None => false,
279 }
280 }
281
282 /// Check if a QFunctionOpt is None
283 ///
284 /// ```
285 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOpt, QFunctionOutputs};
286 /// # fn main() -> libceed::Result<()> {
287 /// # let ceed = libceed::Ceed::default_init();
288 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
289 /// // Iterate over quadrature points
290 /// v.iter_mut()
291 /// .zip(u.iter().zip(weights.iter()))
292 /// .for_each(|(v, (u, w))| *v = u * w);
293 ///
294 /// // Return clean error code
295 /// 0
296 /// };
297 ///
298 /// let qf = ceed
299 /// .q_function_interior(1, Box::new(user_f))?
300 /// .input("u", 1, EvalMode::Interp)?
301 /// .input("weights", 1, EvalMode::Weight)?
302 /// .output("v", 1, EvalMode::Interp)?;
303 /// let qf_opt = QFunctionOpt::from(&qf);
304 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt");
305 ///
306 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
307 /// let qf_opt = QFunctionOpt::from(&qf);
308 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt");
309 ///
310 /// let qf_opt = QFunctionOpt::None;
311 /// assert!(qf_opt.is_none(), "Incorrect QFunctionOpt");
312 /// # Ok(())
313 /// # }
314 /// ```
is_none(&self) -> bool315 pub fn is_none(&self) -> bool {
316 match self {
317 Self::SomeQFunction(_) => false,
318 Self::SomeQFunctionByName(_) => false,
319 Self::None => true,
320 }
321 }
322 }
323
324 // -----------------------------------------------------------------------------
325 // QFunction context wrapper
326 // -----------------------------------------------------------------------------
327 #[derive(Debug)]
328 pub(crate) struct QFunctionCore<'a> {
329 ptr: bind_ceed::CeedQFunction,
330 _lifeline: PhantomData<&'a ()>,
331 }
332
333 struct QFunctionTrampolineData {
334 number_inputs: usize,
335 number_outputs: usize,
336 input_sizes: [usize; MAX_QFUNCTION_FIELDS],
337 output_sizes: [usize; MAX_QFUNCTION_FIELDS],
338 user_f: Box<QFunctionUserClosure>,
339 }
340
341 pub struct QFunction<'a> {
342 qf_core: QFunctionCore<'a>,
343 qf_ctx_ptr: bind_ceed::CeedQFunctionContext,
344 trampoline_data: Pin<Box<QFunctionTrampolineData>>,
345 }
346
347 #[derive(Debug)]
348 pub struct QFunctionByName<'a> {
349 qf_core: QFunctionCore<'a>,
350 }
351
352 // -----------------------------------------------------------------------------
353 // Destructor
354 // -----------------------------------------------------------------------------
355 impl<'a> Drop for QFunctionCore<'a> {
drop(&mut self)356 fn drop(&mut self) {
357 unsafe {
358 if self.ptr != bind_ceed::CEED_QFUNCTION_NONE {
359 bind_ceed::CeedQFunctionDestroy(&mut self.ptr);
360 }
361 }
362 }
363 }
364
365 impl<'a> Drop for QFunction<'a> {
drop(&mut self)366 fn drop(&mut self) {
367 unsafe {
368 bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr);
369 }
370 }
371 }
372
373 // -----------------------------------------------------------------------------
374 // Display
375 // -----------------------------------------------------------------------------
376 impl<'a> fmt::Display for QFunctionCore<'a> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result377 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
378 let mut ptr = std::ptr::null_mut();
379 let mut sizeloc = crate::MAX_BUFFER_LENGTH;
380 let cstring = unsafe {
381 let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
382 bind_ceed::CeedQFunctionView(self.ptr, file);
383 bind_ceed::fclose(file);
384 CString::from_raw(ptr)
385 };
386 cstring.to_string_lossy().fmt(f)
387 }
388 }
389 /// View a QFunction
390 ///
391 /// ```
392 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs};
393 /// # fn main() -> libceed::Result<()> {
394 /// # let ceed = libceed::Ceed::default_init();
395 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
396 /// // Iterate over quadrature points
397 /// v.iter_mut()
398 /// .zip(u.iter().zip(weights.iter()))
399 /// .for_each(|(v, (u, w))| *v = u * w);
400 ///
401 /// // Return clean error code
402 /// 0
403 /// };
404 ///
405 /// let qf = ceed
406 /// .q_function_interior(1, Box::new(user_f))?
407 /// .input("u", 1, EvalMode::Interp)?
408 /// .input("weights", 1, EvalMode::Weight)?
409 /// .output("v", 1, EvalMode::Interp)?;
410 ///
411 /// println!("{}", qf);
412 /// # Ok(())
413 /// # }
414 /// ```
415 impl<'a> fmt::Display for QFunction<'a> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result416 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
417 self.qf_core.fmt(f)
418 }
419 }
420
421 /// View a QFunction by Name
422 ///
423 /// ```
424 /// # use libceed::prelude::*;
425 /// # fn main() -> libceed::Result<()> {
426 /// # let ceed = libceed::Ceed::default_init();
427 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
428 /// println!("{}", qf);
429 /// # Ok(())
430 /// # }
431 /// ```
432 impl<'a> fmt::Display for QFunctionByName<'a> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result433 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
434 self.qf_core.fmt(f)
435 }
436 }
437
438 // -----------------------------------------------------------------------------
439 // Core functionality
440 // -----------------------------------------------------------------------------
441 impl<'a> QFunctionCore<'a> {
442 // Raw Ceed for error handling
443 #[doc(hidden)]
ceed(&self) -> bind_ceed::Ceed444 fn ceed(&self) -> bind_ceed::Ceed {
445 unsafe { bind_ceed::CeedQFunctionReturnCeed(self.ptr) }
446 }
447
448 // Error handling
449 #[doc(hidden)]
check_error(&self, ierr: i32) -> crate::Result<i32>450 fn check_error(&self, ierr: i32) -> crate::Result<i32> {
451 crate::check_error(|| self.ceed(), ierr)
452 }
453
454 // Common implementation
apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32>455 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
456 let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
457 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) {
458 u_c[i] = u[i].ptr;
459 }
460 let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
461 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) {
462 v_c[i] = v[i].ptr;
463 }
464 let Q = i32::try_from(Q).unwrap();
465 self.check_error(unsafe {
466 bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr())
467 })
468 }
469
inputs(&self) -> crate::Result<&[QFunctionField]>470 pub fn inputs(&self) -> crate::Result<&[QFunctionField]> {
471 // Get array of raw C pointers for inputs
472 let mut num_inputs = 0;
473 let mut inputs_ptr = std::ptr::null_mut();
474 self.check_error(unsafe {
475 bind_ceed::CeedQFunctionGetFields(
476 self.ptr,
477 &mut num_inputs,
478 &mut inputs_ptr,
479 std::ptr::null_mut() as *mut bind_ceed::CeedInt,
480 std::ptr::null_mut(),
481 )
482 })?;
483 // Convert raw C pointers to fixed length slice
484 let inputs_slice = unsafe {
485 std::slice::from_raw_parts(inputs_ptr as *const QFunctionField, num_inputs as usize)
486 };
487 Ok(inputs_slice)
488 }
489
outputs(&self) -> crate::Result<&[QFunctionField]>490 pub fn outputs(&self) -> crate::Result<&[QFunctionField]> {
491 // Get array of raw C pointers for outputs
492 let mut num_outputs = 0;
493 let mut outputs_ptr = std::ptr::null_mut();
494 self.check_error(unsafe {
495 bind_ceed::CeedQFunctionGetFields(
496 self.ptr,
497 std::ptr::null_mut() as *mut bind_ceed::CeedInt,
498 std::ptr::null_mut(),
499 &mut num_outputs,
500 &mut outputs_ptr,
501 )
502 })?;
503 // Convert raw C pointers to fixed length slice
504 let outputs_slice = unsafe {
505 std::slice::from_raw_parts(outputs_ptr as *const QFunctionField, num_outputs as usize)
506 };
507 Ok(outputs_slice)
508 }
509 }
510
511 // -----------------------------------------------------------------------------
512 // User QFunction Closure
513 // -----------------------------------------------------------------------------
514 pub type QFunctionUserClosure = dyn FnMut(
515 [&[crate::Scalar]; MAX_QFUNCTION_FIELDS],
516 [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS],
517 ) -> i32;
518
519 macro_rules! mut_max_fields {
520 ($e:expr) => {
521 [
522 $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e,
523 ]
524 };
525 }
trampoline( ctx: *mut ::std::os::raw::c_void, q: bind_ceed::CeedInt, inputs: *const *const bind_ceed::CeedScalar, outputs: *const *mut bind_ceed::CeedScalar, ) -> ::std::os::raw::c_int526 unsafe extern "C" fn trampoline(
527 ctx: *mut ::std::os::raw::c_void,
528 q: bind_ceed::CeedInt,
529 inputs: *const *const bind_ceed::CeedScalar,
530 outputs: *const *mut bind_ceed::CeedScalar,
531 ) -> ::std::os::raw::c_int {
532 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx);
533
534 // Inputs
535 let inputs_slice: &[*const bind_ceed::CeedScalar] =
536 std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS);
537 let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS];
538 inputs_slice
539 .iter()
540 .take(trampoline_data.number_inputs)
541 .enumerate()
542 .map(|(i, &x)| {
543 std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize)
544 as &[crate::Scalar]
545 })
546 .zip(inputs_array.iter_mut())
547 .for_each(|(x, a)| *a = x);
548
549 // Outputs
550 let outputs_slice: &[*mut bind_ceed::CeedScalar] =
551 std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS);
552 let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] =
553 mut_max_fields!(&mut [0.0]);
554 outputs_slice
555 .iter()
556 .take(trampoline_data.number_outputs)
557 .enumerate()
558 .map(|(i, &x)| {
559 std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize)
560 as &mut [crate::Scalar]
561 })
562 .zip(outputs_array.iter_mut())
563 .for_each(|(x, a)| *a = x);
564
565 // User closure
566 (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array)
567 }
568
569 // -----------------------------------------------------------------------------
570 // QFunction
571 // -----------------------------------------------------------------------------
572 impl<'a> QFunction<'a> {
573 // Constructor
create( ceed: &crate::Ceed, vlength: usize, user_f: Box<QFunctionUserClosure>, ) -> crate::Result<Self>574 pub fn create(
575 ceed: &crate::Ceed,
576 vlength: usize,
577 user_f: Box<QFunctionUserClosure>,
578 ) -> crate::Result<Self> {
579 let source_c = CString::new("").expect("CString::new failed");
580 let mut ptr = std::ptr::null_mut();
581
582 // Context for closure
583 let number_inputs = 0;
584 let number_outputs = 0;
585 let input_sizes = [0; MAX_QFUNCTION_FIELDS];
586 let output_sizes = [0; MAX_QFUNCTION_FIELDS];
587 let trampoline_data = unsafe {
588 Pin::new_unchecked(Box::new(QFunctionTrampolineData {
589 number_inputs,
590 number_outputs,
591 input_sizes,
592 output_sizes,
593 user_f,
594 }))
595 };
596
597 // Create QFunction
598 let vlength = i32::try_from(vlength).unwrap();
599 ceed.check_error(unsafe {
600 bind_ceed::CeedQFunctionCreateInterior(
601 ceed.ptr,
602 vlength,
603 Some(trampoline),
604 source_c.as_ptr(),
605 &mut ptr,
606 )
607 })?;
608
609 // Set closure
610 let mut qf_ctx_ptr = std::ptr::null_mut();
611 ceed.check_error(unsafe {
612 bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr)
613 })?;
614 ceed.check_error(unsafe {
615 bind_ceed::CeedQFunctionContextSetData(
616 qf_ctx_ptr,
617 crate::MemType::Host as bind_ceed::CeedMemType,
618 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode,
619 std::mem::size_of::<QFunctionTrampolineData>(),
620 std::mem::transmute::<
621 std::pin::Pin<&QFunctionTrampolineData>,
622 *mut std::ffi::c_void,
623 >(trampoline_data.as_ref()),
624 )
625 })?;
626 ceed.check_error(unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) })?;
627 ceed.check_error(unsafe { bind_ceed::CeedQFunctionContextDestroy(&mut qf_ctx_ptr) })?;
628 Ok(Self {
629 qf_core: QFunctionCore {
630 ptr,
631 _lifeline: PhantomData,
632 },
633 qf_ctx_ptr,
634 trampoline_data,
635 })
636 }
637
638 /// Apply the action of a QFunction
639 ///
640 /// * `Q` - The number of quadrature points
641 /// * `input` - Array of input Vectors
642 /// * `output` - Array of output Vectors
643 ///
644 /// ```
645 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs, Scalar};
646 /// # fn main() -> libceed::Result<()> {
647 /// # let ceed = libceed::Ceed::default_init();
648 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
649 /// // Iterate over quadrature points
650 /// v.iter_mut()
651 /// .zip(u.iter().zip(weights.iter()))
652 /// .for_each(|(v, (u, w))| *v = u * w);
653 ///
654 /// // Return clean error code
655 /// 0
656 /// };
657 ///
658 /// let qf = ceed
659 /// .q_function_interior(1, Box::new(user_f))?
660 /// .input("u", 1, EvalMode::Interp)?
661 /// .input("weights", 1, EvalMode::Weight)?
662 /// .output("v", 1, EvalMode::Interp)?;
663 ///
664 /// const Q: usize = 8;
665 /// let mut w = [0.; Q];
666 /// let mut u = [0.; Q];
667 /// let mut v = [0.; Q];
668 ///
669 /// for i in 0..Q {
670 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.;
671 /// u[i] = 2. + 3. * x + 5. * x * x;
672 /// w[i] = 1. - x * x;
673 /// v[i] = u[i] * w[i];
674 /// }
675 ///
676 /// let uu = ceed.vector_from_slice(&u)?;
677 /// let ww = ceed.vector_from_slice(&w)?;
678 /// let mut vv = ceed.vector(Q)?;
679 /// vv.set_value(0.0);
680 /// {
681 /// let input = vec![uu, ww];
682 /// let mut output = vec![vv];
683 /// qf.apply(Q, &input, &output)?;
684 /// vv = output.remove(0);
685 /// }
686 ///
687 /// vv.view()?
688 /// .iter()
689 /// .zip(v.iter())
690 /// .for_each(|(computed, actual)| {
691 /// assert_eq!(
692 /// *computed, *actual,
693 /// "Incorrect value in QFunction application"
694 /// );
695 /// });
696 /// # Ok(())
697 /// # }
698 /// ```
apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32>699 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
700 self.qf_core.apply(Q, u, v)
701 }
702
703 /// Add a QFunction input
704 ///
705 /// * `fieldname` - Name of QFunction field
706 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or
707 /// `(ncomp * 1)` for `None`, `Interp`, and `Weight`
708 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp`
709 /// to use interpolated values, `EvalMode::Grad` to use
710 /// gradients, `EvalMode::Weight` to use quadrature weights
711 ///
712 /// ```
713 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs};
714 /// # fn main() -> libceed::Result<()> {
715 /// # let ceed = libceed::Ceed::default_init();
716 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
717 /// // Iterate over quadrature points
718 /// v.iter_mut()
719 /// .zip(u.iter().zip(weights.iter()))
720 /// .for_each(|(v, (u, w))| *v = u * w);
721 ///
722 /// // Return clean error code
723 /// 0
724 /// };
725 ///
726 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?;
727 ///
728 /// qf = qf.input("u", 1, EvalMode::Interp)?;
729 /// qf = qf.input("weights", 1, EvalMode::Weight)?;
730 /// # Ok(())
731 /// # }
732 /// ```
input( mut self, fieldname: &str, size: usize, emode: crate::EvalMode, ) -> crate::Result<Self>733 pub fn input(
734 mut self,
735 fieldname: &str,
736 size: usize,
737 emode: crate::EvalMode,
738 ) -> crate::Result<Self> {
739 let name_c = CString::new(fieldname).expect("CString::new failed");
740 let idx = self.trampoline_data.number_inputs;
741 self.trampoline_data.input_sizes[idx] = size;
742 self.trampoline_data.number_inputs += 1;
743 let (size, emode) = (
744 i32::try_from(size).unwrap(),
745 emode as bind_ceed::CeedEvalMode,
746 );
747 self.qf_core.check_error(unsafe {
748 bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
749 })?;
750 Ok(self)
751 }
752
753 /// Add a QFunction output
754 ///
755 /// * `fieldname` - Name of QFunction field
756 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or
757 /// `(ncomp * 1)` for `None` and `Interp`
758 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp`
759 /// to use interpolated values, `EvalMode::Grad` to use
760 /// gradients
761 ///
762 /// ```
763 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs};
764 /// # fn main() -> libceed::Result<()> {
765 /// # let ceed = libceed::Ceed::default_init();
766 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
767 /// // Iterate over quadrature points
768 /// v.iter_mut()
769 /// .zip(u.iter().zip(weights.iter()))
770 /// .for_each(|(v, (u, w))| *v = u * w);
771 ///
772 /// // Return clean error code
773 /// 0
774 /// };
775 ///
776 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?;
777 ///
778 /// qf.output("v", 1, EvalMode::Interp)?;
779 /// # Ok(())
780 /// # }
781 /// ```
output( mut self, fieldname: &str, size: usize, emode: crate::EvalMode, ) -> crate::Result<Self>782 pub fn output(
783 mut self,
784 fieldname: &str,
785 size: usize,
786 emode: crate::EvalMode,
787 ) -> crate::Result<Self> {
788 let name_c = CString::new(fieldname).expect("CString::new failed");
789 let idx = self.trampoline_data.number_outputs;
790 self.trampoline_data.output_sizes[idx] = size;
791 self.trampoline_data.number_outputs += 1;
792 let (size, emode) = (
793 i32::try_from(size).unwrap(),
794 emode as bind_ceed::CeedEvalMode,
795 );
796 self.qf_core.check_error(unsafe {
797 bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
798 })?;
799 Ok(self)
800 }
801
802 /// Get a slice of QFunction inputs
803 ///
804 /// ```
805 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs};
806 /// # fn main() -> libceed::Result<()> {
807 /// # let ceed = libceed::Ceed::default_init();
808 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
809 /// // Iterate over quadrature points
810 /// v.iter_mut()
811 /// .zip(u.iter().zip(weights.iter()))
812 /// .for_each(|(v, (u, w))| *v = u * w);
813 ///
814 /// // Return clean error code
815 /// 0
816 /// };
817 ///
818 /// let mut qf = ceed
819 /// .q_function_interior(1, Box::new(user_f))?
820 /// .input("u", 1, EvalMode::Interp)?
821 /// .input("weights", 1, EvalMode::Weight)?;
822 ///
823 /// let inputs = qf.inputs()?;
824 ///
825 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array");
826 /// # Ok(())
827 /// # }
828 /// ```
inputs(&self) -> crate::Result<&[QFunctionField]>829 pub fn inputs(&self) -> crate::Result<&[QFunctionField]> {
830 self.qf_core.inputs()
831 }
832
833 /// Get a slice of QFunction outputs
834 ///
835 /// ```
836 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs};
837 /// # fn main() -> libceed::Result<()> {
838 /// # let ceed = libceed::Ceed::default_init();
839 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| {
840 /// // Iterate over quadrature points
841 /// v.iter_mut()
842 /// .zip(u.iter().zip(weights.iter()))
843 /// .for_each(|(v, (u, w))| *v = u * w);
844 ///
845 /// // Return clean error code
846 /// 0
847 /// };
848 ///
849 /// let mut qf = ceed
850 /// .q_function_interior(1, Box::new(user_f))?
851 /// .output("v", 1, EvalMode::Interp)?;
852 ///
853 /// let outputs = qf.outputs()?;
854 ///
855 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array");
856 /// # Ok(())
857 /// # }
858 /// ```
outputs(&self) -> crate::Result<&[QFunctionField]>859 pub fn outputs(&self) -> crate::Result<&[QFunctionField]> {
860 self.qf_core.outputs()
861 }
862 }
863
864 // -----------------------------------------------------------------------------
865 // QFunction
866 // -----------------------------------------------------------------------------
867 impl<'a> QFunctionByName<'a> {
868 // Constructor
create(ceed: &crate::Ceed, name: &str) -> crate::Result<Self>869 pub fn create(ceed: &crate::Ceed, name: &str) -> crate::Result<Self> {
870 let name_c = CString::new(name).expect("CString::new failed");
871 let mut ptr = std::ptr::null_mut();
872 ceed.check_error(unsafe {
873 bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr)
874 })?;
875 Ok(Self {
876 qf_core: QFunctionCore {
877 ptr,
878 _lifeline: PhantomData,
879 },
880 })
881 }
882
883 /// Apply the action of a QFunction
884 ///
885 /// * `Q` - The number of quadrature points
886 /// * `input` - Array of input Vectors
887 /// * `output` - Array of output Vectors
888 ///
889 /// ```
890 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs, Scalar};
891 /// # fn main() -> libceed::Result<()> {
892 /// # let ceed = libceed::Ceed::default_init();
893 /// const Q: usize = 8;
894 /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild")?;
895 /// let qf_mass = ceed.q_function_interior_by_name("MassApply")?;
896 ///
897 /// let mut j = [0.; Q];
898 /// let mut w = [0.; Q];
899 /// let mut u = [0.; Q];
900 /// let mut v = [0.; Q];
901 ///
902 /// for i in 0..Q {
903 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.;
904 /// j[i] = 1.;
905 /// w[i] = 1. - x * x;
906 /// u[i] = 2. + 3. * x + 5. * x * x;
907 /// v[i] = w[i] * u[i];
908 /// }
909 ///
910 /// let jj = ceed.vector_from_slice(&j)?;
911 /// let ww = ceed.vector_from_slice(&w)?;
912 /// let uu = ceed.vector_from_slice(&u)?;
913 /// let mut vv = ceed.vector(Q)?;
914 /// vv.set_value(0.0);
915 /// let mut qdata = ceed.vector(Q)?;
916 /// qdata.set_value(0.0);
917 ///
918 /// {
919 /// let mut input = vec![jj, ww];
920 /// let mut output = vec![qdata];
921 /// qf_build.apply(Q, &input, &output)?;
922 /// qdata = output.remove(0);
923 /// }
924 ///
925 /// {
926 /// let mut input = vec![qdata, uu];
927 /// let mut output = vec![vv];
928 /// qf_mass.apply(Q, &input, &output)?;
929 /// vv = output.remove(0);
930 /// }
931 ///
932 /// vv.view()?
933 /// .iter()
934 /// .zip(v.iter())
935 /// .for_each(|(computed, actual)| {
936 /// assert_eq!(
937 /// *computed, *actual,
938 /// "Incorrect value in QFunction application"
939 /// );
940 /// });
941 /// # Ok(())
942 /// # }
943 /// ```
apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32>944 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
945 self.qf_core.apply(Q, u, v)
946 }
947
948 /// Get a slice of QFunction inputs
949 ///
950 /// ```
951 /// # use libceed::prelude::*;
952 /// # fn main() -> libceed::Result<()> {
953 /// # let ceed = libceed::Ceed::default_init();
954 /// const Q: usize = 8;
955 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
956 ///
957 /// let inputs = qf.inputs()?;
958 ///
959 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array");
960 /// # Ok(())
961 /// # }
962 /// ```
inputs(&self) -> crate::Result<&[QFunctionField]>963 pub fn inputs(&self) -> crate::Result<&[QFunctionField]> {
964 self.qf_core.inputs()
965 }
966
967 /// Get a slice of QFunction outputs
968 ///
969 /// ```
970 /// # use libceed::prelude::*;
971 /// # fn main() -> libceed::Result<()> {
972 /// # let ceed = libceed::Ceed::default_init();
973 /// const Q: usize = 8;
974 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?;
975 ///
976 /// let outputs = qf.outputs()?;
977 ///
978 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array");
979 /// # Ok(())
980 /// # }
981 /// ```
outputs(&self) -> crate::Result<&[QFunctionField]>982 pub fn outputs(&self) -> crate::Result<&[QFunctionField]> {
983 self.qf_core.outputs()
984 }
985 }
986
987 // -----------------------------------------------------------------------------
988