1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 //! A Ceed QFunction represents the spatial terms of the point-wise functions 9 //! describing the physics at the quadrature points. 10 11 use std::pin::Pin; 12 13 use crate::prelude::*; 14 15 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 16 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 17 18 // ----------------------------------------------------------------------------- 19 // QFunction Field context wrapper 20 // ----------------------------------------------------------------------------- 21 #[derive(Debug)] 22 pub struct QFunctionField<'a> { 23 ptr: bind_ceed::CeedQFunctionField, 24 _lifeline: PhantomData<&'a ()>, 25 } 26 27 // ----------------------------------------------------------------------------- 28 // Implementations 29 // ----------------------------------------------------------------------------- 30 impl<'a> QFunctionField<'a> { 31 /// Get the name of a QFunctionField 32 /// 33 /// ``` 34 /// # use libceed::prelude::*; 35 /// # fn main() -> libceed::Result<()> { 36 /// # let ceed = libceed::Ceed::default_init(); 37 /// const Q: usize = 8; 38 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 39 /// 40 /// let inputs = qf.inputs()?; 41 /// 42 /// assert_eq!(inputs[0].name(), "dx", "Incorrect input name"); 43 /// assert_eq!(inputs[1].name(), "weights", "Incorrect input name"); 44 /// # Ok(()) 45 /// # } 46 /// ``` 47 pub fn name(&self) -> &str { 48 let mut name_ptr: *mut std::os::raw::c_char = std::ptr::null_mut(); 49 unsafe { 50 bind_ceed::CeedQFunctionFieldGetName(self.ptr, &mut name_ptr); 51 } 52 unsafe { CStr::from_ptr(name_ptr) }.to_str().unwrap() 53 } 54 55 /// Get the size of a QFunctionField 56 /// 57 /// ``` 58 /// # use libceed::prelude::*; 59 /// # fn main() -> libceed::Result<()> { 60 /// # let ceed = libceed::Ceed::default_init(); 61 /// const Q: usize = 8; 62 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 63 /// 64 /// let inputs = qf.inputs()?; 65 /// 66 /// assert_eq!(inputs[0].size(), 4, "Incorrect input size"); 67 /// assert_eq!(inputs[1].size(), 1, "Incorrect input size"); 68 /// # Ok(()) 69 /// # } 70 /// ``` 71 pub fn size(&self) -> usize { 72 let mut size = 0; 73 unsafe { 74 bind_ceed::CeedQFunctionFieldGetSize(self.ptr, &mut size); 75 } 76 usize::try_from(size).unwrap() 77 } 78 79 /// Get the evaluation mode of a QFunctionField 80 /// 81 /// ``` 82 /// # use libceed::prelude::*; 83 /// # fn main() -> libceed::Result<()> { 84 /// # let ceed = libceed::Ceed::default_init(); 85 /// const Q: usize = 8; 86 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 87 /// 88 /// let inputs = qf.inputs()?; 89 /// 90 /// assert_eq!( 91 /// inputs[0].eval_mode(), 92 /// EvalMode::Grad, 93 /// "Incorrect input evaluation mode" 94 /// ); 95 /// assert_eq!( 96 /// inputs[1].eval_mode(), 97 /// EvalMode::Weight, 98 /// "Incorrect input evaluation mode" 99 /// ); 100 /// # Ok(()) 101 /// # } 102 /// ``` 103 pub fn eval_mode(&self) -> crate::EvalMode { 104 let mut mode = 0; 105 unsafe { 106 bind_ceed::CeedQFunctionFieldGetEvalMode(self.ptr, &mut mode); 107 } 108 crate::EvalMode::from_u32(mode as u32) 109 } 110 } 111 112 // ----------------------------------------------------------------------------- 113 // QFunction option 114 // ----------------------------------------------------------------------------- 115 pub enum QFunctionOpt<'a> { 116 SomeQFunction(&'a QFunction<'a>), 117 SomeQFunctionByName(&'a QFunctionByName<'a>), 118 None, 119 } 120 121 /// Construct a QFunctionOpt reference from a QFunction reference 122 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> { 123 fn from(qfunc: &'a QFunction) -> Self { 124 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 125 Self::SomeQFunction(qfunc) 126 } 127 } 128 129 /// Construct a QFunctionOpt reference from a QFunction by Name reference 130 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> { 131 fn from(qfunc: &'a QFunctionByName) -> Self { 132 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 133 Self::SomeQFunctionByName(qfunc) 134 } 135 } 136 137 impl<'a> QFunctionOpt<'a> { 138 /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction 139 pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction { 140 match self { 141 Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr, 142 Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr, 143 Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE }, 144 } 145 } 146 147 /// Check if a QFunctionOpt is Some 148 /// 149 /// ``` 150 /// # use libceed::prelude::*; 151 /// # fn main() -> libceed::Result<()> { 152 /// # let ceed = libceed::Ceed::default_init(); 153 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 154 /// // Iterate over quadrature points 155 /// v.iter_mut() 156 /// .zip(u.iter().zip(weights.iter())) 157 /// .for_each(|(v, (u, w))| *v = u * w); 158 /// 159 /// // Return clean error code 160 /// 0 161 /// }; 162 /// 163 /// let qf = ceed 164 /// .q_function_interior(1, Box::new(user_f))? 165 /// .input("u", 1, EvalMode::Interp)? 166 /// .input("weights", 1, EvalMode::Weight)? 167 /// .output("v", 1, EvalMode::Interp)?; 168 /// let qf_opt = QFunctionOpt::from(&qf); 169 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt"); 170 /// 171 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 172 /// let qf_opt = QFunctionOpt::from(&qf); 173 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt"); 174 /// 175 /// let qf_opt = QFunctionOpt::None; 176 /// assert!(!qf_opt.is_some(), "Incorrect QFunctionOpt"); 177 /// # Ok(()) 178 /// # } 179 /// ``` 180 pub fn is_some(&self) -> bool { 181 match self { 182 Self::SomeQFunction(_) => true, 183 Self::SomeQFunctionByName(_) => true, 184 Self::None => false, 185 } 186 } 187 188 /// Check if a QFunctionOpt is SomeQFunction 189 /// 190 /// ``` 191 /// # use libceed::prelude::*; 192 /// # fn main() -> libceed::Result<()> { 193 /// # let ceed = libceed::Ceed::default_init(); 194 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 195 /// // Iterate over quadrature points 196 /// v.iter_mut() 197 /// .zip(u.iter().zip(weights.iter())) 198 /// .for_each(|(v, (u, w))| *v = u * w); 199 /// 200 /// // Return clean error code 201 /// 0 202 /// }; 203 /// 204 /// let qf = ceed 205 /// .q_function_interior(1, Box::new(user_f))? 206 /// .input("u", 1, EvalMode::Interp)? 207 /// .input("weights", 1, EvalMode::Weight)? 208 /// .output("v", 1, EvalMode::Interp)?; 209 /// let qf_opt = QFunctionOpt::from(&qf); 210 /// assert!(qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 211 /// 212 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 213 /// let qf_opt = QFunctionOpt::from(&qf); 214 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 215 /// 216 /// let qf_opt = QFunctionOpt::None; 217 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 218 /// # Ok(()) 219 /// # } 220 /// ``` 221 pub fn is_some_q_function(&self) -> bool { 222 match self { 223 Self::SomeQFunction(_) => true, 224 Self::SomeQFunctionByName(_) => false, 225 Self::None => false, 226 } 227 } 228 229 /// Check if a QFunctionOpt is SomeQFunctionByName 230 /// 231 /// ``` 232 /// # use libceed::prelude::*; 233 /// # fn main() -> libceed::Result<()> { 234 /// # let ceed = libceed::Ceed::default_init(); 235 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 236 /// // Iterate over quadrature points 237 /// v.iter_mut() 238 /// .zip(u.iter().zip(weights.iter())) 239 /// .for_each(|(v, (u, w))| *v = u * w); 240 /// 241 /// // Return clean error code 242 /// 0 243 /// }; 244 /// 245 /// let qf = ceed 246 /// .q_function_interior(1, Box::new(user_f))? 247 /// .input("u", 1, EvalMode::Interp)? 248 /// .input("weights", 1, EvalMode::Weight)? 249 /// .output("v", 1, EvalMode::Interp)?; 250 /// let qf_opt = QFunctionOpt::from(&qf); 251 /// assert!( 252 /// !qf_opt.is_some_q_function_by_name(), 253 /// "Incorrect QFunctionOpt" 254 /// ); 255 /// 256 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 257 /// let qf_opt = QFunctionOpt::from(&qf); 258 /// assert!( 259 /// qf_opt.is_some_q_function_by_name(), 260 /// "Incorrect QFunctionOpt" 261 /// ); 262 /// 263 /// let qf_opt = QFunctionOpt::None; 264 /// assert!( 265 /// !qf_opt.is_some_q_function_by_name(), 266 /// "Incorrect QFunctionOpt" 267 /// ); 268 /// # Ok(()) 269 /// # } 270 /// ``` 271 pub fn is_some_q_function_by_name(&self) -> bool { 272 match self { 273 Self::SomeQFunction(_) => false, 274 Self::SomeQFunctionByName(_) => true, 275 Self::None => false, 276 } 277 } 278 279 /// Check if a QFunctionOpt is None 280 /// 281 /// ``` 282 /// # use libceed::prelude::*; 283 /// # fn main() -> libceed::Result<()> { 284 /// # let ceed = libceed::Ceed::default_init(); 285 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 286 /// // Iterate over quadrature points 287 /// v.iter_mut() 288 /// .zip(u.iter().zip(weights.iter())) 289 /// .for_each(|(v, (u, w))| *v = u * w); 290 /// 291 /// // Return clean error code 292 /// 0 293 /// }; 294 /// 295 /// let qf = ceed 296 /// .q_function_interior(1, Box::new(user_f))? 297 /// .input("u", 1, EvalMode::Interp)? 298 /// .input("weights", 1, EvalMode::Weight)? 299 /// .output("v", 1, EvalMode::Interp)?; 300 /// let qf_opt = QFunctionOpt::from(&qf); 301 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt"); 302 /// 303 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 304 /// let qf_opt = QFunctionOpt::from(&qf); 305 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt"); 306 /// 307 /// let qf_opt = QFunctionOpt::None; 308 /// assert!(qf_opt.is_none(), "Incorrect QFunctionOpt"); 309 /// # Ok(()) 310 /// # } 311 /// ``` 312 pub fn is_none(&self) -> bool { 313 match self { 314 Self::SomeQFunction(_) => false, 315 Self::SomeQFunctionByName(_) => false, 316 Self::None => true, 317 } 318 } 319 } 320 321 // ----------------------------------------------------------------------------- 322 // QFunction context wrapper 323 // ----------------------------------------------------------------------------- 324 #[derive(Debug)] 325 pub(crate) struct QFunctionCore<'a> { 326 ptr: bind_ceed::CeedQFunction, 327 _lifeline: PhantomData<&'a ()>, 328 } 329 330 struct QFunctionTrampolineData { 331 number_inputs: usize, 332 number_outputs: usize, 333 input_sizes: [usize; MAX_QFUNCTION_FIELDS], 334 output_sizes: [usize; MAX_QFUNCTION_FIELDS], 335 user_f: Box<QFunctionUserClosure>, 336 } 337 338 pub struct QFunction<'a> { 339 qf_core: QFunctionCore<'a>, 340 qf_ctx_ptr: bind_ceed::CeedQFunctionContext, 341 trampoline_data: Pin<Box<QFunctionTrampolineData>>, 342 } 343 344 #[derive(Debug)] 345 pub struct QFunctionByName<'a> { 346 qf_core: QFunctionCore<'a>, 347 } 348 349 // ----------------------------------------------------------------------------- 350 // Destructor 351 // ----------------------------------------------------------------------------- 352 impl<'a> Drop for QFunctionCore<'a> { 353 fn drop(&mut self) { 354 unsafe { 355 if self.ptr != bind_ceed::CEED_QFUNCTION_NONE { 356 bind_ceed::CeedQFunctionDestroy(&mut self.ptr); 357 } 358 } 359 } 360 } 361 362 impl<'a> Drop for QFunction<'a> { 363 fn drop(&mut self) { 364 unsafe { 365 bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr); 366 } 367 } 368 } 369 370 // ----------------------------------------------------------------------------- 371 // Display 372 // ----------------------------------------------------------------------------- 373 impl<'a> fmt::Display for QFunctionCore<'a> { 374 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 375 let mut ptr = std::ptr::null_mut(); 376 let mut sizeloc = crate::MAX_BUFFER_LENGTH; 377 let cstring = unsafe { 378 let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc); 379 bind_ceed::CeedQFunctionView(self.ptr, file); 380 bind_ceed::fclose(file); 381 CString::from_raw(ptr) 382 }; 383 cstring.to_string_lossy().fmt(f) 384 } 385 } 386 /// View a QFunction 387 /// 388 /// ``` 389 /// # use libceed::prelude::*; 390 /// # fn main() -> libceed::Result<()> { 391 /// # let ceed = libceed::Ceed::default_init(); 392 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 393 /// // Iterate over quadrature points 394 /// v.iter_mut() 395 /// .zip(u.iter().zip(weights.iter())) 396 /// .for_each(|(v, (u, w))| *v = u * w); 397 /// 398 /// // Return clean error code 399 /// 0 400 /// }; 401 /// 402 /// let qf = ceed 403 /// .q_function_interior(1, Box::new(user_f))? 404 /// .input("u", 1, EvalMode::Interp)? 405 /// .input("weights", 1, EvalMode::Weight)? 406 /// .output("v", 1, EvalMode::Interp)?; 407 /// 408 /// println!("{}", qf); 409 /// # Ok(()) 410 /// # } 411 /// ``` 412 impl<'a> fmt::Display for QFunction<'a> { 413 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 414 self.qf_core.fmt(f) 415 } 416 } 417 418 /// View a QFunction by Name 419 /// 420 /// ``` 421 /// # use libceed::prelude::*; 422 /// # fn main() -> libceed::Result<()> { 423 /// # let ceed = libceed::Ceed::default_init(); 424 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 425 /// println!("{}", qf); 426 /// # Ok(()) 427 /// # } 428 /// ``` 429 impl<'a> fmt::Display for QFunctionByName<'a> { 430 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 431 self.qf_core.fmt(f) 432 } 433 } 434 435 // ----------------------------------------------------------------------------- 436 // Core functionality 437 // ----------------------------------------------------------------------------- 438 impl<'a> QFunctionCore<'a> { 439 // Error handling 440 #[doc(hidden)] 441 fn check_error(&self, ierr: i32) -> crate::Result<i32> { 442 let mut ptr = std::ptr::null_mut(); 443 unsafe { 444 bind_ceed::CeedQFunctionGetCeed(self.ptr, &mut ptr); 445 } 446 crate::check_error(ptr, ierr) 447 } 448 449 // Common implementation 450 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 451 let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 452 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) { 453 u_c[i] = u[i].ptr; 454 } 455 let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 456 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) { 457 v_c[i] = v[i].ptr; 458 } 459 let Q = i32::try_from(Q).unwrap(); 460 let ierr = unsafe { 461 bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr()) 462 }; 463 self.check_error(ierr) 464 } 465 466 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 467 // Get array of raw C pointers for inputs 468 let mut num_inputs = 0; 469 let mut inputs_ptr = std::ptr::null_mut(); 470 let ierr = unsafe { 471 bind_ceed::CeedQFunctionGetFields( 472 self.ptr, 473 &mut num_inputs, 474 &mut inputs_ptr, 475 std::ptr::null_mut() as *mut bind_ceed::CeedInt, 476 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField, 477 ) 478 }; 479 self.check_error(ierr)?; 480 // Convert raw C pointers to fixed length slice 481 let inputs_slice = unsafe { 482 std::slice::from_raw_parts( 483 inputs_ptr as *const crate::QFunctionField, 484 num_inputs as usize, 485 ) 486 }; 487 Ok(inputs_slice) 488 } 489 490 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 491 // Get array of raw C pointers for outputs 492 let mut num_outputs = 0; 493 let mut outputs_ptr = std::ptr::null_mut(); 494 let ierr = unsafe { 495 bind_ceed::CeedQFunctionGetFields( 496 self.ptr, 497 std::ptr::null_mut() as *mut bind_ceed::CeedInt, 498 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField, 499 &mut num_outputs, 500 &mut outputs_ptr, 501 ) 502 }; 503 self.check_error(ierr)?; 504 // Convert raw C pointers to fixed length slice 505 let outputs_slice = unsafe { 506 std::slice::from_raw_parts( 507 outputs_ptr as *const crate::QFunctionField, 508 num_outputs as usize, 509 ) 510 }; 511 Ok(outputs_slice) 512 } 513 } 514 515 // ----------------------------------------------------------------------------- 516 // User QFunction Closure 517 // ----------------------------------------------------------------------------- 518 pub type QFunctionUserClosure = dyn FnMut( 519 [&[crate::Scalar]; MAX_QFUNCTION_FIELDS], 520 [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS], 521 ) -> i32; 522 523 macro_rules! mut_max_fields { 524 ($e:expr) => { 525 [ 526 $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, 527 ] 528 }; 529 } 530 unsafe extern "C" fn trampoline( 531 ctx: *mut ::std::os::raw::c_void, 532 q: bind_ceed::CeedInt, 533 inputs: *const *const bind_ceed::CeedScalar, 534 outputs: *const *mut bind_ceed::CeedScalar, 535 ) -> ::std::os::raw::c_int { 536 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx); 537 538 // Inputs 539 let inputs_slice: &[*const bind_ceed::CeedScalar] = 540 std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS); 541 let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS]; 542 inputs_slice 543 .iter() 544 .enumerate() 545 .map(|(i, &x)| { 546 std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize) 547 as &[crate::Scalar] 548 }) 549 .zip(inputs_array.iter_mut()) 550 .for_each(|(x, a)| *a = x); 551 552 // Outputs 553 let outputs_slice: &[*mut bind_ceed::CeedScalar] = 554 std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS); 555 let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] = 556 mut_max_fields!(&mut [0.0]); 557 outputs_slice 558 .iter() 559 .enumerate() 560 .map(|(i, &x)| { 561 std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize) 562 as &mut [crate::Scalar] 563 }) 564 .zip(outputs_array.iter_mut()) 565 .for_each(|(x, a)| *a = x); 566 567 // User closure 568 (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array) 569 } 570 571 // ----------------------------------------------------------------------------- 572 // QFunction 573 // ----------------------------------------------------------------------------- 574 impl<'a> QFunction<'a> { 575 // Constructor 576 pub fn create( 577 ceed: &crate::Ceed, 578 vlength: usize, 579 user_f: Box<QFunctionUserClosure>, 580 ) -> crate::Result<Self> { 581 let source_c = CString::new("").expect("CString::new failed"); 582 let mut ptr = std::ptr::null_mut(); 583 584 // Context for closure 585 let number_inputs = 0; 586 let number_outputs = 0; 587 let input_sizes = [0; MAX_QFUNCTION_FIELDS]; 588 let output_sizes = [0; MAX_QFUNCTION_FIELDS]; 589 let trampoline_data = unsafe { 590 Pin::new_unchecked(Box::new(QFunctionTrampolineData { 591 number_inputs, 592 number_outputs, 593 input_sizes, 594 output_sizes, 595 user_f, 596 })) 597 }; 598 599 // Create QFunction 600 let vlength = i32::try_from(vlength).unwrap(); 601 let mut ierr = unsafe { 602 bind_ceed::CeedQFunctionCreateInterior( 603 ceed.ptr, 604 vlength, 605 Some(trampoline), 606 source_c.as_ptr(), 607 &mut ptr, 608 ) 609 }; 610 ceed.check_error(ierr)?; 611 612 // Set closure 613 let mut qf_ctx_ptr = std::ptr::null_mut(); 614 ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) }; 615 ceed.check_error(ierr)?; 616 ierr = unsafe { 617 bind_ceed::CeedQFunctionContextSetData( 618 qf_ctx_ptr, 619 crate::MemType::Host as bind_ceed::CeedMemType, 620 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode, 621 std::mem::size_of::<QFunctionTrampolineData>() as u64, 622 std::mem::transmute(trampoline_data.as_ref()), 623 ) 624 }; 625 ceed.check_error(ierr)?; 626 ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) }; 627 ceed.check_error(ierr)?; 628 Ok(Self { 629 qf_core: QFunctionCore { 630 ptr, 631 _lifeline: PhantomData, 632 }, 633 qf_ctx_ptr, 634 trampoline_data, 635 }) 636 } 637 638 /// Apply the action of a QFunction 639 /// 640 /// * `Q` - The number of quadrature points 641 /// * `input` - Array of input Vectors 642 /// * `output` - Array of output Vectors 643 /// 644 /// ``` 645 /// # use libceed::prelude::*; 646 /// # fn main() -> libceed::Result<()> { 647 /// # let ceed = libceed::Ceed::default_init(); 648 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 649 /// // Iterate over quadrature points 650 /// v.iter_mut() 651 /// .zip(u.iter().zip(weights.iter())) 652 /// .for_each(|(v, (u, w))| *v = u * w); 653 /// 654 /// // Return clean error code 655 /// 0 656 /// }; 657 /// 658 /// let qf = ceed 659 /// .q_function_interior(1, Box::new(user_f))? 660 /// .input("u", 1, EvalMode::Interp)? 661 /// .input("weights", 1, EvalMode::Weight)? 662 /// .output("v", 1, EvalMode::Interp)?; 663 /// 664 /// const Q: usize = 8; 665 /// let mut w = [0.; Q]; 666 /// let mut u = [0.; Q]; 667 /// let mut v = [0.; Q]; 668 /// 669 /// for i in 0..Q { 670 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 671 /// u[i] = 2. + 3. * x + 5. * x * x; 672 /// w[i] = 1. - x * x; 673 /// v[i] = u[i] * w[i]; 674 /// } 675 /// 676 /// let uu = ceed.vector_from_slice(&u)?; 677 /// let ww = ceed.vector_from_slice(&w)?; 678 /// let mut vv = ceed.vector(Q)?; 679 /// vv.set_value(0.0); 680 /// { 681 /// let input = vec![uu, ww]; 682 /// let mut output = vec![vv]; 683 /// qf.apply(Q, &input, &output)?; 684 /// vv = output.remove(0); 685 /// } 686 /// 687 /// vv.view()? 688 /// .iter() 689 /// .zip(v.iter()) 690 /// .for_each(|(computed, actual)| { 691 /// assert_eq!( 692 /// *computed, *actual, 693 /// "Incorrect value in QFunction application" 694 /// ); 695 /// }); 696 /// # Ok(()) 697 /// # } 698 /// ``` 699 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 700 self.qf_core.apply(Q, u, v) 701 } 702 703 /// Add a QFunction input 704 /// 705 /// * `fieldname` - Name of QFunction field 706 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 707 /// `(ncomp * 1)` for `None`, `Interp`, and `Weight` 708 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 709 /// to use interpolated values, `EvalMode::Grad` to use 710 /// gradients, `EvalMode::Weight` to use quadrature weights 711 /// 712 /// ``` 713 /// # use libceed::prelude::*; 714 /// # fn main() -> libceed::Result<()> { 715 /// # let ceed = libceed::Ceed::default_init(); 716 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 717 /// // Iterate over quadrature points 718 /// v.iter_mut() 719 /// .zip(u.iter().zip(weights.iter())) 720 /// .for_each(|(v, (u, w))| *v = u * w); 721 /// 722 /// // Return clean error code 723 /// 0 724 /// }; 725 /// 726 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 727 /// 728 /// qf = qf.input("u", 1, EvalMode::Interp)?; 729 /// qf = qf.input("weights", 1, EvalMode::Weight)?; 730 /// # Ok(()) 731 /// # } 732 /// ``` 733 pub fn input( 734 mut self, 735 fieldname: &str, 736 size: usize, 737 emode: crate::EvalMode, 738 ) -> crate::Result<Self> { 739 let name_c = CString::new(fieldname).expect("CString::new failed"); 740 let idx = self.trampoline_data.number_inputs; 741 self.trampoline_data.input_sizes[idx] = size; 742 self.trampoline_data.number_inputs += 1; 743 let (size, emode) = ( 744 i32::try_from(size).unwrap(), 745 emode as bind_ceed::CeedEvalMode, 746 ); 747 let ierr = unsafe { 748 bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 749 }; 750 self.qf_core.check_error(ierr)?; 751 Ok(self) 752 } 753 754 /// Add a QFunction output 755 /// 756 /// * `fieldname` - Name of QFunction field 757 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 758 /// `(ncomp * 1)` for `None` and `Interp` 759 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 760 /// to use interpolated values, `EvalMode::Grad` to use 761 /// gradients 762 /// 763 /// ``` 764 /// # use libceed::prelude::*; 765 /// # fn main() -> libceed::Result<()> { 766 /// # let ceed = libceed::Ceed::default_init(); 767 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 768 /// // Iterate over quadrature points 769 /// v.iter_mut() 770 /// .zip(u.iter().zip(weights.iter())) 771 /// .for_each(|(v, (u, w))| *v = u * w); 772 /// 773 /// // Return clean error code 774 /// 0 775 /// }; 776 /// 777 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 778 /// 779 /// qf.output("v", 1, EvalMode::Interp)?; 780 /// # Ok(()) 781 /// # } 782 /// ``` 783 pub fn output( 784 mut self, 785 fieldname: &str, 786 size: usize, 787 emode: crate::EvalMode, 788 ) -> crate::Result<Self> { 789 let name_c = CString::new(fieldname).expect("CString::new failed"); 790 let idx = self.trampoline_data.number_outputs; 791 self.trampoline_data.output_sizes[idx] = size; 792 self.trampoline_data.number_outputs += 1; 793 let (size, emode) = ( 794 i32::try_from(size).unwrap(), 795 emode as bind_ceed::CeedEvalMode, 796 ); 797 let ierr = unsafe { 798 bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 799 }; 800 self.qf_core.check_error(ierr)?; 801 Ok(self) 802 } 803 804 /// Get a slice of QFunction inputs 805 /// 806 /// ``` 807 /// # use libceed::prelude::*; 808 /// # fn main() -> libceed::Result<()> { 809 /// # let ceed = libceed::Ceed::default_init(); 810 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 811 /// // Iterate over quadrature points 812 /// v.iter_mut() 813 /// .zip(u.iter().zip(weights.iter())) 814 /// .for_each(|(v, (u, w))| *v = u * w); 815 /// 816 /// // Return clean error code 817 /// 0 818 /// }; 819 /// 820 /// let mut qf = ceed 821 /// .q_function_interior(1, Box::new(user_f))? 822 /// .input("u", 1, EvalMode::Interp)? 823 /// .input("weights", 1, EvalMode::Weight)?; 824 /// 825 /// let inputs = qf.inputs()?; 826 /// 827 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array"); 828 /// # Ok(()) 829 /// # } 830 /// ``` 831 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 832 self.qf_core.inputs() 833 } 834 835 /// Get a slice of QFunction outputs 836 /// 837 /// ``` 838 /// # use libceed::prelude::*; 839 /// # fn main() -> libceed::Result<()> { 840 /// # let ceed = libceed::Ceed::default_init(); 841 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 842 /// // Iterate over quadrature points 843 /// v.iter_mut() 844 /// .zip(u.iter().zip(weights.iter())) 845 /// .for_each(|(v, (u, w))| *v = u * w); 846 /// 847 /// // Return clean error code 848 /// 0 849 /// }; 850 /// 851 /// let mut qf = ceed 852 /// .q_function_interior(1, Box::new(user_f))? 853 /// .output("v", 1, EvalMode::Interp)?; 854 /// 855 /// let outputs = qf.outputs()?; 856 /// 857 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array"); 858 /// # Ok(()) 859 /// # } 860 /// ``` 861 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 862 self.qf_core.outputs() 863 } 864 } 865 866 // ----------------------------------------------------------------------------- 867 // QFunction 868 // ----------------------------------------------------------------------------- 869 impl<'a> QFunctionByName<'a> { 870 // Constructor 871 pub fn create(ceed: &crate::Ceed, name: &str) -> crate::Result<Self> { 872 let name_c = CString::new(name).expect("CString::new failed"); 873 let mut ptr = std::ptr::null_mut(); 874 let ierr = unsafe { 875 bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr) 876 }; 877 ceed.check_error(ierr)?; 878 Ok(Self { 879 qf_core: QFunctionCore { 880 ptr, 881 _lifeline: PhantomData, 882 }, 883 }) 884 } 885 886 /// Apply the action of a QFunction 887 /// 888 /// * `Q` - The number of quadrature points 889 /// * `input` - Array of input Vectors 890 /// * `output` - Array of output Vectors 891 /// 892 /// ``` 893 /// # use libceed::prelude::*; 894 /// # fn main() -> libceed::Result<()> { 895 /// # let ceed = libceed::Ceed::default_init(); 896 /// const Q: usize = 8; 897 /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild")?; 898 /// let qf_mass = ceed.q_function_interior_by_name("MassApply")?; 899 /// 900 /// let mut j = [0.; Q]; 901 /// let mut w = [0.; Q]; 902 /// let mut u = [0.; Q]; 903 /// let mut v = [0.; Q]; 904 /// 905 /// for i in 0..Q { 906 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 907 /// j[i] = 1.; 908 /// w[i] = 1. - x * x; 909 /// u[i] = 2. + 3. * x + 5. * x * x; 910 /// v[i] = w[i] * u[i]; 911 /// } 912 /// 913 /// let jj = ceed.vector_from_slice(&j)?; 914 /// let ww = ceed.vector_from_slice(&w)?; 915 /// let uu = ceed.vector_from_slice(&u)?; 916 /// let mut vv = ceed.vector(Q)?; 917 /// vv.set_value(0.0); 918 /// let mut qdata = ceed.vector(Q)?; 919 /// qdata.set_value(0.0); 920 /// 921 /// { 922 /// let mut input = vec![jj, ww]; 923 /// let mut output = vec![qdata]; 924 /// qf_build.apply(Q, &input, &output)?; 925 /// qdata = output.remove(0); 926 /// } 927 /// 928 /// { 929 /// let mut input = vec![qdata, uu]; 930 /// let mut output = vec![vv]; 931 /// qf_mass.apply(Q, &input, &output)?; 932 /// vv = output.remove(0); 933 /// } 934 /// 935 /// vv.view()? 936 /// .iter() 937 /// .zip(v.iter()) 938 /// .for_each(|(computed, actual)| { 939 /// assert_eq!( 940 /// *computed, *actual, 941 /// "Incorrect value in QFunction application" 942 /// ); 943 /// }); 944 /// # Ok(()) 945 /// # } 946 /// ``` 947 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 948 self.qf_core.apply(Q, u, v) 949 } 950 951 /// Get a slice of QFunction inputs 952 /// 953 /// ``` 954 /// # use libceed::prelude::*; 955 /// # fn main() -> libceed::Result<()> { 956 /// # let ceed = libceed::Ceed::default_init(); 957 /// const Q: usize = 8; 958 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 959 /// 960 /// let inputs = qf.inputs()?; 961 /// 962 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array"); 963 /// # Ok(()) 964 /// # } 965 /// ``` 966 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 967 self.qf_core.inputs() 968 } 969 970 /// Get a slice of QFunction outputs 971 /// 972 /// ``` 973 /// # use libceed::prelude::*; 974 /// # fn main() -> libceed::Result<()> { 975 /// # let ceed = libceed::Ceed::default_init(); 976 /// const Q: usize = 8; 977 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 978 /// 979 /// let outputs = qf.outputs()?; 980 /// 981 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array"); 982 /// # Ok(()) 983 /// # } 984 /// ``` 985 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 986 self.qf_core.outputs() 987 } 988 } 989 990 // ----------------------------------------------------------------------------- 991