1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 //! A Ceed QFunction represents the spatial terms of the point-wise functions 9 //! describing the physics at the quadrature points. 10 11 use std::pin::Pin; 12 13 use crate::prelude::*; 14 15 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 16 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 17 18 // ----------------------------------------------------------------------------- 19 // QFunction Field context wrapper 20 // ----------------------------------------------------------------------------- 21 #[derive(Debug)] 22 pub struct QFunctionField<'a> { 23 ptr: bind_ceed::CeedQFunctionField, 24 _lifeline: PhantomData<&'a ()>, 25 } 26 27 // ----------------------------------------------------------------------------- 28 // Implementations 29 // ----------------------------------------------------------------------------- 30 impl<'a> QFunctionField<'a> { 31 /// Get the name of a QFunctionField 32 /// 33 /// ``` 34 /// # use libceed::prelude::*; 35 /// # fn main() -> libceed::Result<()> { 36 /// # let ceed = libceed::Ceed::default_init(); 37 /// const Q: usize = 8; 38 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 39 /// 40 /// let inputs = qf.inputs()?; 41 /// 42 /// assert_eq!(inputs[0].name(), "dx", "Incorrect input name"); 43 /// assert_eq!(inputs[1].name(), "weights", "Incorrect input name"); 44 /// # Ok(()) 45 /// # } 46 /// ``` 47 pub fn name(&self) -> &str { 48 let mut name_ptr: *mut std::os::raw::c_char = std::ptr::null_mut(); 49 unsafe { 50 bind_ceed::CeedQFunctionFieldGetName(self.ptr, &mut name_ptr); 51 } 52 unsafe { CStr::from_ptr(name_ptr) }.to_str().unwrap() 53 } 54 55 /// Get the size of a QFunctionField 56 /// 57 /// ``` 58 /// # use libceed::prelude::*; 59 /// # fn main() -> libceed::Result<()> { 60 /// # let ceed = libceed::Ceed::default_init(); 61 /// const Q: usize = 8; 62 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 63 /// 64 /// let inputs = qf.inputs()?; 65 /// 66 /// assert_eq!(inputs[0].size(), 4, "Incorrect input size"); 67 /// assert_eq!(inputs[1].size(), 1, "Incorrect input size"); 68 /// # Ok(()) 69 /// # } 70 /// ``` 71 pub fn size(&self) -> usize { 72 let mut size = 0; 73 unsafe { 74 bind_ceed::CeedQFunctionFieldGetSize(self.ptr, &mut size); 75 } 76 usize::try_from(size).unwrap() 77 } 78 79 /// Get the evaluation mode of a QFunctionField 80 /// 81 /// ``` 82 /// # use libceed::prelude::*; 83 /// # fn main() -> libceed::Result<()> { 84 /// # let ceed = libceed::Ceed::default_init(); 85 /// const Q: usize = 8; 86 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 87 /// 88 /// let inputs = qf.inputs()?; 89 /// 90 /// assert_eq!( 91 /// inputs[0].eval_mode(), 92 /// EvalMode::Grad, 93 /// "Incorrect input evaluation mode" 94 /// ); 95 /// assert_eq!( 96 /// inputs[1].eval_mode(), 97 /// EvalMode::Weight, 98 /// "Incorrect input evaluation mode" 99 /// ); 100 /// # Ok(()) 101 /// # } 102 /// ``` 103 pub fn eval_mode(&self) -> crate::EvalMode { 104 let mut mode = 0; 105 unsafe { 106 bind_ceed::CeedQFunctionFieldGetEvalMode(self.ptr, &mut mode); 107 } 108 crate::EvalMode::from_u32(mode as u32) 109 } 110 } 111 112 // ----------------------------------------------------------------------------- 113 // QFunction option 114 // ----------------------------------------------------------------------------- 115 pub enum QFunctionOpt<'a> { 116 SomeQFunction(&'a QFunction<'a>), 117 SomeQFunctionByName(&'a QFunctionByName<'a>), 118 None, 119 } 120 121 /// Construct a QFunctionOpt reference from a QFunction reference 122 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> { 123 fn from(qfunc: &'a QFunction) -> Self { 124 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 125 Self::SomeQFunction(qfunc) 126 } 127 } 128 129 /// Construct a QFunctionOpt reference from a QFunction by Name reference 130 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> { 131 fn from(qfunc: &'a QFunctionByName) -> Self { 132 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 133 Self::SomeQFunctionByName(qfunc) 134 } 135 } 136 137 impl<'a> QFunctionOpt<'a> { 138 /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction 139 pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction { 140 match self { 141 Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr, 142 Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr, 143 Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE }, 144 } 145 } 146 147 /// Check if a QFunctionOpt is Some 148 /// 149 /// ``` 150 /// # use libceed::prelude::*; 151 /// # fn main() -> libceed::Result<()> { 152 /// # let ceed = libceed::Ceed::default_init(); 153 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 154 /// // Iterate over quadrature points 155 /// v.iter_mut() 156 /// .zip(u.iter().zip(weights.iter())) 157 /// .for_each(|(v, (u, w))| *v = u * w); 158 /// 159 /// // Return clean error code 160 /// 0 161 /// }; 162 /// 163 /// let qf = ceed 164 /// .q_function_interior(1, Box::new(user_f))? 165 /// .input("u", 1, EvalMode::Interp)? 166 /// .input("weights", 1, EvalMode::Weight)? 167 /// .output("v", 1, EvalMode::Interp)?; 168 /// let qf_opt = QFunctionOpt::from(&qf); 169 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt"); 170 /// 171 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 172 /// let qf_opt = QFunctionOpt::from(&qf); 173 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt"); 174 /// 175 /// let qf_opt = QFunctionOpt::None; 176 /// assert!(!qf_opt.is_some(), "Incorrect QFunctionOpt"); 177 /// # Ok(()) 178 /// # } 179 /// ``` 180 pub fn is_some(&self) -> bool { 181 match self { 182 Self::SomeQFunction(_) => true, 183 Self::SomeQFunctionByName(_) => true, 184 Self::None => false, 185 } 186 } 187 188 /// Check if a QFunctionOpt is SomeQFunction 189 /// 190 /// ``` 191 /// # use libceed::prelude::*; 192 /// # fn main() -> libceed::Result<()> { 193 /// # let ceed = libceed::Ceed::default_init(); 194 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 195 /// // Iterate over quadrature points 196 /// v.iter_mut() 197 /// .zip(u.iter().zip(weights.iter())) 198 /// .for_each(|(v, (u, w))| *v = u * w); 199 /// 200 /// // Return clean error code 201 /// 0 202 /// }; 203 /// 204 /// let qf = ceed 205 /// .q_function_interior(1, Box::new(user_f))? 206 /// .input("u", 1, EvalMode::Interp)? 207 /// .input("weights", 1, EvalMode::Weight)? 208 /// .output("v", 1, EvalMode::Interp)?; 209 /// let qf_opt = QFunctionOpt::from(&qf); 210 /// assert!(qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 211 /// 212 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 213 /// let qf_opt = QFunctionOpt::from(&qf); 214 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 215 /// 216 /// let qf_opt = QFunctionOpt::None; 217 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 218 /// # Ok(()) 219 /// # } 220 /// ``` 221 pub fn is_some_q_function(&self) -> bool { 222 match self { 223 Self::SomeQFunction(_) => true, 224 Self::SomeQFunctionByName(_) => false, 225 Self::None => false, 226 } 227 } 228 229 /// Check if a QFunctionOpt is SomeQFunctionByName 230 /// 231 /// ``` 232 /// # use libceed::prelude::*; 233 /// # fn main() -> libceed::Result<()> { 234 /// # let ceed = libceed::Ceed::default_init(); 235 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 236 /// // Iterate over quadrature points 237 /// v.iter_mut() 238 /// .zip(u.iter().zip(weights.iter())) 239 /// .for_each(|(v, (u, w))| *v = u * w); 240 /// 241 /// // Return clean error code 242 /// 0 243 /// }; 244 /// 245 /// let qf = ceed 246 /// .q_function_interior(1, Box::new(user_f))? 247 /// .input("u", 1, EvalMode::Interp)? 248 /// .input("weights", 1, EvalMode::Weight)? 249 /// .output("v", 1, EvalMode::Interp)?; 250 /// let qf_opt = QFunctionOpt::from(&qf); 251 /// assert!( 252 /// !qf_opt.is_some_q_function_by_name(), 253 /// "Incorrect QFunctionOpt" 254 /// ); 255 /// 256 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 257 /// let qf_opt = QFunctionOpt::from(&qf); 258 /// assert!( 259 /// qf_opt.is_some_q_function_by_name(), 260 /// "Incorrect QFunctionOpt" 261 /// ); 262 /// 263 /// let qf_opt = QFunctionOpt::None; 264 /// assert!( 265 /// !qf_opt.is_some_q_function_by_name(), 266 /// "Incorrect QFunctionOpt" 267 /// ); 268 /// # Ok(()) 269 /// # } 270 /// ``` 271 pub fn is_some_q_function_by_name(&self) -> bool { 272 match self { 273 Self::SomeQFunction(_) => false, 274 Self::SomeQFunctionByName(_) => true, 275 Self::None => false, 276 } 277 } 278 279 /// Check if a QFunctionOpt is None 280 /// 281 /// ``` 282 /// # use libceed::prelude::*; 283 /// # fn main() -> libceed::Result<()> { 284 /// # let ceed = libceed::Ceed::default_init(); 285 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 286 /// // Iterate over quadrature points 287 /// v.iter_mut() 288 /// .zip(u.iter().zip(weights.iter())) 289 /// .for_each(|(v, (u, w))| *v = u * w); 290 /// 291 /// // Return clean error code 292 /// 0 293 /// }; 294 /// 295 /// let qf = ceed 296 /// .q_function_interior(1, Box::new(user_f))? 297 /// .input("u", 1, EvalMode::Interp)? 298 /// .input("weights", 1, EvalMode::Weight)? 299 /// .output("v", 1, EvalMode::Interp)?; 300 /// let qf_opt = QFunctionOpt::from(&qf); 301 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt"); 302 /// 303 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 304 /// let qf_opt = QFunctionOpt::from(&qf); 305 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt"); 306 /// 307 /// let qf_opt = QFunctionOpt::None; 308 /// assert!(qf_opt.is_none(), "Incorrect QFunctionOpt"); 309 /// # Ok(()) 310 /// # } 311 /// ``` 312 pub fn is_none(&self) -> bool { 313 match self { 314 Self::SomeQFunction(_) => false, 315 Self::SomeQFunctionByName(_) => false, 316 Self::None => true, 317 } 318 } 319 } 320 321 // ----------------------------------------------------------------------------- 322 // QFunction context wrapper 323 // ----------------------------------------------------------------------------- 324 #[derive(Debug)] 325 pub(crate) struct QFunctionCore<'a> { 326 ptr: bind_ceed::CeedQFunction, 327 _lifeline: PhantomData<&'a ()>, 328 } 329 330 struct QFunctionTrampolineData { 331 number_inputs: usize, 332 number_outputs: usize, 333 input_sizes: [usize; MAX_QFUNCTION_FIELDS], 334 output_sizes: [usize; MAX_QFUNCTION_FIELDS], 335 user_f: Box<QFunctionUserClosure>, 336 } 337 338 pub struct QFunction<'a> { 339 qf_core: QFunctionCore<'a>, 340 qf_ctx_ptr: bind_ceed::CeedQFunctionContext, 341 trampoline_data: Pin<Box<QFunctionTrampolineData>>, 342 } 343 344 #[derive(Debug)] 345 pub struct QFunctionByName<'a> { 346 qf_core: QFunctionCore<'a>, 347 } 348 349 // ----------------------------------------------------------------------------- 350 // Destructor 351 // ----------------------------------------------------------------------------- 352 impl<'a> Drop for QFunctionCore<'a> { 353 fn drop(&mut self) { 354 unsafe { 355 if self.ptr != bind_ceed::CEED_QFUNCTION_NONE { 356 bind_ceed::CeedQFunctionDestroy(&mut self.ptr); 357 } 358 } 359 } 360 } 361 362 impl<'a> Drop for QFunction<'a> { 363 fn drop(&mut self) { 364 unsafe { 365 bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr); 366 } 367 } 368 } 369 370 // ----------------------------------------------------------------------------- 371 // Display 372 // ----------------------------------------------------------------------------- 373 impl<'a> fmt::Display for QFunctionCore<'a> { 374 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 375 let mut ptr = std::ptr::null_mut(); 376 let mut sizeloc = crate::MAX_BUFFER_LENGTH; 377 let cstring = unsafe { 378 let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc); 379 bind_ceed::CeedQFunctionView(self.ptr, file); 380 bind_ceed::fclose(file); 381 CString::from_raw(ptr) 382 }; 383 cstring.to_string_lossy().fmt(f) 384 } 385 } 386 /// View a QFunction 387 /// 388 /// ``` 389 /// # use libceed::prelude::*; 390 /// # fn main() -> libceed::Result<()> { 391 /// # let ceed = libceed::Ceed::default_init(); 392 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 393 /// // Iterate over quadrature points 394 /// v.iter_mut() 395 /// .zip(u.iter().zip(weights.iter())) 396 /// .for_each(|(v, (u, w))| *v = u * w); 397 /// 398 /// // Return clean error code 399 /// 0 400 /// }; 401 /// 402 /// let qf = ceed 403 /// .q_function_interior(1, Box::new(user_f))? 404 /// .input("u", 1, EvalMode::Interp)? 405 /// .input("weights", 1, EvalMode::Weight)? 406 /// .output("v", 1, EvalMode::Interp)?; 407 /// 408 /// println!("{}", qf); 409 /// # Ok(()) 410 /// # } 411 /// ``` 412 impl<'a> fmt::Display for QFunction<'a> { 413 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 414 self.qf_core.fmt(f) 415 } 416 } 417 418 /// View a QFunction by Name 419 /// 420 /// ``` 421 /// # use libceed::prelude::*; 422 /// # fn main() -> libceed::Result<()> { 423 /// # let ceed = libceed::Ceed::default_init(); 424 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 425 /// println!("{}", qf); 426 /// # Ok(()) 427 /// # } 428 /// ``` 429 impl<'a> fmt::Display for QFunctionByName<'a> { 430 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 431 self.qf_core.fmt(f) 432 } 433 } 434 435 // ----------------------------------------------------------------------------- 436 // Core functionality 437 // ----------------------------------------------------------------------------- 438 impl<'a> QFunctionCore<'a> { 439 // Error handling 440 #[doc(hidden)] 441 fn check_error(&self, ierr: i32) -> crate::Result<i32> { 442 let mut ptr = std::ptr::null_mut(); 443 unsafe { 444 bind_ceed::CeedQFunctionGetCeed(self.ptr, &mut ptr); 445 } 446 crate::check_error(ptr, ierr) 447 } 448 449 // Common implementation 450 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 451 let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 452 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) { 453 u_c[i] = u[i].ptr; 454 } 455 let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 456 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) { 457 v_c[i] = v[i].ptr; 458 } 459 let Q = i32::try_from(Q).unwrap(); 460 let ierr = unsafe { 461 bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr()) 462 }; 463 self.check_error(ierr) 464 } 465 466 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 467 // Get array of raw C pointers for inputs 468 let mut num_inputs = 0; 469 let mut inputs_ptr = std::ptr::null_mut(); 470 let ierr = unsafe { 471 bind_ceed::CeedQFunctionGetFields( 472 self.ptr, 473 &mut num_inputs, 474 &mut inputs_ptr, 475 std::ptr::null_mut() as *mut bind_ceed::CeedInt, 476 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField, 477 ) 478 }; 479 self.check_error(ierr)?; 480 // Convert raw C pointers to fixed length slice 481 let inputs_slice = unsafe { 482 std::slice::from_raw_parts( 483 inputs_ptr as *const crate::QFunctionField, 484 num_inputs as usize, 485 ) 486 }; 487 Ok(inputs_slice) 488 } 489 490 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 491 // Get array of raw C pointers for outputs 492 let mut num_outputs = 0; 493 let mut outputs_ptr = std::ptr::null_mut(); 494 let ierr = unsafe { 495 bind_ceed::CeedQFunctionGetFields( 496 self.ptr, 497 std::ptr::null_mut() as *mut bind_ceed::CeedInt, 498 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField, 499 &mut num_outputs, 500 &mut outputs_ptr, 501 ) 502 }; 503 self.check_error(ierr)?; 504 // Convert raw C pointers to fixed length slice 505 let outputs_slice = unsafe { 506 std::slice::from_raw_parts( 507 outputs_ptr as *const crate::QFunctionField, 508 num_outputs as usize, 509 ) 510 }; 511 Ok(outputs_slice) 512 } 513 } 514 515 // ----------------------------------------------------------------------------- 516 // User QFunction Closure 517 // ----------------------------------------------------------------------------- 518 pub type QFunctionUserClosure = dyn FnMut( 519 [&[crate::Scalar]; MAX_QFUNCTION_FIELDS], 520 [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS], 521 ) -> i32; 522 523 macro_rules! mut_max_fields { 524 ($e:expr) => { 525 [ 526 $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, 527 ] 528 }; 529 } 530 unsafe extern "C" fn trampoline( 531 ctx: *mut ::std::os::raw::c_void, 532 q: bind_ceed::CeedInt, 533 inputs: *const *const bind_ceed::CeedScalar, 534 outputs: *const *mut bind_ceed::CeedScalar, 535 ) -> ::std::os::raw::c_int { 536 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx); 537 538 // Inputs 539 let inputs_slice: &[*const bind_ceed::CeedScalar] = 540 std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS); 541 let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS]; 542 inputs_slice 543 .iter() 544 .take(trampoline_data.number_inputs) 545 .enumerate() 546 .map(|(i, &x)| { 547 std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize) 548 as &[crate::Scalar] 549 }) 550 .zip(inputs_array.iter_mut()) 551 .for_each(|(x, a)| *a = x); 552 553 // Outputs 554 let outputs_slice: &[*mut bind_ceed::CeedScalar] = 555 std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS); 556 let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] = 557 mut_max_fields!(&mut [0.0]); 558 outputs_slice 559 .iter() 560 .take(trampoline_data.number_outputs) 561 .enumerate() 562 .map(|(i, &x)| { 563 std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize) 564 as &mut [crate::Scalar] 565 }) 566 .zip(outputs_array.iter_mut()) 567 .for_each(|(x, a)| *a = x); 568 569 // User closure 570 (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array) 571 } 572 573 unsafe extern "C" fn destroy_trampoline(ctx: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int { 574 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx); 575 drop(trampoline_data); 576 0 // Clean error code 577 } 578 579 // ----------------------------------------------------------------------------- 580 // QFunction 581 // ----------------------------------------------------------------------------- 582 impl<'a> QFunction<'a> { 583 // Constructor 584 pub fn create( 585 ceed: &crate::Ceed, 586 vlength: usize, 587 user_f: Box<QFunctionUserClosure>, 588 ) -> crate::Result<Self> { 589 let source_c = CString::new("").expect("CString::new failed"); 590 let mut ptr = std::ptr::null_mut(); 591 592 // Context for closure 593 let number_inputs = 0; 594 let number_outputs = 0; 595 let input_sizes = [0; MAX_QFUNCTION_FIELDS]; 596 let output_sizes = [0; MAX_QFUNCTION_FIELDS]; 597 let trampoline_data = unsafe { 598 Pin::new_unchecked(Box::new(QFunctionTrampolineData { 599 number_inputs, 600 number_outputs, 601 input_sizes, 602 output_sizes, 603 user_f, 604 })) 605 }; 606 607 // Create QFunction 608 let vlength = i32::try_from(vlength).unwrap(); 609 let mut ierr = unsafe { 610 bind_ceed::CeedQFunctionCreateInterior( 611 ceed.ptr, 612 vlength, 613 Some(trampoline), 614 source_c.as_ptr(), 615 &mut ptr, 616 ) 617 }; 618 ceed.check_error(ierr)?; 619 620 // Set closure 621 let mut qf_ctx_ptr = std::ptr::null_mut(); 622 ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) }; 623 ceed.check_error(ierr)?; 624 ierr = unsafe { 625 bind_ceed::CeedQFunctionContextSetData( 626 qf_ctx_ptr, 627 crate::MemType::Host as bind_ceed::CeedMemType, 628 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode, 629 std::mem::size_of::<QFunctionTrampolineData>(), 630 std::mem::transmute(trampoline_data.as_ref()), 631 ) 632 }; 633 ceed.check_error(ierr)?; 634 ierr = unsafe { 635 bind_ceed::CeedQFunctionContextSetDataDestroy( 636 qf_ctx_ptr, 637 crate::MemType::Host as bind_ceed::CeedMemType, 638 Some(destroy_trampoline), 639 ) 640 }; 641 ceed.check_error(ierr)?; 642 ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) }; 643 ceed.check_error(ierr)?; 644 Ok(Self { 645 qf_core: QFunctionCore { 646 ptr, 647 _lifeline: PhantomData, 648 }, 649 qf_ctx_ptr, 650 trampoline_data, 651 }) 652 } 653 654 /// Apply the action of a QFunction 655 /// 656 /// * `Q` - The number of quadrature points 657 /// * `input` - Array of input Vectors 658 /// * `output` - Array of output Vectors 659 /// 660 /// ``` 661 /// # use libceed::prelude::*; 662 /// # fn main() -> libceed::Result<()> { 663 /// # let ceed = libceed::Ceed::default_init(); 664 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 665 /// // Iterate over quadrature points 666 /// v.iter_mut() 667 /// .zip(u.iter().zip(weights.iter())) 668 /// .for_each(|(v, (u, w))| *v = u * w); 669 /// 670 /// // Return clean error code 671 /// 0 672 /// }; 673 /// 674 /// let qf = ceed 675 /// .q_function_interior(1, Box::new(user_f))? 676 /// .input("u", 1, EvalMode::Interp)? 677 /// .input("weights", 1, EvalMode::Weight)? 678 /// .output("v", 1, EvalMode::Interp)?; 679 /// 680 /// const Q: usize = 8; 681 /// let mut w = [0.; Q]; 682 /// let mut u = [0.; Q]; 683 /// let mut v = [0.; Q]; 684 /// 685 /// for i in 0..Q { 686 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 687 /// u[i] = 2. + 3. * x + 5. * x * x; 688 /// w[i] = 1. - x * x; 689 /// v[i] = u[i] * w[i]; 690 /// } 691 /// 692 /// let uu = ceed.vector_from_slice(&u)?; 693 /// let ww = ceed.vector_from_slice(&w)?; 694 /// let mut vv = ceed.vector(Q)?; 695 /// vv.set_value(0.0); 696 /// { 697 /// let input = vec![uu, ww]; 698 /// let mut output = vec![vv]; 699 /// qf.apply(Q, &input, &output)?; 700 /// vv = output.remove(0); 701 /// } 702 /// 703 /// vv.view()? 704 /// .iter() 705 /// .zip(v.iter()) 706 /// .for_each(|(computed, actual)| { 707 /// assert_eq!( 708 /// *computed, *actual, 709 /// "Incorrect value in QFunction application" 710 /// ); 711 /// }); 712 /// # Ok(()) 713 /// # } 714 /// ``` 715 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 716 self.qf_core.apply(Q, u, v) 717 } 718 719 /// Add a QFunction input 720 /// 721 /// * `fieldname` - Name of QFunction field 722 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 723 /// `(ncomp * 1)` for `None`, `Interp`, and `Weight` 724 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 725 /// to use interpolated values, `EvalMode::Grad` to use 726 /// gradients, `EvalMode::Weight` to use quadrature weights 727 /// 728 /// ``` 729 /// # use libceed::prelude::*; 730 /// # fn main() -> libceed::Result<()> { 731 /// # let ceed = libceed::Ceed::default_init(); 732 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 733 /// // Iterate over quadrature points 734 /// v.iter_mut() 735 /// .zip(u.iter().zip(weights.iter())) 736 /// .for_each(|(v, (u, w))| *v = u * w); 737 /// 738 /// // Return clean error code 739 /// 0 740 /// }; 741 /// 742 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 743 /// 744 /// qf = qf.input("u", 1, EvalMode::Interp)?; 745 /// qf = qf.input("weights", 1, EvalMode::Weight)?; 746 /// # Ok(()) 747 /// # } 748 /// ``` 749 pub fn input( 750 mut self, 751 fieldname: &str, 752 size: usize, 753 emode: crate::EvalMode, 754 ) -> crate::Result<Self> { 755 let name_c = CString::new(fieldname).expect("CString::new failed"); 756 let idx = self.trampoline_data.number_inputs; 757 self.trampoline_data.input_sizes[idx] = size; 758 self.trampoline_data.number_inputs += 1; 759 let (size, emode) = ( 760 i32::try_from(size).unwrap(), 761 emode as bind_ceed::CeedEvalMode, 762 ); 763 let ierr = unsafe { 764 bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 765 }; 766 self.qf_core.check_error(ierr)?; 767 Ok(self) 768 } 769 770 /// Add a QFunction output 771 /// 772 /// * `fieldname` - Name of QFunction field 773 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 774 /// `(ncomp * 1)` for `None` and `Interp` 775 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 776 /// to use interpolated values, `EvalMode::Grad` to use 777 /// gradients 778 /// 779 /// ``` 780 /// # use libceed::prelude::*; 781 /// # fn main() -> libceed::Result<()> { 782 /// # let ceed = libceed::Ceed::default_init(); 783 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 784 /// // Iterate over quadrature points 785 /// v.iter_mut() 786 /// .zip(u.iter().zip(weights.iter())) 787 /// .for_each(|(v, (u, w))| *v = u * w); 788 /// 789 /// // Return clean error code 790 /// 0 791 /// }; 792 /// 793 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 794 /// 795 /// qf.output("v", 1, EvalMode::Interp)?; 796 /// # Ok(()) 797 /// # } 798 /// ``` 799 pub fn output( 800 mut self, 801 fieldname: &str, 802 size: usize, 803 emode: crate::EvalMode, 804 ) -> crate::Result<Self> { 805 let name_c = CString::new(fieldname).expect("CString::new failed"); 806 let idx = self.trampoline_data.number_outputs; 807 self.trampoline_data.output_sizes[idx] = size; 808 self.trampoline_data.number_outputs += 1; 809 let (size, emode) = ( 810 i32::try_from(size).unwrap(), 811 emode as bind_ceed::CeedEvalMode, 812 ); 813 let ierr = unsafe { 814 bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 815 }; 816 self.qf_core.check_error(ierr)?; 817 Ok(self) 818 } 819 820 /// Get a slice of QFunction inputs 821 /// 822 /// ``` 823 /// # use libceed::prelude::*; 824 /// # fn main() -> libceed::Result<()> { 825 /// # let ceed = libceed::Ceed::default_init(); 826 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 827 /// // Iterate over quadrature points 828 /// v.iter_mut() 829 /// .zip(u.iter().zip(weights.iter())) 830 /// .for_each(|(v, (u, w))| *v = u * w); 831 /// 832 /// // Return clean error code 833 /// 0 834 /// }; 835 /// 836 /// let mut qf = ceed 837 /// .q_function_interior(1, Box::new(user_f))? 838 /// .input("u", 1, EvalMode::Interp)? 839 /// .input("weights", 1, EvalMode::Weight)?; 840 /// 841 /// let inputs = qf.inputs()?; 842 /// 843 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array"); 844 /// # Ok(()) 845 /// # } 846 /// ``` 847 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 848 self.qf_core.inputs() 849 } 850 851 /// Get a slice of QFunction outputs 852 /// 853 /// ``` 854 /// # use libceed::prelude::*; 855 /// # fn main() -> libceed::Result<()> { 856 /// # let ceed = libceed::Ceed::default_init(); 857 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 858 /// // Iterate over quadrature points 859 /// v.iter_mut() 860 /// .zip(u.iter().zip(weights.iter())) 861 /// .for_each(|(v, (u, w))| *v = u * w); 862 /// 863 /// // Return clean error code 864 /// 0 865 /// }; 866 /// 867 /// let mut qf = ceed 868 /// .q_function_interior(1, Box::new(user_f))? 869 /// .output("v", 1, EvalMode::Interp)?; 870 /// 871 /// let outputs = qf.outputs()?; 872 /// 873 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array"); 874 /// # Ok(()) 875 /// # } 876 /// ``` 877 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 878 self.qf_core.outputs() 879 } 880 } 881 882 // ----------------------------------------------------------------------------- 883 // QFunction 884 // ----------------------------------------------------------------------------- 885 impl<'a> QFunctionByName<'a> { 886 // Constructor 887 pub fn create(ceed: &crate::Ceed, name: &str) -> crate::Result<Self> { 888 let name_c = CString::new(name).expect("CString::new failed"); 889 let mut ptr = std::ptr::null_mut(); 890 let ierr = unsafe { 891 bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr) 892 }; 893 ceed.check_error(ierr)?; 894 Ok(Self { 895 qf_core: QFunctionCore { 896 ptr, 897 _lifeline: PhantomData, 898 }, 899 }) 900 } 901 902 /// Apply the action of a QFunction 903 /// 904 /// * `Q` - The number of quadrature points 905 /// * `input` - Array of input Vectors 906 /// * `output` - Array of output Vectors 907 /// 908 /// ``` 909 /// # use libceed::prelude::*; 910 /// # fn main() -> libceed::Result<()> { 911 /// # let ceed = libceed::Ceed::default_init(); 912 /// const Q: usize = 8; 913 /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild")?; 914 /// let qf_mass = ceed.q_function_interior_by_name("MassApply")?; 915 /// 916 /// let mut j = [0.; Q]; 917 /// let mut w = [0.; Q]; 918 /// let mut u = [0.; Q]; 919 /// let mut v = [0.; Q]; 920 /// 921 /// for i in 0..Q { 922 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 923 /// j[i] = 1.; 924 /// w[i] = 1. - x * x; 925 /// u[i] = 2. + 3. * x + 5. * x * x; 926 /// v[i] = w[i] * u[i]; 927 /// } 928 /// 929 /// let jj = ceed.vector_from_slice(&j)?; 930 /// let ww = ceed.vector_from_slice(&w)?; 931 /// let uu = ceed.vector_from_slice(&u)?; 932 /// let mut vv = ceed.vector(Q)?; 933 /// vv.set_value(0.0); 934 /// let mut qdata = ceed.vector(Q)?; 935 /// qdata.set_value(0.0); 936 /// 937 /// { 938 /// let mut input = vec![jj, ww]; 939 /// let mut output = vec![qdata]; 940 /// qf_build.apply(Q, &input, &output)?; 941 /// qdata = output.remove(0); 942 /// } 943 /// 944 /// { 945 /// let mut input = vec![qdata, uu]; 946 /// let mut output = vec![vv]; 947 /// qf_mass.apply(Q, &input, &output)?; 948 /// vv = output.remove(0); 949 /// } 950 /// 951 /// vv.view()? 952 /// .iter() 953 /// .zip(v.iter()) 954 /// .for_each(|(computed, actual)| { 955 /// assert_eq!( 956 /// *computed, *actual, 957 /// "Incorrect value in QFunction application" 958 /// ); 959 /// }); 960 /// # Ok(()) 961 /// # } 962 /// ``` 963 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 964 self.qf_core.apply(Q, u, v) 965 } 966 967 /// Get a slice of QFunction inputs 968 /// 969 /// ``` 970 /// # use libceed::prelude::*; 971 /// # fn main() -> libceed::Result<()> { 972 /// # let ceed = libceed::Ceed::default_init(); 973 /// const Q: usize = 8; 974 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 975 /// 976 /// let inputs = qf.inputs()?; 977 /// 978 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array"); 979 /// # Ok(()) 980 /// # } 981 /// ``` 982 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 983 self.qf_core.inputs() 984 } 985 986 /// Get a slice of QFunction outputs 987 /// 988 /// ``` 989 /// # use libceed::prelude::*; 990 /// # fn main() -> libceed::Result<()> { 991 /// # let ceed = libceed::Ceed::default_init(); 992 /// const Q: usize = 8; 993 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 994 /// 995 /// let outputs = qf.outputs()?; 996 /// 997 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array"); 998 /// # Ok(()) 999 /// # } 1000 /// ``` 1001 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 1002 self.qf_core.outputs() 1003 } 1004 } 1005 1006 // ----------------------------------------------------------------------------- 1007