1 // Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at 2 // the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights 3 // reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative 16 17 //! A Ceed QFunction represents the spatial terms of the point-wise functions 18 //! describing the physics at the quadrature points. 19 20 use std::pin::Pin; 21 22 use crate::prelude::*; 23 24 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 25 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 26 27 // ----------------------------------------------------------------------------- 28 // CeedQFunction option 29 // ----------------------------------------------------------------------------- 30 pub enum QFunctionOpt<'a> { 31 SomeQFunction(&'a QFunction<'a>), 32 SomeQFunctionByName(&'a QFunctionByName<'a>), 33 None, 34 } 35 36 /// Construct a QFunctionOpt reference from a QFunction reference 37 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> { 38 fn from(qfunc: &'a QFunction) -> Self { 39 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 40 Self::SomeQFunction(qfunc) 41 } 42 } 43 44 /// Construct a QFunctionOpt reference from a QFunction by Name reference 45 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> { 46 fn from(qfunc: &'a QFunctionByName) -> Self { 47 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 48 Self::SomeQFunctionByName(qfunc) 49 } 50 } 51 52 impl<'a> QFunctionOpt<'a> { 53 /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction 54 pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction { 55 match self { 56 Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr, 57 Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr, 58 Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE }, 59 } 60 } 61 } 62 63 // ----------------------------------------------------------------------------- 64 // CeedQFunction context wrapper 65 // ----------------------------------------------------------------------------- 66 #[derive(Debug)] 67 pub(crate) struct QFunctionCore<'a> { 68 ceed: &'a crate::Ceed, 69 ptr: bind_ceed::CeedQFunction, 70 } 71 72 struct QFunctionTrampolineData { 73 number_inputs: usize, 74 number_outputs: usize, 75 input_sizes: [usize; MAX_QFUNCTION_FIELDS], 76 output_sizes: [usize; MAX_QFUNCTION_FIELDS], 77 user_f: Box<QFunctionUserClosure>, 78 } 79 80 pub struct QFunction<'a> { 81 qf_core: QFunctionCore<'a>, 82 qf_ctx_ptr: bind_ceed::CeedQFunctionContext, 83 trampoline_data: Pin<Box<QFunctionTrampolineData>>, 84 } 85 86 #[derive(Debug)] 87 pub struct QFunctionByName<'a> { 88 qf_core: QFunctionCore<'a>, 89 } 90 91 // ----------------------------------------------------------------------------- 92 // Destructor 93 // ----------------------------------------------------------------------------- 94 impl<'a> Drop for QFunctionCore<'a> { 95 fn drop(&mut self) { 96 unsafe { 97 if self.ptr != bind_ceed::CEED_QFUNCTION_NONE { 98 bind_ceed::CeedQFunctionDestroy(&mut self.ptr); 99 } 100 } 101 } 102 } 103 104 impl<'a> Drop for QFunction<'a> { 105 fn drop(&mut self) { 106 unsafe { 107 bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr); 108 } 109 } 110 } 111 112 // ----------------------------------------------------------------------------- 113 // Display 114 // ----------------------------------------------------------------------------- 115 impl<'a> fmt::Display for QFunctionCore<'a> { 116 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 117 let mut ptr = std::ptr::null_mut(); 118 let mut sizeloc = crate::MAX_BUFFER_LENGTH; 119 let cstring = unsafe { 120 let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc); 121 bind_ceed::CeedQFunctionView(self.ptr, file); 122 bind_ceed::fclose(file); 123 CString::from_raw(ptr) 124 }; 125 cstring.to_string_lossy().fmt(f) 126 } 127 } 128 /// View a QFunction 129 /// 130 /// ``` 131 /// # use libceed::prelude::*; 132 /// # fn main() -> Result<(), libceed::CeedError> { 133 /// # let ceed = libceed::Ceed::default_init(); 134 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 135 /// // Iterate over quadrature points 136 /// v.iter_mut() 137 /// .zip(u.iter().zip(weights.iter())) 138 /// .for_each(|(v, (u, w))| *v = u * w); 139 /// 140 /// // Return clean error code 141 /// 0 142 /// }; 143 /// 144 /// let qf = ceed 145 /// .q_function_interior(1, Box::new(user_f))? 146 /// .input("u", 1, EvalMode::Interp)? 147 /// .input("weights", 1, EvalMode::Weight)? 148 /// .output("v", 1, EvalMode::Interp)?; 149 /// 150 /// println!("{}", qf); 151 /// # Ok(()) 152 /// # } 153 /// ``` 154 impl<'a> fmt::Display for QFunction<'a> { 155 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 156 self.qf_core.fmt(f) 157 } 158 } 159 160 /// View a QFunction by Name 161 /// 162 /// ``` 163 /// # use libceed::prelude::*; 164 /// # fn main() -> Result<(), libceed::CeedError> { 165 /// # let ceed = libceed::Ceed::default_init(); 166 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild")?; 167 /// println!("{}", qf); 168 /// # Ok(()) 169 /// # } 170 /// ``` 171 impl<'a> fmt::Display for QFunctionByName<'a> { 172 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 173 self.qf_core.fmt(f) 174 } 175 } 176 177 // ----------------------------------------------------------------------------- 178 // Core functionality 179 // ----------------------------------------------------------------------------- 180 impl<'a> QFunctionCore<'a> { 181 // Common implementation 182 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 183 let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 184 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) { 185 u_c[i] = u[i].ptr; 186 } 187 let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 188 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) { 189 v_c[i] = v[i].ptr; 190 } 191 let Q = i32::try_from(Q).unwrap(); 192 let ierr = unsafe { 193 bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr()) 194 }; 195 self.ceed.check_error(ierr) 196 } 197 } 198 199 // ----------------------------------------------------------------------------- 200 // User QFunction Closure 201 // ----------------------------------------------------------------------------- 202 pub type QFunctionUserClosure = dyn FnMut( 203 [&[crate::Scalar]; MAX_QFUNCTION_FIELDS], 204 [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS], 205 ) -> i32; 206 207 macro_rules! mut_max_fields { 208 ($e:expr) => { 209 [ 210 $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, 211 ] 212 }; 213 } 214 unsafe extern "C" fn trampoline( 215 ctx: *mut ::std::os::raw::c_void, 216 q: bind_ceed::CeedInt, 217 inputs: *const *const bind_ceed::CeedScalar, 218 outputs: *const *mut bind_ceed::CeedScalar, 219 ) -> ::std::os::raw::c_int { 220 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx); 221 222 // Inputs 223 let inputs_slice: &[*const bind_ceed::CeedScalar] = 224 std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS); 225 let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS]; 226 inputs_slice 227 .iter() 228 .enumerate() 229 .map(|(i, &x)| { 230 std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize) 231 as &[crate::Scalar] 232 }) 233 .zip(inputs_array.iter_mut()) 234 .for_each(|(x, a)| *a = x); 235 236 // Outputs 237 let outputs_slice: &[*mut bind_ceed::CeedScalar] = 238 std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS); 239 let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] = 240 mut_max_fields!(&mut [0.0]); 241 outputs_slice 242 .iter() 243 .enumerate() 244 .map(|(i, &x)| { 245 std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize) 246 as &mut [crate::Scalar] 247 }) 248 .zip(outputs_array.iter_mut()) 249 .for_each(|(x, a)| *a = x); 250 251 // User closure 252 (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array) 253 } 254 255 // ----------------------------------------------------------------------------- 256 // QFunction 257 // ----------------------------------------------------------------------------- 258 impl<'a> QFunction<'a> { 259 // Constructor 260 pub fn create( 261 ceed: &'a crate::Ceed, 262 vlength: usize, 263 user_f: Box<QFunctionUserClosure>, 264 ) -> crate::Result<Self> { 265 let source_c = CString::new("").expect("CString::new failed"); 266 let mut ptr = std::ptr::null_mut(); 267 268 // Context for closure 269 let number_inputs = 0; 270 let number_outputs = 0; 271 let input_sizes = [0; MAX_QFUNCTION_FIELDS]; 272 let output_sizes = [0; MAX_QFUNCTION_FIELDS]; 273 let trampoline_data = unsafe { 274 Pin::new_unchecked(Box::new(QFunctionTrampolineData { 275 number_inputs, 276 number_outputs, 277 input_sizes, 278 output_sizes, 279 user_f, 280 })) 281 }; 282 283 // Create QFunction 284 let vlength = i32::try_from(vlength).unwrap(); 285 let mut ierr = unsafe { 286 bind_ceed::CeedQFunctionCreateInterior( 287 ceed.ptr, 288 vlength, 289 Some(trampoline), 290 source_c.as_ptr(), 291 &mut ptr, 292 ) 293 }; 294 ceed.check_error(ierr)?; 295 296 // Set closure 297 let mut qf_ctx_ptr = std::ptr::null_mut(); 298 ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) }; 299 ceed.check_error(ierr)?; 300 ierr = unsafe { 301 bind_ceed::CeedQFunctionContextSetData( 302 qf_ctx_ptr, 303 crate::MemType::Host as bind_ceed::CeedMemType, 304 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode, 305 std::mem::size_of::<QFunctionTrampolineData>() as u64, 306 std::mem::transmute(trampoline_data.as_ref()), 307 ) 308 }; 309 ceed.check_error(ierr)?; 310 ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) }; 311 ceed.check_error(ierr)?; 312 Ok(Self { 313 qf_core: QFunctionCore { ceed, ptr }, 314 qf_ctx_ptr, 315 trampoline_data, 316 }) 317 } 318 319 /// Apply the action of a QFunction 320 /// 321 /// * `Q` - The number of quadrature points 322 /// * `input` - Array of input Vectors 323 /// * `output` - Array of output Vectors 324 /// 325 /// ``` 326 /// # use libceed::prelude::*; 327 /// # fn main() -> Result<(), libceed::CeedError> { 328 /// # let ceed = libceed::Ceed::default_init(); 329 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 330 /// // Iterate over quadrature points 331 /// v.iter_mut() 332 /// .zip(u.iter().zip(weights.iter())) 333 /// .for_each(|(v, (u, w))| *v = u * w); 334 /// 335 /// // Return clean error code 336 /// 0 337 /// }; 338 /// 339 /// let qf = ceed 340 /// .q_function_interior(1, Box::new(user_f))? 341 /// .input("u", 1, EvalMode::Interp)? 342 /// .input("weights", 1, EvalMode::Weight)? 343 /// .output("v", 1, EvalMode::Interp)?; 344 /// 345 /// const Q: usize = 8; 346 /// let mut w = [0.; Q]; 347 /// let mut u = [0.; Q]; 348 /// let mut v = [0.; Q]; 349 /// 350 /// for i in 0..Q { 351 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 352 /// u[i] = 2. + 3. * x + 5. * x * x; 353 /// w[i] = 1. - x * x; 354 /// v[i] = u[i] * w[i]; 355 /// } 356 /// 357 /// let uu = ceed.vector_from_slice(&u)?; 358 /// let ww = ceed.vector_from_slice(&w)?; 359 /// let mut vv = ceed.vector(Q)?; 360 /// vv.set_value(0.0); 361 /// { 362 /// let input = vec![uu, ww]; 363 /// let mut output = vec![vv]; 364 /// qf.apply(Q, &input, &output)?; 365 /// vv = output.remove(0); 366 /// } 367 /// 368 /// vv.view() 369 /// .iter() 370 /// .zip(v.iter()) 371 /// .for_each(|(computed, actual)| { 372 /// assert_eq!( 373 /// *computed, *actual, 374 /// "Incorrect value in QFunction application" 375 /// ); 376 /// }); 377 /// # Ok(()) 378 /// # } 379 /// ``` 380 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 381 self.qf_core.apply(Q, u, v) 382 } 383 384 /// Add a QFunction input 385 /// 386 /// * `fieldname` - Name of QFunction field 387 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 388 /// `(ncomp * 1)` for `None`, `Interp`, and `Weight` 389 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 390 /// to use interpolated values, `EvalMode::Grad` to use 391 /// gradients, `EvalMode::Weight` to use quadrature weights 392 /// 393 /// ``` 394 /// # use libceed::prelude::*; 395 /// # fn main() -> Result<(), libceed::CeedError> { 396 /// # let ceed = libceed::Ceed::default_init(); 397 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 398 /// // Iterate over quadrature points 399 /// v.iter_mut() 400 /// .zip(u.iter().zip(weights.iter())) 401 /// .for_each(|(v, (u, w))| *v = u * w); 402 /// 403 /// // Return clean error code 404 /// 0 405 /// }; 406 /// 407 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 408 /// 409 /// qf = qf.input("u", 1, EvalMode::Interp)?; 410 /// qf = qf.input("weights", 1, EvalMode::Weight)?; 411 /// # Ok(()) 412 /// # } 413 /// ``` 414 pub fn input( 415 mut self, 416 fieldname: &str, 417 size: usize, 418 emode: crate::EvalMode, 419 ) -> crate::Result<Self> { 420 let name_c = CString::new(fieldname).expect("CString::new failed"); 421 let idx = self.trampoline_data.number_inputs; 422 self.trampoline_data.input_sizes[idx] = size; 423 self.trampoline_data.number_inputs += 1; 424 let (size, emode) = ( 425 i32::try_from(size).unwrap(), 426 emode as bind_ceed::CeedEvalMode, 427 ); 428 let ierr = unsafe { 429 bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 430 }; 431 self.qf_core.ceed.check_error(ierr)?; 432 Ok(self) 433 } 434 435 /// Add a QFunction output 436 /// 437 /// * `fieldname` - Name of QFunction field 438 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 439 /// `(ncomp * 1)` for `None` and `Interp` 440 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 441 /// to use interpolated values, `EvalMode::Grad` to use 442 /// gradients 443 /// 444 /// ``` 445 /// # use libceed::prelude::*; 446 /// # fn main() -> Result<(), libceed::CeedError> { 447 /// # let ceed = libceed::Ceed::default_init(); 448 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 449 /// // Iterate over quadrature points 450 /// v.iter_mut() 451 /// .zip(u.iter().zip(weights.iter())) 452 /// .for_each(|(v, (u, w))| *v = u * w); 453 /// 454 /// // Return clean error code 455 /// 0 456 /// }; 457 /// 458 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f))?; 459 /// 460 /// qf.output("v", 1, EvalMode::Interp)?; 461 /// # Ok(()) 462 /// # } 463 /// ``` 464 pub fn output( 465 mut self, 466 fieldname: &str, 467 size: usize, 468 emode: crate::EvalMode, 469 ) -> crate::Result<Self> { 470 let name_c = CString::new(fieldname).expect("CString::new failed"); 471 let idx = self.trampoline_data.number_outputs; 472 self.trampoline_data.output_sizes[idx] = size; 473 self.trampoline_data.number_outputs += 1; 474 let (size, emode) = ( 475 i32::try_from(size).unwrap(), 476 emode as bind_ceed::CeedEvalMode, 477 ); 478 let ierr = unsafe { 479 bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 480 }; 481 self.qf_core.ceed.check_error(ierr)?; 482 Ok(self) 483 } 484 } 485 486 // ----------------------------------------------------------------------------- 487 // QFunction 488 // ----------------------------------------------------------------------------- 489 impl<'a> QFunctionByName<'a> { 490 // Constructor 491 pub fn create(ceed: &'a crate::Ceed, name: &str) -> crate::Result<Self> { 492 let name_c = CString::new(name).expect("CString::new failed"); 493 let mut ptr = std::ptr::null_mut(); 494 let ierr = unsafe { 495 bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr) 496 }; 497 ceed.check_error(ierr)?; 498 Ok(Self { 499 qf_core: QFunctionCore { ceed, ptr }, 500 }) 501 } 502 503 /// Apply the action of a QFunction 504 /// 505 /// * `Q` - The number of quadrature points 506 /// * `input` - Array of input Vectors 507 /// * `output` - Array of output Vectors 508 /// 509 /// ``` 510 /// # use libceed::prelude::*; 511 /// # fn main() -> Result<(), libceed::CeedError> { 512 /// # let ceed = libceed::Ceed::default_init(); 513 /// const Q: usize = 8; 514 /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild")?; 515 /// let qf_mass = ceed.q_function_interior_by_name("MassApply")?; 516 /// 517 /// let mut j = [0.; Q]; 518 /// let mut w = [0.; Q]; 519 /// let mut u = [0.; Q]; 520 /// let mut v = [0.; Q]; 521 /// 522 /// for i in 0..Q { 523 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 524 /// j[i] = 1.; 525 /// w[i] = 1. - x * x; 526 /// u[i] = 2. + 3. * x + 5. * x * x; 527 /// v[i] = w[i] * u[i]; 528 /// } 529 /// 530 /// let jj = ceed.vector_from_slice(&j)?; 531 /// let ww = ceed.vector_from_slice(&w)?; 532 /// let uu = ceed.vector_from_slice(&u)?; 533 /// let mut vv = ceed.vector(Q)?; 534 /// vv.set_value(0.0); 535 /// let mut qdata = ceed.vector(Q)?; 536 /// qdata.set_value(0.0); 537 /// 538 /// { 539 /// let mut input = vec![jj, ww]; 540 /// let mut output = vec![qdata]; 541 /// qf_build.apply(Q, &input, &output)?; 542 /// qdata = output.remove(0); 543 /// } 544 /// 545 /// { 546 /// let mut input = vec![qdata, uu]; 547 /// let mut output = vec![vv]; 548 /// qf_mass.apply(Q, &input, &output)?; 549 /// vv = output.remove(0); 550 /// } 551 /// 552 /// vv.view() 553 /// .iter() 554 /// .zip(v.iter()) 555 /// .for_each(|(computed, actual)| { 556 /// assert_eq!( 557 /// *computed, *actual, 558 /// "Incorrect value in QFunction application" 559 /// ); 560 /// }); 561 /// # Ok(()) 562 /// # } 563 /// ``` 564 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 565 self.qf_core.apply(Q, u, v) 566 } 567 } 568 569 // ----------------------------------------------------------------------------- 570