1 // Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at 2 // the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights 3 // reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative 16 17 //! A Ceed QFunction represents the spatial terms of the point-wise functions 18 //! describing the physics at the quadrature points. 19 20 use std::pin::Pin; 21 22 use crate::prelude::*; 23 24 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 25 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 26 27 // ----------------------------------------------------------------------------- 28 // CeedQFunction Field context wrapper 29 // ----------------------------------------------------------------------------- 30 #[derive(Debug)] 31 pub struct QFunctionField<'a> { 32 ptr: bind_ceed::CeedQFunctionField, 33 _lifeline: PhantomData<&'a ()>, 34 } 35 36 // ----------------------------------------------------------------------------- 37 // Implementations 38 // ----------------------------------------------------------------------------- 39 impl<'a> QFunctionField<'a> { 40 /// Get the name of a QFunctionField 41 /// 42 /// ``` 43 /// # use libceed::prelude::*; 44 /// # fn main() -> libceed::Result<()> { 45 /// # let ceed = libceed::Ceed::default_init(); 46 /// const Q: usize = 8; 47 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 48 /// 49 /// let inputs = qf.inputs()?; 50 /// 51 /// assert_eq!(inputs[0].name(), "dx", "Incorrect input name"); 52 /// assert_eq!(inputs[1].name(), "weights", "Incorrect input name"); 53 /// # Ok(()) 54 /// # } 55 /// ``` 56 pub fn name(&self) -> &str { 57 let mut name_ptr: *mut std::os::raw::c_char = std::ptr::null_mut(); 58 unsafe { 59 bind_ceed::CeedQFunctionFieldGetName(self.ptr, &mut name_ptr); 60 } 61 unsafe { CStr::from_ptr(name_ptr) }.to_str().unwrap() 62 } 63 64 /// Get the size of a QFunctionField 65 /// 66 /// ``` 67 /// # use libceed::prelude::*; 68 /// # fn main() -> libceed::Result<()> { 69 /// # let ceed = libceed::Ceed::default_init(); 70 /// const Q: usize = 8; 71 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 72 /// 73 /// let inputs = qf.inputs()?; 74 /// 75 /// assert_eq!(inputs[0].size(), 4, "Incorrect input size"); 76 /// assert_eq!(inputs[1].size(), 1, "Incorrect input size"); 77 /// # Ok(()) 78 /// # } 79 /// ``` 80 pub fn size(&self) -> usize { 81 let mut size = 0; 82 unsafe { 83 bind_ceed::CeedQFunctionFieldGetSize(self.ptr, &mut size); 84 } 85 usize::try_from(size).unwrap() 86 } 87 88 /// Get the evaluation mode of a QFunctionField 89 /// 90 /// ``` 91 /// # use libceed::prelude::*; 92 /// # fn main() -> libceed::Result<()> { 93 /// # let ceed = libceed::Ceed::default_init(); 94 /// const Q: usize = 8; 95 /// let qf = ceed.q_function_interior_by_name("Mass2DBuild")?; 96 /// 97 /// let inputs = qf.inputs()?; 98 /// 99 /// assert_eq!( 100 /// inputs[0].eval_mode(), 101 /// EvalMode::Grad, 102 /// "Incorrect input evaluation mode" 103 /// ); 104 /// assert_eq!( 105 /// inputs[1].eval_mode(), 106 /// EvalMode::Weight, 107 /// "Incorrect input evaluation mode" 108 /// ); 109 /// # Ok(()) 110 /// # } 111 /// ``` 112 pub fn eval_mode(&self) -> crate::EvalMode { 113 let mut mode = 0; 114 unsafe { 115 bind_ceed::CeedQFunctionFieldGetEvalMode(self.ptr, &mut mode); 116 } 117 crate::EvalMode::from_u32(mode as u32) 118 } 119 } 120 121 // ----------------------------------------------------------------------------- 122 // CeedQFunction option 123 // ----------------------------------------------------------------------------- 124 pub enum QFunctionOpt<'a> { 125 SomeQFunction(&'a QFunction<'a>), 126 SomeQFunctionByName(&'a QFunctionByName<'a>), 127 None, 128 } 129 130 /// Construct a QFunctionOpt reference from a QFunction reference 131 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> { 132 fn from(qfunc: &'a QFunction) -> Self { 133 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 134 Self::SomeQFunction(qfunc) 135 } 136 } 137 138 /// Construct a QFunctionOpt reference from a QFunction by Name reference 139 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> { 140 fn from(qfunc: &'a QFunctionByName) -> Self { 141 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 142 Self::SomeQFunctionByName(qfunc) 143 } 144 } 145 146 impl<'a> QFunctionOpt<'a> { 147 /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction 148 pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction { 149 match self { 150 Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr, 151 Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr, 152 Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE }, 153 } 154 } 155 156 /// Check if a QFunctionOpt is Some 157 /// 158 /// ``` 159 /// # use libceed::prelude::*; 160 /// # fn main() -> libceed::Result<()> { 161 /// # let ceed = libceed::Ceed::default_init(); 162 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 163 /// // Iterate over quadrature points 164 /// v.iter_mut() 165 /// .zip(u.iter().zip(weights.iter())) 166 /// .for_each(|(v, (u, w))| *v = u * w); 167 /// 168 /// // Return clean error code 169 /// 0 170 /// }; 171 /// 172 /// let qf = ceed 173 /// .q_function_interior(1, Box::new(user_f))? 174 /// .input("u", 1, EvalMode::Interp)? 175 /// .input("weights", 1, EvalMode::Weight)? 176 /// .output("v", 1, EvalMode::Interp)?; 177 /// let qf_opt = QFunctionOpt::from(&qf); 178 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt"); 179 /// 180 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 181 /// let qf_opt = QFunctionOpt::from(&qf); 182 /// assert!(qf_opt.is_some(), "Incorrect QFunctionOpt"); 183 /// 184 /// let qf_opt = QFunctionOpt::None; 185 /// assert!(!qf_opt.is_some(), "Incorrect QFunctionOpt"); 186 /// # Ok(()) 187 /// # } 188 /// ``` 189 pub fn is_some(&self) -> bool { 190 match self { 191 Self::SomeQFunction(_) => true, 192 Self::SomeQFunctionByName(_) => true, 193 Self::None => false, 194 } 195 } 196 197 /// Check if a QFunctionOpt is SomeQFunction 198 /// 199 /// ``` 200 /// # use libceed::prelude::*; 201 /// # fn main() -> libceed::Result<()> { 202 /// # let ceed = libceed::Ceed::default_init(); 203 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 204 /// // Iterate over quadrature points 205 /// v.iter_mut() 206 /// .zip(u.iter().zip(weights.iter())) 207 /// .for_each(|(v, (u, w))| *v = u * w); 208 /// 209 /// // Return clean error code 210 /// 0 211 /// }; 212 /// 213 /// let qf = ceed 214 /// .q_function_interior(1, Box::new(user_f))? 215 /// .input("u", 1, EvalMode::Interp)? 216 /// .input("weights", 1, EvalMode::Weight)? 217 /// .output("v", 1, EvalMode::Interp)?; 218 /// let qf_opt = QFunctionOpt::from(&qf); 219 /// assert!(qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 220 /// 221 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 222 /// let qf_opt = QFunctionOpt::from(&qf); 223 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 224 /// 225 /// let qf_opt = QFunctionOpt::None; 226 /// assert!(!qf_opt.is_some_q_function(), "Incorrect QFunctionOpt"); 227 /// # Ok(()) 228 /// # } 229 /// ``` 230 pub fn is_some_q_function(&self) -> bool { 231 match self { 232 Self::SomeQFunction(_) => true, 233 Self::SomeQFunctionByName(_) => false, 234 Self::None => false, 235 } 236 } 237 238 /// Check if a QFunctionOpt is SomeQFunctionByName 239 /// 240 /// ``` 241 /// # use libceed::prelude::*; 242 /// # fn main() -> libceed::Result<()> { 243 /// # let ceed = libceed::Ceed::default_init(); 244 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 245 /// // Iterate over quadrature points 246 /// v.iter_mut() 247 /// .zip(u.iter().zip(weights.iter())) 248 /// .for_each(|(v, (u, w))| *v = u * w); 249 /// 250 /// // Return clean error code 251 /// 0 252 /// }; 253 /// 254 /// let qf = ceed 255 /// .q_function_interior(1, Box::new(user_f))? 256 /// .input("u", 1, EvalMode::Interp)? 257 /// .input("weights", 1, EvalMode::Weight)? 258 /// .output("v", 1, EvalMode::Interp)?; 259 /// let qf_opt = QFunctionOpt::from(&qf); 260 /// assert!( 261 /// !qf_opt.is_some_q_function_by_name(), 262 /// "Incorrect QFunctionOpt" 263 /// ); 264 /// 265 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 266 /// let qf_opt = QFunctionOpt::from(&qf); 267 /// assert!( 268 /// qf_opt.is_some_q_function_by_name(), 269 /// "Incorrect QFunctionOpt" 270 /// ); 271 /// 272 /// let qf_opt = QFunctionOpt::None; 273 /// assert!( 274 /// !qf_opt.is_some_q_function_by_name(), 275 /// "Incorrect QFunctionOpt" 276 /// ); 277 /// # Ok(()) 278 /// # } 279 /// ``` 280 pub fn is_some_q_function_by_name(&self) -> bool { 281 match self { 282 Self::SomeQFunction(_) => false, 283 Self::SomeQFunctionByName(_) => true, 284 Self::None => false, 285 } 286 } 287 288 /// Check if a QFunctionOpt is None 289 /// 290 /// ``` 291 /// # use libceed::prelude::*; 292 /// # fn main() -> libceed::Result<()> { 293 /// # let ceed = libceed::Ceed::default_init(); 294 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 295 /// // Iterate over quadrature points 296 /// v.iter_mut() 297 /// .zip(u.iter().zip(weights.iter())) 298 /// .for_each(|(v, (u, w))| *v = u * w); 299 /// 300 /// // Return clean error code 301 /// 0 302 /// }; 303 /// 304 /// let qf = ceed 305 /// .q_function_interior(1, Box::new(user_f))? 306 /// .input("u", 1, EvalMode::Interp)? 307 /// .input("weights", 1, EvalMode::Weight)? 308 /// .output("v", 1, EvalMode::Interp)?; 309 /// let qf_opt = QFunctionOpt::from(&qf); 310 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt"); 311 /// 312 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 313 /// let qf_opt = QFunctionOpt::from(&qf); 314 /// assert!(!qf_opt.is_none(), "Incorrect QFunctionOpt"); 315 /// 316 /// let qf_opt = QFunctionOpt::None; 317 /// assert!(qf_opt.is_none(), "Incorrect QFunctionOpt"); 318 /// # Ok(()) 319 /// # } 320 /// ``` 321 pub fn is_none(&self) -> bool { 322 match self { 323 Self::SomeQFunction(_) => false, 324 Self::SomeQFunctionByName(_) => false, 325 Self::None => true, 326 } 327 } 328 } 329 330 // ----------------------------------------------------------------------------- 331 // CeedQFunction context wrapper 332 // ----------------------------------------------------------------------------- 333 #[derive(Debug)] 334 pub(crate) struct QFunctionCore<'a> { 335 ptr: bind_ceed::CeedQFunction, 336 _lifeline: PhantomData<&'a ()>, 337 } 338 339 struct QFunctionTrampolineData { 340 number_inputs: usize, 341 number_outputs: usize, 342 input_sizes: [usize; MAX_QFUNCTION_FIELDS], 343 output_sizes: [usize; MAX_QFUNCTION_FIELDS], 344 user_f: Box<QFunctionUserClosure>, 345 } 346 347 pub struct QFunction<'a> { 348 qf_core: QFunctionCore<'a>, 349 qf_ctx_ptr: bind_ceed::CeedQFunctionContext, 350 trampoline_data: Pin<Box<QFunctionTrampolineData>>, 351 } 352 353 #[derive(Debug)] 354 pub struct QFunctionByName<'a> { 355 qf_core: QFunctionCore<'a>, 356 } 357 358 // ----------------------------------------------------------------------------- 359 // Destructor 360 // ----------------------------------------------------------------------------- 361 impl<'a> Drop for QFunctionCore<'a> { 362 fn drop(&mut self) { 363 unsafe { 364 if self.ptr != bind_ceed::CEED_QFUNCTION_NONE { 365 bind_ceed::CeedQFunctionDestroy(&mut self.ptr); 366 } 367 } 368 } 369 } 370 371 impl<'a> Drop for QFunction<'a> { 372 fn drop(&mut self) { 373 unsafe { 374 bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr); 375 } 376 } 377 } 378 379 // ----------------------------------------------------------------------------- 380 // Display 381 // ----------------------------------------------------------------------------- 382 impl<'a> fmt::Display for QFunctionCore<'a> { 383 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 384 let mut ptr = std::ptr::null_mut(); 385 let mut sizeloc = crate::MAX_BUFFER_LENGTH; 386 let cstring = unsafe { 387 let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc); 388 bind_ceed::CeedQFunctionView(self.ptr, file); 389 bind_ceed::fclose(file); 390 CString::from_raw(ptr) 391 }; 392 cstring.to_string_lossy().fmt(f) 393 } 394 } 395 /// View a QFunction 396 /// 397 /// ``` 398 /// # use libceed::prelude::*; 399 /// # fn main() -> libceed::Result<()> { 400 /// # let ceed = libceed::Ceed::default_init(); 401 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 402 /// // Iterate over quadrature points 403 /// v.iter_mut() 404 /// .zip(u.iter().zip(weights.iter())) 405 /// .for_each(|(v, (u, w))| *v = u * w); 406 /// 407 /// // Return clean error code 408 /// 0 409 /// }; 410 /// 411 /// let qf = ceed 412 /// .q_function_interior(1, Box::new(user_f))? 413 /// .input("u", 1, EvalMode::Interp)? 414 /// .input("weights", 1, EvalMode::Weight)? 415 /// .output("v", 1, EvalMode::Interp)?; 416 /// 417 /// println!("{}", qf); 418 /// # Ok(()) 419 /// # } 420 /// ``` 421 impl<'a> fmt::Display for QFunction<'a> { 422 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 423 self.qf_core.fmt(f) 424 } 425 } 426 427 /// View a QFunction by Name 428 /// 429 /// ``` 430 /// # use libceed::prelude::*; 431 /// # fn main() -> libceed::Result<()> { 432 /// # let ceed = libceed::Ceed::default_init(); 433 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 434 /// println!("{}", qf); 435 /// # Ok(()) 436 /// # } 437 /// ``` 438 impl<'a> fmt::Display for QFunctionByName<'a> { 439 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 440 self.qf_core.fmt(f) 441 } 442 } 443 444 // ----------------------------------------------------------------------------- 445 // Core functionality 446 // ----------------------------------------------------------------------------- 447 impl<'a> QFunctionCore<'a> { 448 // Error handling 449 #[doc(hidden)] 450 fn check_error(&self, ierr: i32) -> crate::Result<i32> { 451 let mut ptr = std::ptr::null_mut(); 452 unsafe { 453 bind_ceed::CeedQFunctionGetCeed(self.ptr, &mut ptr); 454 } 455 crate::check_error(ptr, ierr) 456 } 457 458 // Common implementation 459 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 460 let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 461 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) { 462 u_c[i] = u[i].ptr; 463 } 464 let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 465 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) { 466 v_c[i] = v[i].ptr; 467 } 468 let Q = i32::try_from(Q).unwrap(); 469 let ierr = unsafe { 470 bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr()) 471 }; 472 self.check_error(ierr) 473 } 474 475 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 476 // Get array of raw C pointers for inputs 477 let mut num_inputs = 0; 478 let mut inputs_ptr = std::ptr::null_mut(); 479 let ierr = unsafe { 480 bind_ceed::CeedQFunctionGetFields( 481 self.ptr, 482 &mut num_inputs, 483 &mut inputs_ptr, 484 std::ptr::null_mut() as *mut bind_ceed::CeedInt, 485 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField, 486 ) 487 }; 488 self.check_error(ierr)?; 489 // Convert raw C pointers to fixed length slice 490 let inputs_slice = unsafe { 491 std::slice::from_raw_parts( 492 inputs_ptr as *const crate::QFunctionField, 493 num_inputs as usize, 494 ) 495 }; 496 Ok(inputs_slice) 497 } 498 499 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 500 // Get array of raw C pointers for outputs 501 let mut num_outputs = 0; 502 let mut outputs_ptr = std::ptr::null_mut(); 503 let ierr = unsafe { 504 bind_ceed::CeedQFunctionGetFields( 505 self.ptr, 506 std::ptr::null_mut() as *mut bind_ceed::CeedInt, 507 std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField, 508 &mut num_outputs, 509 &mut outputs_ptr, 510 ) 511 }; 512 self.check_error(ierr)?; 513 // Convert raw C pointers to fixed length slice 514 let outputs_slice = unsafe { 515 std::slice::from_raw_parts( 516 outputs_ptr as *const crate::QFunctionField, 517 num_outputs as usize, 518 ) 519 }; 520 Ok(outputs_slice) 521 } 522 } 523 524 // ----------------------------------------------------------------------------- 525 // User QFunction Closure 526 // ----------------------------------------------------------------------------- 527 pub type QFunctionUserClosure = dyn FnMut( 528 [&[crate::Scalar]; MAX_QFUNCTION_FIELDS], 529 [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS], 530 ) -> i32; 531 532 macro_rules! mut_max_fields { 533 ($e:expr) => { 534 [ 535 $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, 536 ] 537 }; 538 } 539 unsafe extern "C" fn trampoline( 540 ctx: *mut ::std::os::raw::c_void, 541 q: bind_ceed::CeedInt, 542 inputs: *const *const bind_ceed::CeedScalar, 543 outputs: *const *mut bind_ceed::CeedScalar, 544 ) -> ::std::os::raw::c_int { 545 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx); 546 547 // Inputs 548 let inputs_slice: &[*const bind_ceed::CeedScalar] = 549 std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS); 550 let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS]; 551 inputs_slice 552 .iter() 553 .enumerate() 554 .map(|(i, &x)| { 555 std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize) 556 as &[crate::Scalar] 557 }) 558 .zip(inputs_array.iter_mut()) 559 .for_each(|(x, a)| *a = x); 560 561 // Outputs 562 let outputs_slice: &[*mut bind_ceed::CeedScalar] = 563 std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS); 564 let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] = 565 mut_max_fields!(&mut [0.0]); 566 outputs_slice 567 .iter() 568 .enumerate() 569 .map(|(i, &x)| { 570 std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize) 571 as &mut [crate::Scalar] 572 }) 573 .zip(outputs_array.iter_mut()) 574 .for_each(|(x, a)| *a = x); 575 576 // User closure 577 (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array) 578 } 579 580 // ----------------------------------------------------------------------------- 581 // QFunction 582 // ----------------------------------------------------------------------------- 583 impl<'a> QFunction<'a> { 584 // Constructor 585 pub fn create( 586 ceed: &crate::Ceed, 587 vlength: usize, 588 user_f: Box<QFunctionUserClosure>, 589 ) -> crate::Result<Self> { 590 let source_c = CString::new("").expect("CString::new failed"); 591 let mut ptr = std::ptr::null_mut(); 592 593 // Context for closure 594 let number_inputs = 0; 595 let number_outputs = 0; 596 let input_sizes = [0; MAX_QFUNCTION_FIELDS]; 597 let output_sizes = [0; MAX_QFUNCTION_FIELDS]; 598 let trampoline_data = unsafe { 599 Pin::new_unchecked(Box::new(QFunctionTrampolineData { 600 number_inputs, 601 number_outputs, 602 input_sizes, 603 output_sizes, 604 user_f, 605 })) 606 }; 607 608 // Create QFunction 609 let vlength = i32::try_from(vlength).unwrap(); 610 let mut ierr = unsafe { 611 bind_ceed::CeedQFunctionCreateInterior( 612 ceed.ptr, 613 vlength, 614 Some(trampoline), 615 source_c.as_ptr(), 616 &mut ptr, 617 ) 618 }; 619 ceed.check_error(ierr)?; 620 621 // Set closure 622 let mut qf_ctx_ptr = std::ptr::null_mut(); 623 ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) }; 624 ceed.check_error(ierr)?; 625 ierr = unsafe { 626 bind_ceed::CeedQFunctionContextSetData( 627 qf_ctx_ptr, 628 crate::MemType::Host as bind_ceed::CeedMemType, 629 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode, 630 std::mem::size_of::<QFunctionTrampolineData>() as u64, 631 std::mem::transmute(trampoline_data.as_ref()), 632 ) 633 }; 634 ceed.check_error(ierr)?; 635 ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) }; 636 ceed.check_error(ierr)?; 637 Ok(Self { 638 qf_core: QFunctionCore { 639 ptr, 640 _lifeline: PhantomData, 641 }, 642 qf_ctx_ptr, 643 trampoline_data, 644 }) 645 } 646 647 /// Apply the action of a QFunction 648 /// 649 /// * `Q` - The number of quadrature points 650 /// * `input` - Array of input Vectors 651 /// * `output` - Array of output Vectors 652 /// 653 /// ``` 654 /// # use libceed::prelude::*; 655 /// # fn main() -> libceed::Result<()> { 656 /// # let ceed = libceed::Ceed::default_init(); 657 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 658 /// // Iterate over quadrature points 659 /// v.iter_mut() 660 /// .zip(u.iter().zip(weights.iter())) 661 /// .for_each(|(v, (u, w))| *v = u * w); 662 /// 663 /// // Return clean error code 664 /// 0 665 /// }; 666 /// 667 /// let qf = ceed 668 /// .q_function_interior(1, Box::new(user_f))? 669 /// .input("u", 1, EvalMode::Interp)? 670 /// .input("weights", 1, EvalMode::Weight)? 671 /// .output("v", 1, EvalMode::Interp)?; 672 /// 673 /// const Q: usize = 8; 674 /// let mut w = [0.; Q]; 675 /// let mut u = [0.; Q]; 676 /// let mut v = [0.; Q]; 677 /// 678 /// for i in 0..Q { 679 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 680 /// u[i] = 2. + 3. * x + 5. * x * x; 681 /// w[i] = 1. - x * x; 682 /// v[i] = u[i] * w[i]; 683 /// } 684 /// 685 /// let uu = ceed.vector_from_slice(&u)?; 686 /// let ww = ceed.vector_from_slice(&w)?; 687 /// let mut vv = ceed.vector(Q)?; 688 /// vv.set_value(0.0); 689 /// { 690 /// let input = vec![uu, ww]; 691 /// let mut output = vec![vv]; 692 /// qf.apply(Q, &input, &output)?; 693 /// vv = output.remove(0); 694 /// } 695 /// 696 /// vv.view()? 697 /// .iter() 698 /// .zip(v.iter()) 699 /// .for_each(|(computed, actual)| { 700 /// assert_eq!( 701 /// *computed, *actual, 702 /// "Incorrect value in QFunction application" 703 /// ); 704 /// }); 705 /// # Ok(()) 706 /// # } 707 /// ``` 708 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 709 self.qf_core.apply(Q, u, v) 710 } 711 712 /// Add a QFunction input 713 /// 714 /// * `fieldname` - Name of QFunction field 715 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 716 /// `(ncomp * 1)` for `None`, `Interp`, and `Weight` 717 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 718 /// to use interpolated values, `EvalMode::Grad` to use 719 /// gradients, `EvalMode::Weight` to use quadrature weights 720 /// 721 /// ``` 722 /// # use libceed::prelude::*; 723 /// # fn main() -> libceed::Result<()> { 724 /// # let ceed = libceed::Ceed::default_init(); 725 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 726 /// // Iterate over quadrature points 727 /// v.iter_mut() 728 /// .zip(u.iter().zip(weights.iter())) 729 /// .for_each(|(v, (u, w))| *v = u * w); 730 /// 731 /// // Return clean error code 732 /// 0 733 /// }; 734 /// 735 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 736 /// 737 /// qf = qf.input("u", 1, EvalMode::Interp)?; 738 /// qf = qf.input("weights", 1, EvalMode::Weight)?; 739 /// # Ok(()) 740 /// # } 741 /// ``` 742 pub fn input( 743 mut self, 744 fieldname: &str, 745 size: usize, 746 emode: crate::EvalMode, 747 ) -> crate::Result<Self> { 748 let name_c = CString::new(fieldname).expect("CString::new failed"); 749 let idx = self.trampoline_data.number_inputs; 750 self.trampoline_data.input_sizes[idx] = size; 751 self.trampoline_data.number_inputs += 1; 752 let (size, emode) = ( 753 i32::try_from(size).unwrap(), 754 emode as bind_ceed::CeedEvalMode, 755 ); 756 let ierr = unsafe { 757 bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 758 }; 759 self.qf_core.check_error(ierr)?; 760 Ok(self) 761 } 762 763 /// Add a QFunction output 764 /// 765 /// * `fieldname` - Name of QFunction field 766 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 767 /// `(ncomp * 1)` for `None` and `Interp` 768 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 769 /// to use interpolated values, `EvalMode::Grad` to use 770 /// gradients 771 /// 772 /// ``` 773 /// # use libceed::prelude::*; 774 /// # fn main() -> libceed::Result<()> { 775 /// # let ceed = libceed::Ceed::default_init(); 776 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 777 /// // Iterate over quadrature points 778 /// v.iter_mut() 779 /// .zip(u.iter().zip(weights.iter())) 780 /// .for_each(|(v, (u, w))| *v = u * w); 781 /// 782 /// // Return clean error code 783 /// 0 784 /// }; 785 /// 786 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 787 /// 788 /// qf.output("v", 1, EvalMode::Interp)?; 789 /// # Ok(()) 790 /// # } 791 /// ``` 792 pub fn output( 793 mut self, 794 fieldname: &str, 795 size: usize, 796 emode: crate::EvalMode, 797 ) -> crate::Result<Self> { 798 let name_c = CString::new(fieldname).expect("CString::new failed"); 799 let idx = self.trampoline_data.number_outputs; 800 self.trampoline_data.output_sizes[idx] = size; 801 self.trampoline_data.number_outputs += 1; 802 let (size, emode) = ( 803 i32::try_from(size).unwrap(), 804 emode as bind_ceed::CeedEvalMode, 805 ); 806 let ierr = unsafe { 807 bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 808 }; 809 self.qf_core.check_error(ierr)?; 810 Ok(self) 811 } 812 813 /// Get a slice of QFunction inputs 814 /// 815 /// ``` 816 /// # use libceed::prelude::*; 817 /// # fn main() -> libceed::Result<()> { 818 /// # let ceed = libceed::Ceed::default_init(); 819 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 820 /// // Iterate over quadrature points 821 /// v.iter_mut() 822 /// .zip(u.iter().zip(weights.iter())) 823 /// .for_each(|(v, (u, w))| *v = u * w); 824 /// 825 /// // Return clean error code 826 /// 0 827 /// }; 828 /// 829 /// let mut qf = ceed 830 /// .q_function_interior(1, Box::new(user_f))? 831 /// .input("u", 1, EvalMode::Interp)? 832 /// .input("weights", 1, EvalMode::Weight)?; 833 /// 834 /// let inputs = qf.inputs()?; 835 /// 836 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array"); 837 /// # Ok(()) 838 /// # } 839 /// ``` 840 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 841 self.qf_core.inputs() 842 } 843 844 /// Get a slice of QFunction outputs 845 /// 846 /// ``` 847 /// # use libceed::prelude::*; 848 /// # fn main() -> libceed::Result<()> { 849 /// # let ceed = libceed::Ceed::default_init(); 850 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 851 /// // Iterate over quadrature points 852 /// v.iter_mut() 853 /// .zip(u.iter().zip(weights.iter())) 854 /// .for_each(|(v, (u, w))| *v = u * w); 855 /// 856 /// // Return clean error code 857 /// 0 858 /// }; 859 /// 860 /// let mut qf = ceed 861 /// .q_function_interior(1, Box::new(user_f))? 862 /// .output("v", 1, EvalMode::Interp)?; 863 /// 864 /// let outputs = qf.outputs()?; 865 /// 866 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array"); 867 /// # Ok(()) 868 /// # } 869 /// ``` 870 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 871 self.qf_core.outputs() 872 } 873 } 874 875 // ----------------------------------------------------------------------------- 876 // QFunction 877 // ----------------------------------------------------------------------------- 878 impl<'a> QFunctionByName<'a> { 879 // Constructor 880 pub fn create(ceed: &crate::Ceed, name: &str) -> crate::Result<Self> { 881 let name_c = CString::new(name).expect("CString::new failed"); 882 let mut ptr = std::ptr::null_mut(); 883 let ierr = unsafe { 884 bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr) 885 }; 886 ceed.check_error(ierr)?; 887 Ok(Self { 888 qf_core: QFunctionCore { 889 ptr, 890 _lifeline: PhantomData, 891 }, 892 }) 893 } 894 895 /// Apply the action of a QFunction 896 /// 897 /// * `Q` - The number of quadrature points 898 /// * `input` - Array of input Vectors 899 /// * `output` - Array of output Vectors 900 /// 901 /// ``` 902 /// # use libceed::prelude::*; 903 /// # fn main() -> libceed::Result<()> { 904 /// # let ceed = libceed::Ceed::default_init(); 905 /// const Q: usize = 8; 906 /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild")?; 907 /// let qf_mass = ceed.q_function_interior_by_name("MassApply")?; 908 /// 909 /// let mut j = [0.; Q]; 910 /// let mut w = [0.; Q]; 911 /// let mut u = [0.; Q]; 912 /// let mut v = [0.; Q]; 913 /// 914 /// for i in 0..Q { 915 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 916 /// j[i] = 1.; 917 /// w[i] = 1. - x * x; 918 /// u[i] = 2. + 3. * x + 5. * x * x; 919 /// v[i] = w[i] * u[i]; 920 /// } 921 /// 922 /// let jj = ceed.vector_from_slice(&j)?; 923 /// let ww = ceed.vector_from_slice(&w)?; 924 /// let uu = ceed.vector_from_slice(&u)?; 925 /// let mut vv = ceed.vector(Q)?; 926 /// vv.set_value(0.0); 927 /// let mut qdata = ceed.vector(Q)?; 928 /// qdata.set_value(0.0); 929 /// 930 /// { 931 /// let mut input = vec![jj, ww]; 932 /// let mut output = vec![qdata]; 933 /// qf_build.apply(Q, &input, &output)?; 934 /// qdata = output.remove(0); 935 /// } 936 /// 937 /// { 938 /// let mut input = vec![qdata, uu]; 939 /// let mut output = vec![vv]; 940 /// qf_mass.apply(Q, &input, &output)?; 941 /// vv = output.remove(0); 942 /// } 943 /// 944 /// vv.view()? 945 /// .iter() 946 /// .zip(v.iter()) 947 /// .for_each(|(computed, actual)| { 948 /// assert_eq!( 949 /// *computed, *actual, 950 /// "Incorrect value in QFunction application" 951 /// ); 952 /// }); 953 /// # Ok(()) 954 /// # } 955 /// ``` 956 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 957 self.qf_core.apply(Q, u, v) 958 } 959 960 /// Get a slice of QFunction inputs 961 /// 962 /// ``` 963 /// # use libceed::prelude::*; 964 /// # fn main() -> libceed::Result<()> { 965 /// # let ceed = libceed::Ceed::default_init(); 966 /// const Q: usize = 8; 967 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 968 /// 969 /// let inputs = qf.inputs()?; 970 /// 971 /// assert_eq!(inputs.len(), 2, "Incorrect inputs array"); 972 /// # Ok(()) 973 /// # } 974 /// ``` 975 pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> { 976 self.qf_core.inputs() 977 } 978 979 /// Get a slice of QFunction outputs 980 /// 981 /// ``` 982 /// # use libceed::prelude::*; 983 /// # fn main() -> libceed::Result<()> { 984 /// # let ceed = libceed::Ceed::default_init(); 985 /// const Q: usize = 8; 986 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 987 /// 988 /// let outputs = qf.outputs()?; 989 /// 990 /// assert_eq!(outputs.len(), 1, "Incorrect outputs array"); 991 /// # Ok(()) 992 /// # } 993 /// ``` 994 pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> { 995 self.qf_core.outputs() 996 } 997 } 998 999 // ----------------------------------------------------------------------------- 1000