1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 //! A Ceed QFunction represents the spatial terms of the point-wise functions 9 //! describing the physics at the quadrature points. 10 11 use std::pin::Pin; 12 13 use crate::{prelude::*, vector::Vector, MAX_QFUNCTION_FIELDS}; 14 15 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 16 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 17 18 // ----------------------------------------------------------------------------- 19 // QFunction Field context wrapper 20 // ----------------------------------------------------------------------------- 21 #[derive(Debug)] 22 pub struct QFunctionField<'a> { 23 ptr: bind_ceed::CeedQFunctionField, 24 _lifeline: PhantomData<&'a ()>, 25 } 26 27 // ----------------------------------------------------------------------------- 28 // Implementations 29 // ----------------------------------------------------------------------------- 30 impl<'a> QFunctionField<'a> { 31 /// Get the name of a QFunctionField 32 /// 33 /// ``` 34 /// # use libceed::prelude::*; 35 /// # fn main() -> libceed::Result<()> { 36 /// # let ceed = libceed::Ceed::default_init(); 37 /// const Q: usize = 8; 38 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 39 /// 40 /// let inputs = qf.inputs()?; 41 /// 42 /// assert_eq!(inputs[0].name(), "dx", "Incorrect input name"); 43 /// assert_eq!(inputs[1].name(), "weights", "Incorrect input name"); 44 /// # Ok(()) 45 /// # } 46 /// ``` 47 pub fn name(&self) -> &str { 48 let mut name_ptr: *mut std::os::raw::c_char = std::ptr::null_mut(); 49 unsafe { 50 bind_ceed::CeedQFunctionFieldGetName( 51 self.ptr, 52 &mut name_ptr as *const _ as *mut *const _, 53 ); 54 } 55 unsafe { CStr::from_ptr(name_ptr) }.to_str().unwrap() 56 } 57 58 /// Get the size of a QFunctionField 59 /// 60 /// ``` 61 /// # use libceed::prelude::*; 62 /// # fn main() -> libceed::Result<()> { 63 /// # let ceed = libceed::Ceed::default_init(); 64 /// const Q: usize = 8; 65 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 66 /// 67 /// let inputs = qf.inputs()?; 68 /// 69 /// assert_eq!(inputs[0].size(), 4, "Incorrect input size"); 70 /// assert_eq!(inputs[1].size(), 1, "Incorrect input size"); 71 /// # Ok(()) 72 /// # } 73 /// ``` 74 pub fn size(&self) -> usize { 75 let mut size = 0; 76 unsafe { 77 bind_ceed::CeedQFunctionFieldGetSize(self.ptr, &mut size); 78 } 79 usize::try_from(size).unwrap() 80 } 81 82 /// Get the evaluation mode of a QFunctionField 83 /// 84 /// ``` 85 /// # use libceed::{prelude::*, EvalMode}; 86 /// # fn main() -> libceed::Result<()> { 87 /// # let ceed = libceed::Ceed::default_init(); 88 /// const Q: usize = 8; 89 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 90 /// 91 /// let inputs = qf.inputs()?; 92 /// 93 /// assert_eq!( 94 /// inputs[0].eval_mode(), 95 /// EvalMode::Grad, 96 /// "Incorrect input evaluation mode" 97 /// ); 98 /// assert_eq!( 99 /// inputs[1].eval_mode(), 100 /// EvalMode::Weight, 101 /// "Incorrect input evaluation mode" 102 /// ); 103 /// # Ok(()) 104 /// # } 105 /// ``` 106 pub fn eval_mode(&self) -> crate::EvalMode { 107 let mut mode = 0; 108 unsafe { 109 bind_ceed::CeedQFunctionFieldGetEvalMode(self.ptr, &mut mode); 110 } 111 crate::EvalMode::from_u32(mode as u32) 112 } 113 } 114 115 // ----------------------------------------------------------------------------- 116 // QFunction option 117 // ----------------------------------------------------------------------------- 118 pub enum QFunctionOpt<'a> { 119 SomeQFunction(&'a QFunction<'a>), 120 SomeQFunctionByName(&'a QFunctionByName<'a>), 121 None, 122 } 123 124 /// Construct a QFunctionOpt reference from a QFunction reference 125 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> { 126 fn from(qfunc: &'a QFunction) -> Self { 127 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 128 Self::SomeQFunction(qfunc) 129 } 130 } 131 132 /// Construct a QFunctionOpt reference from a QFunction by Name reference 133 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> { 134 fn from(qfunc: &'a QFunctionByName) -> Self { 135 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 136 Self::SomeQFunctionByName(qfunc) 137 } 138 } 139 140 impl<'a> QFunctionOpt<'a> { 141 /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction 142 pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction { 143 match self { 144 Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr, 145 Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr, 146 Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE }, 147 } 148 } 149 150 /// Check if a QFunctionOpt is Some 151 /// 152 /// ``` 153 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOpt, QFunctionOutputs}; 154 /// # fn main() -> libceed::Result<()> { 155 /// # let ceed = libceed::Ceed::default_init(); 156 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 157 /// // Iterate over quadrature points 158 /// v.iter_mut() 159 /// .zip(u.iter().zip(weights.iter())) 160 /// .for_each(|(v, (u, w))| *v = u * w); 161 /// 162 /// // Return clean error code 163 /// 0 164 /// }; 165 /// 166 /// let qf = ceed 167 /// .q_function_interior(1, Box::new(user_f))? 168 /// .input("u", 1, EvalMode::Interp)? 169 /// .input("weights", 1, EvalMode::Weight)? 170 /// .output("v", 1, EvalMode::Interp)?; 171 /// let qf_opt = QFunctionOpt::from(&qf); 172 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt"); 173 /// 174 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 175 /// let qf_opt = QFunctionOpt::from(&qf); 176 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt"); 177 /// 178 /// let qf_opt = QFunctionOpt::None; 179 /// assert!(!qf_opt.is_some(), "Incorrect QFunctionOpt"); 180 /// # Ok(()) 181 /// # } 182 /// ``` 183 pub fn is_some(&self) -> bool { 184 match self { 185 Self::SomeQFunction(_) => true, 186 Self::SomeQFunctionByName(_) => true, 187 Self::None => false, 188 } 189 } 190 191 /// Check if a QFunctionOpt is SomeQFunction 192 /// 193 /// ``` 194 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOpt, QFunctionOutputs}; 195 /// # fn main() -> libceed::Result<()> { 196 /// # let ceed = libceed::Ceed::default_init(); 197 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 198 /// // Iterate over quadrature points 199 /// v.iter_mut() 200 /// .zip(u.iter().zip(weights.iter())) 201 /// .for_each(|(v, (u, w))| *v = u * w); 202 /// 203 /// // Return clean error code 204 /// 0 205 /// }; 206 /// 207 /// let qf = ceed 208 /// .q_function_interior(1, Box::new(user_f))? 209 /// .input("u", 1, EvalMode::Interp)? 210 /// .input("weights", 1, EvalMode::Weight)? 211 /// .output("v", 1, EvalMode::Interp)?; 212 /// let qf_opt = QFunctionOpt::from(&qf); 213 /// assert!(qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 214 /// 215 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 216 /// let qf_opt = QFunctionOpt::from(&qf); 217 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 218 /// 219 /// let qf_opt = QFunctionOpt::None; 220 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 221 /// # Ok(()) 222 /// # } 223 /// ``` 224 pub fn is_some_q_function(&self) -> bool { 225 match self { 226 Self::SomeQFunction(_) => true, 227 Self::SomeQFunctionByName(_) => false, 228 Self::None => false, 229 } 230 } 231 232 /// Check if a QFunctionOpt is SomeQFunctionByName 233 /// 234 /// ``` 235 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOpt, QFunctionOutputs}; 236 /// # fn main() -> libceed::Result<()> { 237 /// # let ceed = libceed::Ceed::default_init(); 238 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 239 /// // Iterate over quadrature points 240 /// v.iter_mut() 241 /// .zip(u.iter().zip(weights.iter())) 242 /// .for_each(|(v, (u, w))| *v = u * w); 243 /// 244 /// // Return clean error code 245 /// 0 246 /// }; 247 /// 248 /// let qf = ceed 249 /// .q_function_interior(1, Box::new(user_f))? 250 /// .input("u", 1, EvalMode::Interp)? 251 /// .input("weights", 1, EvalMode::Weight)? 252 /// .output("v", 1, EvalMode::Interp)?; 253 /// let qf_opt = QFunctionOpt::from(&qf); 254 /// assert!( 255 /// !qf_opt.is_some_q_function_by_name(), 256 /// "Incorrect QFunctionOpt" 257 /// ); 258 /// 259 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 260 /// let qf_opt = QFunctionOpt::from(&qf); 261 /// assert!( 262 /// qf_opt.is_some_q_function_by_name(), 263 /// "Incorrect QFunctionOpt" 264 /// ); 265 /// 266 /// let qf_opt = QFunctionOpt::None; 267 /// assert!( 268 /// !qf_opt.is_some_q_function_by_name(), 269 /// "Incorrect QFunctionOpt" 270 /// ); 271 /// # Ok(()) 272 /// # } 273 /// ``` 274 pub fn is_some_q_function_by_name(&self) -> bool { 275 match self { 276 Self::SomeQFunction(_) => false, 277 Self::SomeQFunctionByName(_) => true, 278 Self::None => false, 279 } 280 } 281 282 /// Check if a QFunctionOpt is None 283 /// 284 /// ``` 285 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOpt, QFunctionOutputs}; 286 /// # fn main() -> libceed::Result<()> { 287 /// # let ceed = libceed::Ceed::default_init(); 288 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 289 /// // Iterate over quadrature points 290 /// v.iter_mut() 291 /// .zip(u.iter().zip(weights.iter())) 292 /// .for_each(|(v, (u, w))| *v = u * w); 293 /// 294 /// // Return clean error code 295 /// 0 296 /// }; 297 /// 298 /// let qf = ceed 299 /// .q_function_interior(1, Box::new(user_f))? 300 /// .input("u", 1, EvalMode::Interp)? 301 /// .input("weights", 1, EvalMode::Weight)? 302 /// .output("v", 1, EvalMode::Interp)?; 303 /// let qf_opt = QFunctionOpt::from(&qf); 304 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt"); 305 /// 306 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 307 /// let qf_opt = QFunctionOpt::from(&qf); 308 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt"); 309 /// 310 /// let qf_opt = QFunctionOpt::None; 311 /// assert!(qf_opt.is_none(), "Incorrect QFunctionOpt"); 312 /// # Ok(()) 313 /// # } 314 /// ``` 315 pub fn is_none(&self) -> bool { 316 match self { 317 Self::SomeQFunction(_) => false, 318 Self::SomeQFunctionByName(_) => false, 319 Self::None => true, 320 } 321 } 322 } 323 324 // ----------------------------------------------------------------------------- 325 // QFunction context wrapper 326 // ----------------------------------------------------------------------------- 327 #[derive(Debug)] 328 pub(crate) struct QFunctionCore<'a> { 329 ptr: bind_ceed::CeedQFunction, 330 _lifeline: PhantomData<&'a ()>, 331 } 332 333 struct QFunctionTrampolineData { 334 number_inputs: usize, 335 number_outputs: usize, 336 input_sizes: [usize; MAX_QFUNCTION_FIELDS], 337 output_sizes: [usize; MAX_QFUNCTION_FIELDS], 338 user_f: Box<QFunctionUserClosure>, 339 } 340 341 pub struct QFunction<'a> { 342 qf_core: QFunctionCore<'a>, 343 qf_ctx_ptr: bind_ceed::CeedQFunctionContext, 344 trampoline_data: Pin<Box<QFunctionTrampolineData>>, 345 } 346 347 #[derive(Debug)] 348 pub struct QFunctionByName<'a> { 349 qf_core: QFunctionCore<'a>, 350 } 351 352 // ----------------------------------------------------------------------------- 353 // Destructor 354 // ----------------------------------------------------------------------------- 355 impl<'a> Drop for QFunctionCore<'a> { 356 fn drop(&mut self) { 357 unsafe { 358 if self.ptr != bind_ceed::CEED_QFUNCTION_NONE { 359 bind_ceed::CeedQFunctionDestroy(&mut self.ptr); 360 } 361 } 362 } 363 } 364 365 impl<'a> Drop for QFunction<'a> { 366 fn drop(&mut self) { 367 unsafe { 368 bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr); 369 } 370 } 371 } 372 373 // ----------------------------------------------------------------------------- 374 // Display 375 // ----------------------------------------------------------------------------- 376 impl<'a> fmt::Display for QFunctionCore<'a> { 377 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 378 let mut ptr = std::ptr::null_mut(); 379 let mut sizeloc = crate::MAX_BUFFER_LENGTH; 380 let cstring = unsafe { 381 let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc); 382 bind_ceed::CeedQFunctionView(self.ptr, file); 383 bind_ceed::fclose(file); 384 CString::from_raw(ptr) 385 }; 386 cstring.to_string_lossy().fmt(f) 387 } 388 } 389 /// View a QFunction 390 /// 391 /// ``` 392 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs}; 393 /// # fn main() -> libceed::Result<()> { 394 /// # let ceed = libceed::Ceed::default_init(); 395 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 396 /// // Iterate over quadrature points 397 /// v.iter_mut() 398 /// .zip(u.iter().zip(weights.iter())) 399 /// .for_each(|(v, (u, w))| *v = u * w); 400 /// 401 /// // Return clean error code 402 /// 0 403 /// }; 404 /// 405 /// let qf = ceed 406 /// .q_function_interior(1, Box::new(user_f))? 407 /// .input("u", 1, EvalMode::Interp)? 408 /// .input("weights", 1, EvalMode::Weight)? 409 /// .output("v", 1, EvalMode::Interp)?; 410 /// 411 /// println!("{}", qf); 412 /// # Ok(()) 413 /// # } 414 /// ``` 415 impl<'a> fmt::Display for QFunction<'a> { 416 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 417 self.qf_core.fmt(f) 418 } 419 } 420 421 /// View a QFunction by Name 422 /// 423 /// ``` 424 /// # use libceed::prelude::*; 425 /// # fn main() -> libceed::Result<()> { 426 /// # let ceed = libceed::Ceed::default_init(); 427 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 428 /// println!("{}", qf); 429 /// # Ok(()) 430 /// # } 431 /// ``` 432 impl<'a> fmt::Display for QFunctionByName<'a> { 433 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 434 self.qf_core.fmt(f) 435 } 436 } 437 438 // ----------------------------------------------------------------------------- 439 // Core functionality 440 // ----------------------------------------------------------------------------- 441 impl<'a> QFunctionCore<'a> { 442 // Raw Ceed for error handling 443 #[doc(hidden)] 444 fn ceed(&self) -> bind_ceed::Ceed { 445 unsafe { bind_ceed::CeedQFunctionReturnCeed(self.ptr) } 446 } 447 448 // Error handling 449 #[doc(hidden)] 450 fn check_error(&self, ierr: i32) -> crate::Result<i32> { 451 crate::check_error(|| self.ceed(), ierr) 452 } 453 454 // Common implementation 455 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 456 let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 457 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) { 458 u_c[i] = u[i].ptr; 459 } 460 let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 461 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) { 462 v_c[i] = v[i].ptr; 463 } 464 let Q = i32::try_from(Q).unwrap(); 465 self.check_error(unsafe { 466 bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr()) 467 }) 468 } 469 470 pub fn inputs(&self) -> crate::Result<&[QFunctionField]> { 471 // Get array of raw C pointers for inputs 472 let mut num_inputs = 0; 473 let mut inputs_ptr = std::ptr::null_mut(); 474 self.check_error(unsafe { 475 bind_ceed::CeedQFunctionGetFields( 476 self.ptr, 477 &mut num_inputs, 478 &mut inputs_ptr, 479 std::ptr::null_mut() as *mut bind_ceed::CeedInt, 480 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField, 481 ) 482 })?; 483 // Convert raw C pointers to fixed length slice 484 let inputs_slice = unsafe { 485 std::slice::from_raw_parts(inputs_ptr as *const QFunctionField, num_inputs as usize) 486 }; 487 Ok(inputs_slice) 488 } 489 490 pub fn outputs(&self) -> crate::Result<&[QFunctionField]> { 491 // Get array of raw C pointers for outputs 492 let mut num_outputs = 0; 493 let mut outputs_ptr = std::ptr::null_mut(); 494 self.check_error(unsafe { 495 bind_ceed::CeedQFunctionGetFields( 496 self.ptr, 497 std::ptr::null_mut() as *mut bind_ceed::CeedInt, 498 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField, 499 &mut num_outputs, 500 &mut outputs_ptr, 501 ) 502 })?; 503 // Convert raw C pointers to fixed length slice 504 let outputs_slice = unsafe { 505 std::slice::from_raw_parts(outputs_ptr as *const QFunctionField, num_outputs as usize) 506 }; 507 Ok(outputs_slice) 508 } 509 } 510 511 // ----------------------------------------------------------------------------- 512 // User QFunction Closure 513 // ----------------------------------------------------------------------------- 514 pub type QFunctionUserClosure = dyn FnMut( 515 [&[crate::Scalar]; MAX_QFUNCTION_FIELDS], 516 [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS], 517 ) -> i32; 518 519 macro_rules! mut_max_fields { 520 ($e:expr) => { 521 [ 522 $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, 523 ] 524 }; 525 } 526 unsafe extern "C" fn trampoline( 527 ctx: *mut ::std::os::raw::c_void, 528 q: bind_ceed::CeedInt, 529 inputs: *const *const bind_ceed::CeedScalar, 530 outputs: *const *mut bind_ceed::CeedScalar, 531 ) -> ::std::os::raw::c_int { 532 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx); 533 534 // Inputs 535 let inputs_slice: &[*const bind_ceed::CeedScalar] = 536 std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS); 537 let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS]; 538 inputs_slice 539 .iter() 540 .take(trampoline_data.number_inputs) 541 .enumerate() 542 .map(|(i, &x)| { 543 std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize) 544 as &[crate::Scalar] 545 }) 546 .zip(inputs_array.iter_mut()) 547 .for_each(|(x, a)| *a = x); 548 549 // Outputs 550 let outputs_slice: &[*mut bind_ceed::CeedScalar] = 551 std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS); 552 let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] = 553 mut_max_fields!(&mut [0.0]); 554 outputs_slice 555 .iter() 556 .take(trampoline_data.number_outputs) 557 .enumerate() 558 .map(|(i, &x)| { 559 std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize) 560 as &mut [crate::Scalar] 561 }) 562 .zip(outputs_array.iter_mut()) 563 .for_each(|(x, a)| *a = x); 564 565 // User closure 566 (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array) 567 } 568 569 unsafe extern "C" fn destroy_trampoline(ctx: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int { 570 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx); 571 drop(trampoline_data); 572 0 // Clean error code 573 } 574 575 // ----------------------------------------------------------------------------- 576 // QFunction 577 // ----------------------------------------------------------------------------- 578 impl<'a> QFunction<'a> { 579 // Constructor 580 pub fn create( 581 ceed: &crate::Ceed, 582 vlength: usize, 583 user_f: Box<QFunctionUserClosure>, 584 ) -> crate::Result<Self> { 585 let source_c = CString::new("").expect("CString::new failed"); 586 let mut ptr = std::ptr::null_mut(); 587 588 // Context for closure 589 let number_inputs = 0; 590 let number_outputs = 0; 591 let input_sizes = [0; MAX_QFUNCTION_FIELDS]; 592 let output_sizes = [0; MAX_QFUNCTION_FIELDS]; 593 let trampoline_data = unsafe { 594 Pin::new_unchecked(Box::new(QFunctionTrampolineData { 595 number_inputs, 596 number_outputs, 597 input_sizes, 598 output_sizes, 599 user_f, 600 })) 601 }; 602 603 // Create QFunction 604 let vlength = i32::try_from(vlength).unwrap(); 605 ceed.check_error(unsafe { 606 bind_ceed::CeedQFunctionCreateInterior( 607 ceed.ptr, 608 vlength, 609 Some(trampoline), 610 source_c.as_ptr(), 611 &mut ptr, 612 ) 613 })?; 614 615 // Set closure 616 let mut qf_ctx_ptr = std::ptr::null_mut(); 617 ceed.check_error(unsafe { 618 bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) 619 })?; 620 ceed.check_error(unsafe { 621 bind_ceed::CeedQFunctionContextSetData( 622 qf_ctx_ptr, 623 crate::MemType::Host as bind_ceed::CeedMemType, 624 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode, 625 std::mem::size_of::<QFunctionTrampolineData>(), 626 std::mem::transmute(trampoline_data.as_ref()), 627 ) 628 })?; 629 ceed.check_error(unsafe { 630 bind_ceed::CeedQFunctionContextSetDataDestroy( 631 qf_ctx_ptr, 632 crate::MemType::Host as bind_ceed::CeedMemType, 633 Some(destroy_trampoline), 634 ) 635 })?; 636 ceed.check_error(unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) })?; 637 ceed.check_error(unsafe { bind_ceed::CeedQFunctionContextDestroy(&mut qf_ctx_ptr) })?; 638 Ok(Self { 639 qf_core: QFunctionCore { 640 ptr, 641 _lifeline: PhantomData, 642 }, 643 qf_ctx_ptr, 644 trampoline_data, 645 }) 646 } 647 648 /// Apply the action of a QFunction 649 /// 650 /// * `Q` - The number of quadrature points 651 /// * `input` - Array of input Vectors 652 /// * `output` - Array of output Vectors 653 /// 654 /// ``` 655 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs, Scalar}; 656 /// # fn main() -> libceed::Result<()> { 657 /// # let ceed = libceed::Ceed::default_init(); 658 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 659 /// // Iterate over quadrature points 660 /// v.iter_mut() 661 /// .zip(u.iter().zip(weights.iter())) 662 /// .for_each(|(v, (u, w))| *v = u * w); 663 /// 664 /// // Return clean error code 665 /// 0 666 /// }; 667 /// 668 /// let qf = ceed 669 /// .q_function_interior(1, Box::new(user_f))? 670 /// .input("u", 1, EvalMode::Interp)? 671 /// .input("weights", 1, EvalMode::Weight)? 672 /// .output("v", 1, EvalMode::Interp)?; 673 /// 674 /// const Q: usize = 8; 675 /// let mut w = [0.; Q]; 676 /// let mut u = [0.; Q]; 677 /// let mut v = [0.; Q]; 678 /// 679 /// for i in 0..Q { 680 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 681 /// u[i] = 2. + 3. * x + 5. * x * x; 682 /// w[i] = 1. - x * x; 683 /// v[i] = u[i] * w[i]; 684 /// } 685 /// 686 /// let uu = ceed.vector_from_slice(&u)?; 687 /// let ww = ceed.vector_from_slice(&w)?; 688 /// let mut vv = ceed.vector(Q)?; 689 /// vv.set_value(0.0); 690 /// { 691 /// let input = vec![uu, ww]; 692 /// let mut output = vec![vv]; 693 /// qf.apply(Q, &input, &output)?; 694 /// vv = output.remove(0); 695 /// } 696 /// 697 /// vv.view()? 698 /// .iter() 699 /// .zip(v.iter()) 700 /// .for_each(|(computed, actual)| { 701 /// assert_eq!( 702 /// *computed, *actual, 703 /// "Incorrect value in QFunction application" 704 /// ); 705 /// }); 706 /// # Ok(()) 707 /// # } 708 /// ``` 709 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 710 self.qf_core.apply(Q, u, v) 711 } 712 713 /// Add a QFunction input 714 /// 715 /// * `fieldname` - Name of QFunction field 716 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 717 /// `(ncomp * 1)` for `None`, `Interp`, and `Weight` 718 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 719 /// to use interpolated values, `EvalMode::Grad` to use 720 /// gradients, `EvalMode::Weight` to use quadrature weights 721 /// 722 /// ``` 723 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs}; 724 /// # fn main() -> libceed::Result<()> { 725 /// # let ceed = libceed::Ceed::default_init(); 726 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 727 /// // Iterate over quadrature points 728 /// v.iter_mut() 729 /// .zip(u.iter().zip(weights.iter())) 730 /// .for_each(|(v, (u, w))| *v = u * w); 731 /// 732 /// // Return clean error code 733 /// 0 734 /// }; 735 /// 736 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 737 /// 738 /// qf = qf.input("u", 1, EvalMode::Interp)?; 739 /// qf = qf.input("weights", 1, EvalMode::Weight)?; 740 /// # Ok(()) 741 /// # } 742 /// ``` 743 pub fn input( 744 mut self, 745 fieldname: &str, 746 size: usize, 747 emode: crate::EvalMode, 748 ) -> crate::Result<Self> { 749 let name_c = CString::new(fieldname).expect("CString::new failed"); 750 let idx = self.trampoline_data.number_inputs; 751 self.trampoline_data.input_sizes[idx] = size; 752 self.trampoline_data.number_inputs += 1; 753 let (size, emode) = ( 754 i32::try_from(size).unwrap(), 755 emode as bind_ceed::CeedEvalMode, 756 ); 757 self.qf_core.check_error(unsafe { 758 bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 759 })?; 760 Ok(self) 761 } 762 763 /// Add a QFunction output 764 /// 765 /// * `fieldname` - Name of QFunction field 766 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 767 /// `(ncomp * 1)` for `None` and `Interp` 768 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 769 /// to use interpolated values, `EvalMode::Grad` to use 770 /// gradients 771 /// 772 /// ``` 773 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs}; 774 /// # fn main() -> libceed::Result<()> { 775 /// # let ceed = libceed::Ceed::default_init(); 776 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 777 /// // Iterate over quadrature points 778 /// v.iter_mut() 779 /// .zip(u.iter().zip(weights.iter())) 780 /// .for_each(|(v, (u, w))| *v = u * w); 781 /// 782 /// // Return clean error code 783 /// 0 784 /// }; 785 /// 786 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 787 /// 788 /// qf.output("v", 1, EvalMode::Interp)?; 789 /// # Ok(()) 790 /// # } 791 /// ``` 792 pub fn output( 793 mut self, 794 fieldname: &str, 795 size: usize, 796 emode: crate::EvalMode, 797 ) -> crate::Result<Self> { 798 let name_c = CString::new(fieldname).expect("CString::new failed"); 799 let idx = self.trampoline_data.number_outputs; 800 self.trampoline_data.output_sizes[idx] = size; 801 self.trampoline_data.number_outputs += 1; 802 let (size, emode) = ( 803 i32::try_from(size).unwrap(), 804 emode as bind_ceed::CeedEvalMode, 805 ); 806 self.qf_core.check_error(unsafe { 807 bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 808 })?; 809 Ok(self) 810 } 811 812 /// Get a slice of QFunction inputs 813 /// 814 /// ``` 815 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs}; 816 /// # fn main() -> libceed::Result<()> { 817 /// # let ceed = libceed::Ceed::default_init(); 818 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 819 /// // Iterate over quadrature points 820 /// v.iter_mut() 821 /// .zip(u.iter().zip(weights.iter())) 822 /// .for_each(|(v, (u, w))| *v = u * w); 823 /// 824 /// // Return clean error code 825 /// 0 826 /// }; 827 /// 828 /// let mut qf = ceed 829 /// .q_function_interior(1, Box::new(user_f))? 830 /// .input("u", 1, EvalMode::Interp)? 831 /// .input("weights", 1, EvalMode::Weight)?; 832 /// 833 /// let inputs = qf.inputs()?; 834 /// 835 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array"); 836 /// # Ok(()) 837 /// # } 838 /// ``` 839 pub fn inputs(&self) -> crate::Result<&[QFunctionField]> { 840 self.qf_core.inputs() 841 } 842 843 /// Get a slice of QFunction outputs 844 /// 845 /// ``` 846 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs}; 847 /// # fn main() -> libceed::Result<()> { 848 /// # let ceed = libceed::Ceed::default_init(); 849 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 850 /// // Iterate over quadrature points 851 /// v.iter_mut() 852 /// .zip(u.iter().zip(weights.iter())) 853 /// .for_each(|(v, (u, w))| *v = u * w); 854 /// 855 /// // Return clean error code 856 /// 0 857 /// }; 858 /// 859 /// let mut qf = ceed 860 /// .q_function_interior(1, Box::new(user_f))? 861 /// .output("v", 1, EvalMode::Interp)?; 862 /// 863 /// let outputs = qf.outputs()?; 864 /// 865 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array"); 866 /// # Ok(()) 867 /// # } 868 /// ``` 869 pub fn outputs(&self) -> crate::Result<&[QFunctionField]> { 870 self.qf_core.outputs() 871 } 872 } 873 874 // ----------------------------------------------------------------------------- 875 // QFunction 876 // ----------------------------------------------------------------------------- 877 impl<'a> QFunctionByName<'a> { 878 // Constructor 879 pub fn create(ceed: &crate::Ceed, name: &str) -> crate::Result<Self> { 880 let name_c = CString::new(name).expect("CString::new failed"); 881 let mut ptr = std::ptr::null_mut(); 882 ceed.check_error(unsafe { 883 bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr) 884 })?; 885 Ok(Self { 886 qf_core: QFunctionCore { 887 ptr, 888 _lifeline: PhantomData, 889 }, 890 }) 891 } 892 893 /// Apply the action of a QFunction 894 /// 895 /// * `Q` - The number of quadrature points 896 /// * `input` - Array of input Vectors 897 /// * `output` - Array of output Vectors 898 /// 899 /// ``` 900 /// # use libceed::{prelude::*, EvalMode, QFunctionInputs, QFunctionOutputs, Scalar}; 901 /// # fn main() -> libceed::Result<()> { 902 /// # let ceed = libceed::Ceed::default_init(); 903 /// const Q: usize = 8; 904 /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild")?; 905 /// let qf_mass = ceed.q_function_interior_by_name("MassApply")?; 906 /// 907 /// let mut j = [0.; Q]; 908 /// let mut w = [0.; Q]; 909 /// let mut u = [0.; Q]; 910 /// let mut v = [0.; Q]; 911 /// 912 /// for i in 0..Q { 913 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 914 /// j[i] = 1.; 915 /// w[i] = 1. - x * x; 916 /// u[i] = 2. + 3. * x + 5. * x * x; 917 /// v[i] = w[i] * u[i]; 918 /// } 919 /// 920 /// let jj = ceed.vector_from_slice(&j)?; 921 /// let ww = ceed.vector_from_slice(&w)?; 922 /// let uu = ceed.vector_from_slice(&u)?; 923 /// let mut vv = ceed.vector(Q)?; 924 /// vv.set_value(0.0); 925 /// let mut qdata = ceed.vector(Q)?; 926 /// qdata.set_value(0.0); 927 /// 928 /// { 929 /// let mut input = vec![jj, ww]; 930 /// let mut output = vec![qdata]; 931 /// qf_build.apply(Q, &input, &output)?; 932 /// qdata = output.remove(0); 933 /// } 934 /// 935 /// { 936 /// let mut input = vec![qdata, uu]; 937 /// let mut output = vec![vv]; 938 /// qf_mass.apply(Q, &input, &output)?; 939 /// vv = output.remove(0); 940 /// } 941 /// 942 /// vv.view()? 943 /// .iter() 944 /// .zip(v.iter()) 945 /// .for_each(|(computed, actual)| { 946 /// assert_eq!( 947 /// *computed, *actual, 948 /// "Incorrect value in QFunction application" 949 /// ); 950 /// }); 951 /// # Ok(()) 952 /// # } 953 /// ``` 954 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 955 self.qf_core.apply(Q, u, v) 956 } 957 958 /// Get a slice of QFunction inputs 959 /// 960 /// ``` 961 /// # use libceed::prelude::*; 962 /// # fn main() -> libceed::Result<()> { 963 /// # let ceed = libceed::Ceed::default_init(); 964 /// const Q: usize = 8; 965 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 966 /// 967 /// let inputs = qf.inputs()?; 968 /// 969 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array"); 970 /// # Ok(()) 971 /// # } 972 /// ``` 973 pub fn inputs(&self) -> crate::Result<&[QFunctionField]> { 974 self.qf_core.inputs() 975 } 976 977 /// Get a slice of QFunction outputs 978 /// 979 /// ``` 980 /// # use libceed::prelude::*; 981 /// # fn main() -> libceed::Result<()> { 982 /// # let ceed = libceed::Ceed::default_init(); 983 /// const Q: usize = 8; 984 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 985 /// 986 /// let outputs = qf.outputs()?; 987 /// 988 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array"); 989 /// # Ok(()) 990 /// # } 991 /// ``` 992 pub fn outputs(&self) -> crate::Result<&[QFunctionField]> { 993 self.qf_core.outputs() 994 } 995 } 996 997 // ----------------------------------------------------------------------------- 998