1 // Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at 2 // the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights 3 // reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative 16 17 //! A Ceed QFunction represents the spatial terms of the point-wise functions 18 //! describing the physics at the quadrature points. 19 20 use std::pin::Pin; 21 22 use crate::prelude::*; 23 24 pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 25 pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS]; 26 27 // ----------------------------------------------------------------------------- 28 // CeedQFunction option 29 // ----------------------------------------------------------------------------- 30 #[derive(Clone, Copy)] 31 pub enum QFunctionOpt<'a> { 32 SomeQFunction(&'a QFunction<'a>), 33 SomeQFunctionByName(&'a QFunctionByName<'a>), 34 None, 35 } 36 37 /// Construct a QFunctionOpt reference from a QFunction reference 38 impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> { 39 fn from(qfunc: &'a QFunction) -> Self { 40 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 41 Self::SomeQFunction(qfunc) 42 } 43 } 44 45 /// Construct a QFunctionOpt reference from a QFunction by Name reference 46 impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> { 47 fn from(qfunc: &'a QFunctionByName) -> Self { 48 debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE }); 49 Self::SomeQFunctionByName(qfunc) 50 } 51 } 52 53 impl<'a> QFunctionOpt<'a> { 54 /// Transform a Rust libCEED QFunctionOpt into C libCEED CeedQFunction 55 pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction { 56 match self { 57 Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr, 58 Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr, 59 Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE }, 60 } 61 } 62 } 63 64 // ----------------------------------------------------------------------------- 65 // CeedQFunction context wrapper 66 // ----------------------------------------------------------------------------- 67 pub(crate) struct QFunctionCore<'a> { 68 ceed: &'a crate::Ceed, 69 ptr: bind_ceed::CeedQFunction, 70 } 71 72 struct QFunctionTrampolineData { 73 number_inputs: usize, 74 number_outputs: usize, 75 input_sizes: [usize; MAX_QFUNCTION_FIELDS], 76 output_sizes: [usize; MAX_QFUNCTION_FIELDS], 77 user_f: Box<QFunctionUserClosure>, 78 } 79 80 pub struct QFunction<'a> { 81 qf_core: QFunctionCore<'a>, 82 qf_ctx_ptr: bind_ceed::CeedQFunctionContext, 83 trampoline_data: Pin<Box<QFunctionTrampolineData>>, 84 } 85 86 pub struct QFunctionByName<'a> { 87 qf_core: QFunctionCore<'a>, 88 } 89 90 // ----------------------------------------------------------------------------- 91 // Destructor 92 // ----------------------------------------------------------------------------- 93 impl<'a> Drop for QFunctionCore<'a> { 94 fn drop(&mut self) { 95 unsafe { 96 if self.ptr != bind_ceed::CEED_QFUNCTION_NONE { 97 bind_ceed::CeedQFunctionDestroy(&mut self.ptr); 98 } 99 } 100 } 101 } 102 103 impl<'a> Drop for QFunction<'a> { 104 fn drop(&mut self) { 105 unsafe { 106 bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr); 107 } 108 } 109 } 110 111 // ----------------------------------------------------------------------------- 112 // Display 113 // ----------------------------------------------------------------------------- 114 impl<'a> fmt::Display for QFunctionCore<'a> { 115 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 116 let mut ptr = std::ptr::null_mut(); 117 let mut sizeloc = crate::MAX_BUFFER_LENGTH; 118 let cstring = unsafe { 119 let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc); 120 bind_ceed::CeedQFunctionView(self.ptr, file); 121 bind_ceed::fclose(file); 122 CString::from_raw(ptr) 123 }; 124 cstring.to_string_lossy().fmt(f) 125 } 126 } 127 /// View a QFunction 128 /// 129 /// ``` 130 /// # use libceed::prelude::*; 131 /// # let ceed = libceed::Ceed::default_init(); 132 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 133 /// // Iterate over quadrature points 134 /// v.iter_mut() 135 /// .zip(u.iter().zip(weights.iter())) 136 /// .for_each(|(v, (u, w))| *v = u * w); 137 /// 138 /// // Return clean error code 139 /// 0 140 /// }; 141 /// 142 /// let qf = ceed 143 /// .q_function_interior(1, Box::new(user_f)) 144 /// .unwrap() 145 /// .input("u", 1, EvalMode::Interp) 146 /// .unwrap() 147 /// .input("weights", 1, EvalMode::Weight) 148 /// .unwrap() 149 /// .output("v", 1, EvalMode::Interp) 150 /// .unwrap(); 151 /// 152 /// println!("{}", qf); 153 /// ``` 154 impl<'a> fmt::Display for QFunction<'a> { 155 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 156 self.qf_core.fmt(f) 157 } 158 } 159 160 /// View a QFunction by Name 161 /// 162 /// ``` 163 /// # use libceed::prelude::*; 164 /// # let ceed = libceed::Ceed::default_init(); 165 /// let qf = ceed.q_function_interior_by_name("Mass1DBuild").unwrap(); 166 /// println!("{}", qf); 167 /// ``` 168 impl<'a> fmt::Display for QFunctionByName<'a> { 169 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 170 self.qf_core.fmt(f) 171 } 172 } 173 174 // ----------------------------------------------------------------------------- 175 // Core functionality 176 // ----------------------------------------------------------------------------- 177 impl<'a> QFunctionCore<'a> { 178 // Common implementation 179 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 180 let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 181 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) { 182 u_c[i] = u[i].ptr; 183 } 184 let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS]; 185 for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) { 186 v_c[i] = v[i].ptr; 187 } 188 let Q = i32::try_from(Q).unwrap(); 189 let ierr = unsafe { 190 bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr()) 191 }; 192 self.ceed.check_error(ierr) 193 } 194 } 195 196 // ----------------------------------------------------------------------------- 197 // User QFunction Closure 198 // ----------------------------------------------------------------------------- 199 pub type QFunctionUserClosure = dyn FnMut( 200 [&[crate::Scalar]; MAX_QFUNCTION_FIELDS], 201 [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS], 202 ) -> i32; 203 204 macro_rules! mut_max_fields { 205 ($e:expr) => { 206 [ 207 $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, 208 ] 209 }; 210 } 211 unsafe extern "C" fn trampoline( 212 ctx: *mut ::std::os::raw::c_void, 213 q: bind_ceed::CeedInt, 214 inputs: *const *const bind_ceed::CeedScalar, 215 outputs: *const *mut bind_ceed::CeedScalar, 216 ) -> ::std::os::raw::c_int { 217 let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx); 218 219 // Inputs 220 let inputs_slice: &[*const bind_ceed::CeedScalar] = 221 std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS); 222 let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS]; 223 inputs_slice 224 .iter() 225 .enumerate() 226 .map(|(i, &x)| { 227 std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize) 228 as &[crate::Scalar] 229 }) 230 .zip(inputs_array.iter_mut()) 231 .for_each(|(x, a)| *a = x); 232 233 // Outputs 234 let outputs_slice: &[*mut bind_ceed::CeedScalar] = 235 std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS); 236 let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] = 237 mut_max_fields!(&mut [0.0]); 238 outputs_slice 239 .iter() 240 .enumerate() 241 .map(|(i, &x)| { 242 std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize) 243 as &mut [crate::Scalar] 244 }) 245 .zip(outputs_array.iter_mut()) 246 .for_each(|(x, a)| *a = x); 247 248 // User closure 249 (trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array) 250 } 251 252 // ----------------------------------------------------------------------------- 253 // QFunction 254 // ----------------------------------------------------------------------------- 255 impl<'a> QFunction<'a> { 256 // Constructor 257 pub fn create( 258 ceed: &'a crate::Ceed, 259 vlength: usize, 260 user_f: Box<QFunctionUserClosure>, 261 ) -> crate::Result<Self> { 262 let source_c = CString::new("").expect("CString::new failed"); 263 let mut ptr = std::ptr::null_mut(); 264 265 // Context for closure 266 let number_inputs = 0; 267 let number_outputs = 0; 268 let input_sizes = [0; MAX_QFUNCTION_FIELDS]; 269 let output_sizes = [0; MAX_QFUNCTION_FIELDS]; 270 let trampoline_data = unsafe { 271 Pin::new_unchecked(Box::new(QFunctionTrampolineData { 272 number_inputs, 273 number_outputs, 274 input_sizes, 275 output_sizes, 276 user_f, 277 })) 278 }; 279 280 // Create QFunction 281 let vlength = i32::try_from(vlength).unwrap(); 282 let mut ierr = unsafe { 283 bind_ceed::CeedQFunctionCreateInterior( 284 ceed.ptr, 285 vlength, 286 Some(trampoline), 287 source_c.as_ptr(), 288 &mut ptr, 289 ) 290 }; 291 ceed.check_error(ierr)?; 292 293 // Set closure 294 let mut qf_ctx_ptr = std::ptr::null_mut(); 295 ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) }; 296 ceed.check_error(ierr)?; 297 ierr = unsafe { 298 bind_ceed::CeedQFunctionContextSetData( 299 qf_ctx_ptr, 300 crate::MemType::Host as bind_ceed::CeedMemType, 301 crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode, 302 std::mem::size_of::<QFunctionTrampolineData>() as u64, 303 std::mem::transmute(trampoline_data.as_ref()), 304 ) 305 }; 306 ceed.check_error(ierr)?; 307 ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) }; 308 ceed.check_error(ierr)?; 309 Ok(Self { 310 qf_core: QFunctionCore { ceed, ptr }, 311 qf_ctx_ptr, 312 trampoline_data, 313 }) 314 } 315 316 /// Apply the action of a QFunction 317 /// 318 /// * `Q` - The number of quadrature points 319 /// * `input` - Array of input Vectors 320 /// * `output` - Array of output Vectors 321 /// 322 /// ``` 323 /// # use libceed::prelude::*; 324 /// # let ceed = libceed::Ceed::default_init(); 325 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 326 /// // Iterate over quadrature points 327 /// v.iter_mut() 328 /// .zip(u.iter().zip(weights.iter())) 329 /// .for_each(|(v, (u, w))| *v = u * w); 330 /// 331 /// // Return clean error code 332 /// 0 333 /// }; 334 /// 335 /// let qf = ceed 336 /// .q_function_interior(1, Box::new(user_f)) 337 /// .unwrap() 338 /// .input("u", 1, EvalMode::Interp) 339 /// .unwrap() 340 /// .input("weights", 1, EvalMode::Weight) 341 /// .unwrap() 342 /// .output("v", 1, EvalMode::Interp) 343 /// .unwrap(); 344 /// 345 /// const Q: usize = 8; 346 /// let mut w = [0.; Q]; 347 /// let mut u = [0.; Q]; 348 /// let mut v = [0.; Q]; 349 /// 350 /// for i in 0..Q { 351 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 352 /// u[i] = 2. + 3. * x + 5. * x * x; 353 /// w[i] = 1. - x * x; 354 /// v[i] = u[i] * w[i]; 355 /// } 356 /// 357 /// let uu = ceed.vector_from_slice(&u).unwrap(); 358 /// let ww = ceed.vector_from_slice(&w).unwrap(); 359 /// let mut vv = ceed.vector(Q).unwrap(); 360 /// vv.set_value(0.0); 361 /// { 362 /// let input = vec![uu, ww]; 363 /// let mut output = vec![vv]; 364 /// qf.apply(Q, &input, &output).unwrap(); 365 /// vv = output.remove(0); 366 /// } 367 /// 368 /// vv.view() 369 /// .iter() 370 /// .zip(v.iter()) 371 /// .for_each(|(computed, actual)| { 372 /// assert_eq!( 373 /// *computed, *actual, 374 /// "Incorrect value in QFunction application" 375 /// ); 376 /// }); 377 /// ``` 378 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 379 self.qf_core.apply(Q, u, v) 380 } 381 382 /// Add a QFunction input 383 /// 384 /// * `fieldname` - Name of QFunction field 385 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 386 /// `(ncomp * 1)` for `None`, `Interp`, and `Weight` 387 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 388 /// to use interpolated values, `EvalMode::Grad` to use 389 /// gradients, `EvalMode::Weight` to use quadrature weights 390 /// 391 /// ``` 392 /// # use libceed::prelude::*; 393 /// # let ceed = libceed::Ceed::default_init(); 394 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 395 /// // Iterate over quadrature points 396 /// v.iter_mut() 397 /// .zip(u.iter().zip(weights.iter())) 398 /// .for_each(|(v, (u, w))| *v = u * w); 399 /// 400 /// // Return clean error code 401 /// 0 402 /// }; 403 /// 404 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f)).unwrap(); 405 /// 406 /// qf = qf.input("u", 1, EvalMode::Interp).unwrap(); 407 /// qf = qf.input("weights", 1, EvalMode::Weight).unwrap(); 408 /// ``` 409 pub fn input( 410 mut self, 411 fieldname: &str, 412 size: usize, 413 emode: crate::EvalMode, 414 ) -> crate::Result<Self> { 415 let name_c = CString::new(fieldname).expect("CString::new failed"); 416 let idx = self.trampoline_data.number_inputs; 417 self.trampoline_data.input_sizes[idx] = size; 418 self.trampoline_data.number_inputs += 1; 419 let (size, emode) = ( 420 i32::try_from(size).unwrap(), 421 emode as bind_ceed::CeedEvalMode, 422 ); 423 let ierr = unsafe { 424 bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 425 }; 426 self.qf_core.ceed.check_error(ierr)?; 427 Ok(self) 428 } 429 430 /// Add a QFunction output 431 /// 432 /// * `fieldname` - Name of QFunction field 433 /// * `size` - Size of QFunction field, `(ncomp * dim)` for `Grad` or 434 /// `(ncomp * 1)` for `None` and `Interp` 435 /// * `emode` - `EvalMode::None` to use values directly, `EvalMode::Interp` 436 /// to use interpolated values, `EvalMode::Grad` to use 437 /// gradients 438 /// 439 /// ``` 440 /// # use libceed::prelude::*; 441 /// # let ceed = libceed::Ceed::default_init(); 442 /// let mut user_f = |[u, weights, ..]: QFunctionInputs, [v, ..]: QFunctionOutputs| { 443 /// // Iterate over quadrature points 444 /// v.iter_mut() 445 /// .zip(u.iter().zip(weights.iter())) 446 /// .for_each(|(v, (u, w))| *v = u * w); 447 /// 448 /// // Return clean error code 449 /// 0 450 /// }; 451 /// 452 /// let mut qf = ceed.q_function_interior(1, Box::new(user_f)).unwrap(); 453 /// 454 /// qf.output("v", 1, EvalMode::Interp).unwrap(); 455 /// ``` 456 pub fn output( 457 mut self, 458 fieldname: &str, 459 size: usize, 460 emode: crate::EvalMode, 461 ) -> crate::Result<Self> { 462 let name_c = CString::new(fieldname).expect("CString::new failed"); 463 let idx = self.trampoline_data.number_outputs; 464 self.trampoline_data.output_sizes[idx] = size; 465 self.trampoline_data.number_outputs += 1; 466 let (size, emode) = ( 467 i32::try_from(size).unwrap(), 468 emode as bind_ceed::CeedEvalMode, 469 ); 470 let ierr = unsafe { 471 bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode) 472 }; 473 self.qf_core.ceed.check_error(ierr)?; 474 Ok(self) 475 } 476 } 477 478 // ----------------------------------------------------------------------------- 479 // QFunction 480 // ----------------------------------------------------------------------------- 481 impl<'a> QFunctionByName<'a> { 482 // Constructor 483 pub fn create(ceed: &'a crate::Ceed, name: &str) -> crate::Result<Self> { 484 let name_c = CString::new(name).expect("CString::new failed"); 485 let mut ptr = std::ptr::null_mut(); 486 let ierr = unsafe { 487 bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr) 488 }; 489 ceed.check_error(ierr)?; 490 Ok(Self { 491 qf_core: QFunctionCore { ceed, ptr }, 492 }) 493 } 494 495 /// Apply the action of a QFunction 496 /// 497 /// * `Q` - The number of quadrature points 498 /// * `input` - Array of input Vectors 499 /// * `output` - Array of output Vectors 500 /// 501 /// ``` 502 /// # use libceed::prelude::*; 503 /// # let ceed = libceed::Ceed::default_init(); 504 /// const Q: usize = 8; 505 /// let qf_build = ceed.q_function_interior_by_name("Mass1DBuild").unwrap(); 506 /// let qf_mass = ceed.q_function_interior_by_name("MassApply").unwrap(); 507 /// 508 /// let mut j = [0.; Q]; 509 /// let mut w = [0.; Q]; 510 /// let mut u = [0.; Q]; 511 /// let mut v = [0.; Q]; 512 /// 513 /// for i in 0..Q { 514 /// let x = 2. * (i as Scalar) / ((Q as Scalar) - 1.) - 1.; 515 /// j[i] = 1.; 516 /// w[i] = 1. - x * x; 517 /// u[i] = 2. + 3. * x + 5. * x * x; 518 /// v[i] = w[i] * u[i]; 519 /// } 520 /// 521 /// let jj = ceed.vector_from_slice(&j).unwrap(); 522 /// let ww = ceed.vector_from_slice(&w).unwrap(); 523 /// let uu = ceed.vector_from_slice(&u).unwrap(); 524 /// let mut vv = ceed.vector(Q).unwrap(); 525 /// vv.set_value(0.0); 526 /// let mut qdata = ceed.vector(Q).unwrap(); 527 /// qdata.set_value(0.0); 528 /// 529 /// { 530 /// let mut input = vec![jj, ww]; 531 /// let mut output = vec![qdata]; 532 /// qf_build.apply(Q, &input, &output).unwrap(); 533 /// qdata = output.remove(0); 534 /// } 535 /// 536 /// { 537 /// let mut input = vec![qdata, uu]; 538 /// let mut output = vec![vv]; 539 /// qf_mass.apply(Q, &input, &output).unwrap(); 540 /// vv = output.remove(0); 541 /// } 542 /// 543 /// vv.view() 544 /// .iter() 545 /// .zip(v.iter()) 546 /// .for_each(|(computed, actual)| { 547 /// assert_eq!( 548 /// *computed, *actual, 549 /// "Incorrect value in QFunction application" 550 /// ); 551 /// }); 552 /// ``` 553 pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> { 554 self.qf_core.apply(Q, u, v) 555 } 556 } 557 558 // ----------------------------------------------------------------------------- 559