1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 //! A Ceed QFunction represents the spatial terms of the point-wise functions 9 //! describing the physics at the quadrature points. 10 11 use std::pin::Pin; 12 13 use crate::prelude::*; 14 15 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 16 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 17 18 // ----------------------------------------------------------------------------- 19 // QFunction Field context wrapper 20 // ----------------------------------------------------------------------------- 21 #[derive(Debug)] 22 pub struct QFunctionField<'a> { 23 ptr: bind_ceed::CeedQFunctionField, 24 _lifeline: PhantomData<&'a ()>, 25 } 26 27 // ----------------------------------------------------------------------------- 28 // Implementations 29 // ----------------------------------------------------------------------------- 30 impl<'a> QFunctionField<'a> { 31 /// Get the name of a QFunctionField 32 /// 33 /// ``` 34 /// # use libceed::prelude::*; 35 /// # fn main() -> libceed::Result<()> { 36 /// # let ceed = libceed::Ceed::default_init(); 37 /// const Q: usize = 8; 38 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 39 /// 40 /// let inputs = qf.inputs()?; 41 /// 42 /// assert_eq!(inputs[0].name(), "dx", "Incorrect input name"); 43 /// assert_eq!(inputs[1].name(), "weights", "Incorrect input name"); 44 /// # Ok(()) 45 /// # } 46 /// ``` 47 pub fn name(&self) -> &str { 48 let mut name_ptr: *mut std::os::raw::c_char = std::ptr::null_mut(); 49 unsafe { 50 bind_ceed::CeedQFunctionFieldGetName(self.ptr, &mut name_ptr); 51 } 52 unsafe { CStr::from_ptr(name_ptr) }.to_str().unwrap() 53 } 54 55 /// Get the size of a QFunctionField 56 /// 57 /// ``` 58 /// # use libceed::prelude::*; 59 /// # fn main() -> libceed::Result<()> { 60 /// # let ceed = libceed::Ceed::default_init(); 61 /// const Q: usize = 8; 62 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 63 /// 64 /// let inputs = qf.inputs()?; 65 /// 66 /// assert_eq!(inputs[0].size(), 4, "Incorrect input size"); 67 /// assert_eq!(inputs[1].size(), 1, "Incorrect input size"); 68 /// # Ok(()) 69 /// # } 70 /// ``` 71 pub fn size(&self) -> usize { 72 let mut size = 0; 73 unsafe { 74 bind_ceed::CeedQFunctionFieldGetSize(self.ptr, &mut size); 75 } 76 usize::try_from(size).unwrap() 77 } 78 79 /// Get the evaluation mode of a QFunctionField 80 /// 81 /// ``` 82 /// # use libceed::prelude::*; 83 /// # fn main() -> libceed::Result<()> { 84 /// # let ceed = libceed::Ceed::default_init(); 85 /// const Q: usize = 8; 86 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 87 /// 88 /// let inputs = qf.inputs()?; 89 /// 90 /// assert_eq!( 91 /// inputs[0].eval_mode(), 92 /// EvalMode::Grad, 93 /// "Incorrect input evaluation mode" 94 /// ); 95 /// assert_eq!( 96 /// inputs[1].eval_mode(), 97 /// EvalMode::Weight, 98 /// "Incorrect input evaluation mode" 99 /// ); 100 /// # Ok(()) 101 /// # } 102 /// ``` 103 pub fn eval_mode(&self) -> crate::EvalMode { 104 let mut mode = 0; 105 unsafe { 106 bind_ceed::CeedQFunctionFieldGetEvalMode(self.ptr, &mut mode); 107 } 108 crate::EvalMode::from_u32(mode as u32) 109 } 110 } 111 112 // ----------------------------------------------------------------------------- 113 // QFunction option 114 // ----------------------------------------------------------------------------- 115 pub enum QFunctionOpt<'a> { 116 SomeQFunction(&'a QFunction<'a>), 117 SomeQFunctionByName(&'a QFunctionByName<'a>), 118 None, 119 } 120 121 /// Construct a QFunctionOpt reference from a QFunction reference 122 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> { 123 fn from(qfunc: &'a QFunction) -> Self { 124 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 125 Self::SomeQFunction(qfunc) 126 } 127 } 128 129 /// Construct a QFunctionOpt reference from a QFunction by Name reference 130 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> { 131 fn from(qfunc: &'a QFunctionByName) -> Self { 132 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 133 Self::SomeQFunctionByName(qfunc) 134 } 135 } 136 137 impl<'a> QFunctionOpt<'a> { 138 /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction 139 pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction { 140 match self { 141 Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr, 142 Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr, 143 Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE }, 144 } 145 } 146 147 /// Check if a QFunctionOpt is Some 148 /// 149 /// ``` 150 /// # use libceed::prelude::*; 151 /// # fn main() -> libceed::Result<()> { 152 /// # let ceed = libceed::Ceed::default_init(); 153 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 154 /// // Iterate over quadrature points 155 /// v.iter_mut() 156 /// .zip(u.iter().zip(weights.iter())) 157 /// .for_each(|(v, (u, w))| *v = u * w); 158 /// 159 /// // Return clean error code 160 /// 0 161 /// }; 162 /// 163 /// let qf = ceed 164 /// .q_function_interior(1, Box::new(user_f))? 165 /// .input("u", 1, EvalMode::Interp)? 166 /// .input("weights", 1, EvalMode::Weight)? 167 /// .output("v", 1, EvalMode::Interp)?; 168 /// let qf_opt = QFunctionOpt::from(&qf); 169 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt"); 170 /// 171 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 172 /// let qf_opt = QFunctionOpt::from(&qf); 173 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt"); 174 /// 175 /// let qf_opt = QFunctionOpt::None; 176 /// assert!(!qf_opt.is_some(), "Incorrect QFunctionOpt"); 177 /// # Ok(()) 178 /// # } 179 /// ``` 180 pub fn is_some(&self) -> bool { 181 match self { 182 Self::SomeQFunction(_) => true, 183 Self::SomeQFunctionByName(_) => true, 184 Self::None => false, 185 } 186 } 187 188 /// Check if a QFunctionOpt is SomeQFunction 189 /// 190 /// ``` 191 /// # use libceed::prelude::*; 192 /// # fn main() -> libceed::Result<()> { 193 /// # let ceed = libceed::Ceed::default_init(); 194 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 195 /// // Iterate over quadrature points 196 /// v.iter_mut() 197 /// .zip(u.iter().zip(weights.iter())) 198 /// .for_each(|(v, (u, w))| *v = u * w); 199 /// 200 /// // Return clean error code 201 /// 0 202 /// }; 203 /// 204 /// let qf = ceed 205 /// .q_function_interior(1, Box::new(user_f))? 206 /// .input("u", 1, EvalMode::Interp)? 207 /// .input("weights", 1, EvalMode::Weight)? 208 /// .output("v", 1, EvalMode::Interp)?; 209 /// let qf_opt = QFunctionOpt::from(&qf); 210 /// assert!(qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 211 /// 212 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 213 /// let qf_opt = QFunctionOpt::from(&qf); 214 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 215 /// 216 /// let qf_opt = QFunctionOpt::None; 217 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 218 /// # Ok(()) 219 /// # } 220 /// ``` 221 pub fn is_some_q_function(&self) -> bool { 222 match self { 223 Self::SomeQFunction(_) => true, 224 Self::SomeQFunctionByName(_) => false, 225 Self::None => false, 226 } 227 } 228 229 /// Check if a QFunctionOpt is SomeQFunctionByName 230 /// 231 /// ``` 232 /// # use libceed::prelude::*; 233 /// # fn main() -> libceed::Result<()> { 234 /// # let ceed = libceed::Ceed::default_init(); 235 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 236 /// // Iterate over quadrature points 237 /// v.iter_mut() 238 /// .zip(u.iter().zip(weights.iter())) 239 /// .for_each(|(v, (u, w))| *v = u * w); 240 /// 241 /// // Return clean error code 242 /// 0 243 /// }; 244 /// 245 /// let qf = ceed 246 /// .q_function_interior(1, Box::new(user_f))? 247 /// .input("u", 1, EvalMode::Interp)? 248 /// .input("weights", 1, EvalMode::Weight)? 249 /// .output("v", 1, EvalMode::Interp)?; 250 /// let qf_opt = QFunctionOpt::from(&qf); 251 /// assert!( 252 /// !qf_opt.is_some_q_function_by_name(), 253 /// "Incorrect QFunctionOpt" 254 /// ); 255 /// 256 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 257 /// let qf_opt = QFunctionOpt::from(&qf); 258 /// assert!( 259 /// qf_opt.is_some_q_function_by_name(), 260 /// "Incorrect QFunctionOpt" 261 /// ); 262 /// 263 /// let qf_opt = QFunctionOpt::None; 264 /// assert!( 265 /// !qf_opt.is_some_q_function_by_name(), 266 /// "Incorrect QFunctionOpt" 267 /// ); 268 /// # Ok(()) 269 /// # } 270 /// ``` 271 pub fn is_some_q_function_by_name(&self) -> bool { 272 match self { 273 Self::SomeQFunction(_) => false, 274 Self::SomeQFunctionByName(_) => true, 275 Self::None => false, 276 } 277 } 278 279 /// Check if a QFunctionOpt is None 280 /// 281 /// ``` 282 /// # use libceed::prelude::*; 283 /// # fn main() -> libceed::Result<()> { 284 /// # let ceed = libceed::Ceed::default_init(); 285 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 286 /// // Iterate over quadrature points 287 /// v.iter_mut() 288 /// .zip(u.iter().zip(weights.iter())) 289 /// .for_each(|(v, (u, w))| *v = u * w); 290 /// 291 /// // Return clean error code 292 /// 0 293 /// }; 294 /// 295 /// let qf = ceed 296 /// .q_function_interior(1, Box::new(user_f))? 297 /// .input("u", 1, EvalMode::Interp)? 298 /// .input("weights", 1, EvalMode::Weight)? 299 /// .output("v", 1, EvalMode::Interp)?; 300 /// let qf_opt = QFunctionOpt::from(&qf); 301 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt"); 302 /// 303 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 304 /// let qf_opt = QFunctionOpt::from(&qf); 305 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt"); 306 /// 307 /// let qf_opt = QFunctionOpt::None; 308 /// assert!(qf_opt.is_none(), "Incorrect QFunctionOpt"); 309 /// # Ok(()) 310 /// # } 311 /// ``` 312 pub fn is_none(&self) -> bool { 313 match self { 314 Self::SomeQFunction(_) => false, 315 Self::SomeQFunctionByName(_) => false, 316 Self::None => true, 317 } 318 } 319 } 320 321 // ----------------------------------------------------------------------------- 322 // QFunction context wrapper 323 // ----------------------------------------------------------------------------- 324 #[derive(Debug)] 325 pub(crate) struct QFunctionCore<'a> { 326 ptr: bind_ceed::CeedQFunction, 327 _lifeline: PhantomData<&'a ()>, 328 } 329 330 struct QFunctionTrampolineData { 331 number_inputs: usize, 332 number_outputs: usize, 333 input_sizes: [usize; MAX_QFUNCTION_FIELDS], 334 output_sizes: [usize; MAX_QFUNCTION_FIELDS], 335 user_f: Box<QFunctionUserClosure>, 336 } 337 338 pub struct QFunction<'a> { 339 qf_core: QFunctionCore<'a>, 340 qf_ctx_ptr: bind_ceed::CeedQFunctionContext, 341 trampoline_data: Pin<Box<QFunctionTrampolineData>>, 342 } 343 344 #[derive(Debug)] 345 pub struct QFunctionByName<'a> { 346 qf_core: QFunctionCore<'a>, 347 } 348 349 // ----------------------------------------------------------------------------- 350 // Destructor 351 // ----------------------------------------------------------------------------- 352 impl<'a> Drop for QFunctionCore<'a> { 353 fn drop(&mut self) { 354 unsafe { 355 if self.ptr != bind_ceed::CEED_QFUNCTION_NONE { 356 bind_ceed::CeedQFunctionDestroy(&mut self.ptr); 357 } 358 } 359 } 360 } 361 362 impl<'a> Drop for QFunction<'a> { 363 fn drop(&mut self) { 364 unsafe { 365 bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr); 366 } 367 } 368 } 369 370 // ----------------------------------------------------------------------------- 371 // Display 372 // ----------------------------------------------------------------------------- 373 impl<'a> fmt::Display for QFunctionCore<'a> { 374 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 375 let mut ptr = std::ptr::null_mut(); 376 let mut sizeloc = crate::MAX_BUFFER_LENGTH; 377 let cstring = unsafe { 378 let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc); 379 bind_ceed::CeedQFunctionView(self.ptr, file); 380 bind_ceed::fclose(file); 381 CString::from_raw(ptr) 382 }; 383 cstring.to_string_lossy().fmt(f) 384 } 385 } 386 /// View a QFunction 387 /// 388 /// ``` 389 /// # use libceed::prelude::*; 390 /// # fn main() -> libceed::Result<()> { 391 /// # let ceed = libceed::Ceed::default_init(); 392 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 393 /// // Iterate over quadrature points 394 /// v.iter_mut() 395 /// .zip(u.iter().zip(weights.iter())) 396 /// .for_each(|(v, (u, w))| *v = u * w); 397 /// 398 /// // Return clean error code 399 /// 0 400 /// }; 401 /// 402 /// let qf = ceed 403 /// .q_function_interior(1, Box::new(user_f))? 404 /// .input("u", 1, EvalMode::Interp)? 405 /// .input("weights", 1, EvalMode::Weight)? 406 /// .output("v", 1, EvalMode::Interp)?; 407 /// 408 /// println!("{}", qf); 409 /// # Ok(()) 410 /// # } 411 /// ``` 412 impl<'a> fmt::Display for QFunction<'a> { 413 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 414 self.qf_core.fmt(f) 415 } 416 } 417 418 /// View a QFunction by Name 419 /// 420 /// ``` 421 /// # use libceed::prelude::*; 422 /// # fn main() -> libceed::Result<()> { 423 /// # let ceed = libceed::Ceed::default_init(); 424 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 425 /// println!("{}", qf); 426 /// # Ok(()) 427 /// # } 428 /// ``` 429 impl<'a> fmt::Display for QFunctionByName<'a> { 430 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 431 self.qf_core.fmt(f) 432 } 433 } 434 435 // ----------------------------------------------------------------------------- 436 // Core functionality 437 // ----------------------------------------------------------------------------- 438 impl<'a> QFunctionCore<'a> { 439 // Error handling 440 #[doc(hidden)] 441 fn check_error(&self, ierr: i32) -> crate::Result<i32> { 442 let mut ptr = std::ptr::null_mut(); 443 unsafe { 444 bind_ceed::CeedQFunctionGetCeed(self.ptr, &mut ptr); 445 } 446 crate::check_error(ptr, ierr) 447 } 448 449 // Common implementation 450 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 451 let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 452 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) { 453 u_c[i] = u[i].ptr; 454 } 455 let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 456 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) { 457 v_c[i] = v[i].ptr; 458 } 459 let Q = i32::try_from(Q).unwrap(); 460 let ierr = unsafe { 461 bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr()) 462 }; 463 self.check_error(ierr) 464 } 465 466 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 467 // Get array of raw C pointers for inputs 468 let mut num_inputs = 0; 469 let mut inputs_ptr = std::ptr::null_mut(); 470 let ierr = unsafe { 471 bind_ceed::CeedQFunctionGetFields( 472 self.ptr, 473 &mut num_inputs, 474 &mut inputs_ptr, 475 std::ptr::null_mut() as *mut bind_ceed::CeedInt, 476 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField, 477 ) 478 }; 479 self.check_error(ierr)?; 480 // Convert raw C pointers to fixed length slice 481 let inputs_slice = unsafe { 482 std::slice::from_raw_parts( 483 inputs_ptr as *const crate::QFunctionField, 484 num_inputs as usize, 485 ) 486 }; 487 Ok(inputs_slice) 488 } 489 490 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 491 // Get array of raw C pointers for outputs 492 let mut num_outputs = 0; 493 let mut outputs_ptr = std::ptr::null_mut(); 494 let ierr = unsafe { 495 bind_ceed::CeedQFunctionGetFields( 496 self.ptr, 497 std::ptr::null_mut() as *mut bind_ceed::CeedInt, 498 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField, 499 &mut num_outputs, 500 &mut outputs_ptr, 501 ) 502 }; 503 self.check_error(ierr)?; 504 // Convert raw C pointers to fixed length slice 505 let outputs_slice = unsafe { 506 std::slice::from_raw_parts( 507 outputs_ptr as *const crate::QFunctionField, 508 num_outputs as usize, 509 ) 510 }; 511 Ok(outputs_slice) 512 } 513 } 514 515 // ----------------------------------------------------------------------------- 516 // User QFunction Closure 517 // ----------------------------------------------------------------------------- 518 pub type QFunctionUserClosure = dyn FnMut( 519 [&[crate::Scalar]; MAX_QFUNCTION_FIELDS], 520 [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS], 521 ) -> i32; 522 523 macro_rules! mut_max_fields { 524 ($e:expr) => { 525 [ 526 $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, 527 ] 528 }; 529 } 530 unsafe extern "C" fn trampoline( 531 ctx: *mut ::std::os::raw::c_void, 532 q: bind_ceed::CeedInt, 533 inputs: *const *const bind_ceed::CeedScalar, 534 outputs: *const *mut bind_ceed::CeedScalar, 535 ) -> ::std::os::raw::c_int { 536 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx); 537 538 // Inputs 539 let inputs_slice: &[*const bind_ceed::CeedScalar] = 540 std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS); 541 let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS]; 542 inputs_slice 543 .iter() 544 .enumerate() 545 .map(|(i, &x)| { 546 std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize) 547 as &[crate::Scalar] 548 }) 549 .zip(inputs_array.iter_mut()) 550 .for_each(|(x, a)| *a = x); 551 552 // Outputs 553 let outputs_slice: &[*mut bind_ceed::CeedScalar] = 554 std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS); 555 let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] = 556 mut_max_fields!(&mut [0.0]); 557 outputs_slice 558 .iter() 559 .enumerate() 560 .map(|(i, &x)| { 561 std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize) 562 as &mut [crate::Scalar] 563 }) 564 .zip(outputs_array.iter_mut()) 565 .for_each(|(x, a)| *a = x); 566 567 // User closure 568 (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array) 569 } 570 571 unsafe extern "C" fn destroy_trampoline(ctx: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int { 572 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx); 573 drop(trampoline_data); 574 0 // Clean error code 575 } 576 577 // ----------------------------------------------------------------------------- 578 // QFunction 579 // ----------------------------------------------------------------------------- 580 impl<'a> QFunction<'a> { 581 // Constructor 582 pub fn create( 583 ceed: &crate::Ceed, 584 vlength: usize, 585 user_f: Box<QFunctionUserClosure>, 586 ) -> crate::Result<Self> { 587 let source_c = CString::new("").expect("CString::new failed"); 588 let mut ptr = std::ptr::null_mut(); 589 590 // Context for closure 591 let number_inputs = 0; 592 let number_outputs = 0; 593 let input_sizes = [0; MAX_QFUNCTION_FIELDS]; 594 let output_sizes = [0; MAX_QFUNCTION_FIELDS]; 595 let trampoline_data = unsafe { 596 Pin::new_unchecked(Box::new(QFunctionTrampolineData { 597 number_inputs, 598 number_outputs, 599 input_sizes, 600 output_sizes, 601 user_f, 602 })) 603 }; 604 605 // Create QFunction 606 let vlength = i32::try_from(vlength).unwrap(); 607 let mut ierr = unsafe { 608 bind_ceed::CeedQFunctionCreateInterior( 609 ceed.ptr, 610 vlength, 611 Some(trampoline), 612 source_c.as_ptr(), 613 &mut ptr, 614 ) 615 }; 616 ceed.check_error(ierr)?; 617 618 // Set closure 619 let mut qf_ctx_ptr = std::ptr::null_mut(); 620 ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) }; 621 ceed.check_error(ierr)?; 622 ierr = unsafe { 623 bind_ceed::CeedQFunctionContextSetData( 624 qf_ctx_ptr, 625 crate::MemType::Host as bind_ceed::CeedMemType, 626 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode, 627 std::mem::size_of::<QFunctionTrampolineData>(), 628 std::mem::transmute(trampoline_data.as_ref()), 629 ) 630 }; 631 ceed.check_error(ierr)?; 632 ierr = unsafe { 633 bind_ceed::CeedQFunctionContextSetDataDestroy( 634 qf_ctx_ptr, 635 crate::MemType::Host as bind_ceed::CeedMemType, 636 Some(destroy_trampoline), 637 ) 638 }; 639 ceed.check_error(ierr)?; 640 ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) }; 641 ceed.check_error(ierr)?; 642 Ok(Self { 643 qf_core: QFunctionCore { 644 ptr, 645 _lifeline: PhantomData, 646 }, 647 qf_ctx_ptr, 648 trampoline_data, 649 }) 650 } 651 652 /// Apply the action of a QFunction 653 /// 654 /// * `Q` - The number of quadrature points 655 /// * `input` - Array of input Vectors 656 /// * `output` - Array of output Vectors 657 /// 658 /// ``` 659 /// # use libceed::prelude::*; 660 /// # fn main() -> libceed::Result<()> { 661 /// # let ceed = libceed::Ceed::default_init(); 662 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 663 /// // Iterate over quadrature points 664 /// v.iter_mut() 665 /// .zip(u.iter().zip(weights.iter())) 666 /// .for_each(|(v, (u, w))| *v = u * w); 667 /// 668 /// // Return clean error code 669 /// 0 670 /// }; 671 /// 672 /// let qf = ceed 673 /// .q_function_interior(1, Box::new(user_f))? 674 /// .input("u", 1, EvalMode::Interp)? 675 /// .input("weights", 1, EvalMode::Weight)? 676 /// .output("v", 1, EvalMode::Interp)?; 677 /// 678 /// const Q: usize = 8; 679 /// let mut w = [0.; Q]; 680 /// let mut u = [0.; Q]; 681 /// let mut v = [0.; Q]; 682 /// 683 /// for i in 0..Q { 684 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 685 /// u[i] = 2. + 3. * x + 5. * x * x; 686 /// w[i] = 1. - x * x; 687 /// v[i] = u[i] * w[i]; 688 /// } 689 /// 690 /// let uu = ceed.vector_from_slice(&u)?; 691 /// let ww = ceed.vector_from_slice(&w)?; 692 /// let mut vv = ceed.vector(Q)?; 693 /// vv.set_value(0.0); 694 /// { 695 /// let input = vec![uu, ww]; 696 /// let mut output = vec![vv]; 697 /// qf.apply(Q, &input, &output)?; 698 /// vv = output.remove(0); 699 /// } 700 /// 701 /// vv.view()? 702 /// .iter() 703 /// .zip(v.iter()) 704 /// .for_each(|(computed, actual)| { 705 /// assert_eq!( 706 /// *computed, *actual, 707 /// "Incorrect value in QFunction application" 708 /// ); 709 /// }); 710 /// # Ok(()) 711 /// # } 712 /// ``` 713 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 714 self.qf_core.apply(Q, u, v) 715 } 716 717 /// Add a QFunction input 718 /// 719 /// * `fieldname` - Name of QFunction field 720 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 721 /// `(ncomp * 1)` for `None`, `Interp`, and `Weight` 722 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 723 /// to use interpolated values, `EvalMode::Grad` to use 724 /// gradients, `EvalMode::Weight` to use quadrature weights 725 /// 726 /// ``` 727 /// # use libceed::prelude::*; 728 /// # fn main() -> libceed::Result<()> { 729 /// # let ceed = libceed::Ceed::default_init(); 730 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 731 /// // Iterate over quadrature points 732 /// v.iter_mut() 733 /// .zip(u.iter().zip(weights.iter())) 734 /// .for_each(|(v, (u, w))| *v = u * w); 735 /// 736 /// // Return clean error code 737 /// 0 738 /// }; 739 /// 740 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 741 /// 742 /// qf = qf.input("u", 1, EvalMode::Interp)?; 743 /// qf = qf.input("weights", 1, EvalMode::Weight)?; 744 /// # Ok(()) 745 /// # } 746 /// ``` 747 pub fn input( 748 mut self, 749 fieldname: &str, 750 size: usize, 751 emode: crate::EvalMode, 752 ) -> crate::Result<Self> { 753 let name_c = CString::new(fieldname).expect("CString::new failed"); 754 let idx = self.trampoline_data.number_inputs; 755 self.trampoline_data.input_sizes[idx] = size; 756 self.trampoline_data.number_inputs += 1; 757 let (size, emode) = ( 758 i32::try_from(size).unwrap(), 759 emode as bind_ceed::CeedEvalMode, 760 ); 761 let ierr = unsafe { 762 bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 763 }; 764 self.qf_core.check_error(ierr)?; 765 Ok(self) 766 } 767 768 /// Add a QFunction output 769 /// 770 /// * `fieldname` - Name of QFunction field 771 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 772 /// `(ncomp * 1)` for `None` and `Interp` 773 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 774 /// to use interpolated values, `EvalMode::Grad` to use 775 /// gradients 776 /// 777 /// ``` 778 /// # use libceed::prelude::*; 779 /// # fn main() -> libceed::Result<()> { 780 /// # let ceed = libceed::Ceed::default_init(); 781 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 782 /// // Iterate over quadrature points 783 /// v.iter_mut() 784 /// .zip(u.iter().zip(weights.iter())) 785 /// .for_each(|(v, (u, w))| *v = u * w); 786 /// 787 /// // Return clean error code 788 /// 0 789 /// }; 790 /// 791 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 792 /// 793 /// qf.output("v", 1, EvalMode::Interp)?; 794 /// # Ok(()) 795 /// # } 796 /// ``` 797 pub fn output( 798 mut self, 799 fieldname: &str, 800 size: usize, 801 emode: crate::EvalMode, 802 ) -> crate::Result<Self> { 803 let name_c = CString::new(fieldname).expect("CString::new failed"); 804 let idx = self.trampoline_data.number_outputs; 805 self.trampoline_data.output_sizes[idx] = size; 806 self.trampoline_data.number_outputs += 1; 807 let (size, emode) = ( 808 i32::try_from(size).unwrap(), 809 emode as bind_ceed::CeedEvalMode, 810 ); 811 let ierr = unsafe { 812 bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 813 }; 814 self.qf_core.check_error(ierr)?; 815 Ok(self) 816 } 817 818 /// Get a slice of QFunction inputs 819 /// 820 /// ``` 821 /// # use libceed::prelude::*; 822 /// # fn main() -> libceed::Result<()> { 823 /// # let ceed = libceed::Ceed::default_init(); 824 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 825 /// // Iterate over quadrature points 826 /// v.iter_mut() 827 /// .zip(u.iter().zip(weights.iter())) 828 /// .for_each(|(v, (u, w))| *v = u * w); 829 /// 830 /// // Return clean error code 831 /// 0 832 /// }; 833 /// 834 /// let mut qf = ceed 835 /// .q_function_interior(1, Box::new(user_f))? 836 /// .input("u", 1, EvalMode::Interp)? 837 /// .input("weights", 1, EvalMode::Weight)?; 838 /// 839 /// let inputs = qf.inputs()?; 840 /// 841 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array"); 842 /// # Ok(()) 843 /// # } 844 /// ``` 845 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 846 self.qf_core.inputs() 847 } 848 849 /// Get a slice of QFunction outputs 850 /// 851 /// ``` 852 /// # use libceed::prelude::*; 853 /// # fn main() -> libceed::Result<()> { 854 /// # let ceed = libceed::Ceed::default_init(); 855 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 856 /// // Iterate over quadrature points 857 /// v.iter_mut() 858 /// .zip(u.iter().zip(weights.iter())) 859 /// .for_each(|(v, (u, w))| *v = u * w); 860 /// 861 /// // Return clean error code 862 /// 0 863 /// }; 864 /// 865 /// let mut qf = ceed 866 /// .q_function_interior(1, Box::new(user_f))? 867 /// .output("v", 1, EvalMode::Interp)?; 868 /// 869 /// let outputs = qf.outputs()?; 870 /// 871 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array"); 872 /// # Ok(()) 873 /// # } 874 /// ``` 875 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 876 self.qf_core.outputs() 877 } 878 } 879 880 // ----------------------------------------------------------------------------- 881 // QFunction 882 // ----------------------------------------------------------------------------- 883 impl<'a> QFunctionByName<'a> { 884 // Constructor 885 pub fn create(ceed: &crate::Ceed, name: &str) -> crate::Result<Self> { 886 let name_c = CString::new(name).expect("CString::new failed"); 887 let mut ptr = std::ptr::null_mut(); 888 let ierr = unsafe { 889 bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr) 890 }; 891 ceed.check_error(ierr)?; 892 Ok(Self { 893 qf_core: QFunctionCore { 894 ptr, 895 _lifeline: PhantomData, 896 }, 897 }) 898 } 899 900 /// Apply the action of a QFunction 901 /// 902 /// * `Q` - The number of quadrature points 903 /// * `input` - Array of input Vectors 904 /// * `output` - Array of output Vectors 905 /// 906 /// ``` 907 /// # use libceed::prelude::*; 908 /// # fn main() -> libceed::Result<()> { 909 /// # let ceed = libceed::Ceed::default_init(); 910 /// const Q: usize = 8; 911 /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild")?; 912 /// let qf_mass = ceed.q_function_interior_by_name("MassApply")?; 913 /// 914 /// let mut j = [0.; Q]; 915 /// let mut w = [0.; Q]; 916 /// let mut u = [0.; Q]; 917 /// let mut v = [0.; Q]; 918 /// 919 /// for i in 0..Q { 920 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 921 /// j[i] = 1.; 922 /// w[i] = 1. - x * x; 923 /// u[i] = 2. + 3. * x + 5. * x * x; 924 /// v[i] = w[i] * u[i]; 925 /// } 926 /// 927 /// let jj = ceed.vector_from_slice(&j)?; 928 /// let ww = ceed.vector_from_slice(&w)?; 929 /// let uu = ceed.vector_from_slice(&u)?; 930 /// let mut vv = ceed.vector(Q)?; 931 /// vv.set_value(0.0); 932 /// let mut qdata = ceed.vector(Q)?; 933 /// qdata.set_value(0.0); 934 /// 935 /// { 936 /// let mut input = vec![jj, ww]; 937 /// let mut output = vec![qdata]; 938 /// qf_build.apply(Q, &input, &output)?; 939 /// qdata = output.remove(0); 940 /// } 941 /// 942 /// { 943 /// let mut input = vec![qdata, uu]; 944 /// let mut output = vec![vv]; 945 /// qf_mass.apply(Q, &input, &output)?; 946 /// vv = output.remove(0); 947 /// } 948 /// 949 /// vv.view()? 950 /// .iter() 951 /// .zip(v.iter()) 952 /// .for_each(|(computed, actual)| { 953 /// assert_eq!( 954 /// *computed, *actual, 955 /// "Incorrect value in QFunction application" 956 /// ); 957 /// }); 958 /// # Ok(()) 959 /// # } 960 /// ``` 961 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 962 self.qf_core.apply(Q, u, v) 963 } 964 965 /// Get a slice of QFunction inputs 966 /// 967 /// ``` 968 /// # use libceed::prelude::*; 969 /// # fn main() -> libceed::Result<()> { 970 /// # let ceed = libceed::Ceed::default_init(); 971 /// const Q: usize = 8; 972 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 973 /// 974 /// let inputs = qf.inputs()?; 975 /// 976 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array"); 977 /// # Ok(()) 978 /// # } 979 /// ``` 980 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 981 self.qf_core.inputs() 982 } 983 984 /// Get a slice of QFunction outputs 985 /// 986 /// ``` 987 /// # use libceed::prelude::*; 988 /// # fn main() -> libceed::Result<()> { 989 /// # let ceed = libceed::Ceed::default_init(); 990 /// const Q: usize = 8; 991 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 992 /// 993 /// let outputs = qf.outputs()?; 994 /// 995 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array"); 996 /// # Ok(()) 997 /// # } 998 /// ``` 999 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 1000 self.qf_core.outputs() 1001 } 1002 } 1003 1004 // ----------------------------------------------------------------------------- 1005