xref: /libCEED/examples/rust/mesh/src/lib.rs (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 use libceed::{Ceed, ElemRestriction, Vector};
9 use std::convert::TryInto;
10 
11 // ----------------------------------------------------------------------------
12 // Determine problem size in each dimension from size and dimenison
13 // ----------------------------------------------------------------------------
cartesian_mesh_size(dim: usize, solution_degree: usize, problem_size: i64) -> [usize; 3]14 pub fn cartesian_mesh_size(dim: usize, solution_degree: usize, problem_size: i64) -> [usize; 3] {
15     // Use the approximate formula:
16     //    prob_size ~ num_elem * degree^dim
17     let mut num_elem = problem_size / solution_degree.pow(dim as u32) as i64;
18     let mut s = 0; // find s: num_elem / 2 < 2^s <= num_elem
19     while num_elem > 1 {
20         num_elem /= 2;
21         s += 1;
22     }
23 
24     // Size per dimension
25     let mut r = s % dim;
26     let xyz: [usize; 3] = (0..3)
27         .map(|_| -> usize {
28             let mut sd = s / dim;
29             if r > 0 {
30                 sd += 1;
31                 r -= 1;
32             }
33             1 << sd
34         })
35         .collect::<Vec<usize>>()
36         .try_into()
37         .unwrap();
38     xyz
39 }
40 
41 // ----------------------------------------------------------------------------
42 // Build element restriction objects for the mesh
43 // ----------------------------------------------------------------------------
build_cartesian_restriction( ceed: &Ceed, dim: usize, num_xyz: [usize; 3], degree: usize, num_comp: usize, num_qpts: usize, ) -> libceed::Result<(ElemRestriction, ElemRestriction)>44 pub fn build_cartesian_restriction(
45     ceed: &Ceed,
46     dim: usize,
47     num_xyz: [usize; 3],
48     degree: usize,
49     num_comp: usize,
50     num_qpts: usize,
51 ) -> libceed::Result<(ElemRestriction, ElemRestriction)> {
52     let p = degree + 1;
53     let num_nodes = p.pow(dim as u32); // number of nodes per element
54     let elem_qpts = num_qpts.pow(dim as u32); // number of quadrature pts per element
55 
56     // Problem dimensions
57     let mut num_d = [0; 3];
58     let mut num_elem = 1;
59     let mut scalar_size = 1;
60     for d in 0..dim {
61         num_elem *= num_xyz[d];
62         num_d[d] = num_xyz[d] * (p - 1) + 1;
63         scalar_size *= num_d[d];
64     }
65 
66     // elem:         0             1                 n-1
67     //        |---*-...-*---|---*-...-*---|- ... -|--...--|
68     // nodes: 0   1    p-1  p  p+1       2*p             n*p
69     let mut elem_nodes = vec![0; num_elem * num_nodes];
70     for e in 0..num_elem {
71         let mut e_xyz = [1; 3];
72         let mut re = e;
73         for d in 0..dim {
74             e_xyz[d] = re % num_xyz[d];
75             re /= num_xyz[d];
76         }
77         let loc_offset = e * num_nodes;
78         for loc_nodes in 0..num_nodes {
79             let mut global_nodes = 0;
80             let mut global_nodes_stride = 1;
81             let mut r_nodes = loc_nodes;
82             for d in 0..dim {
83                 global_nodes += (e_xyz[d] * (p - 1) + r_nodes % p) * global_nodes_stride;
84                 global_nodes_stride *= num_d[d];
85                 r_nodes /= p;
86             }
87             elem_nodes[loc_offset + loc_nodes] = global_nodes as i32;
88         }
89     }
90 
91     // Mesh/solution data restriction
92     let rstr = ceed.elem_restriction(
93         num_elem,
94         num_nodes,
95         num_comp,
96         scalar_size,
97         num_comp * scalar_size,
98         libceed::MemType::Host,
99         &elem_nodes,
100     )?;
101 
102     // Quadrature data restriction
103     let rstr_qdata = ceed.strided_elem_restriction(
104         num_elem,
105         elem_qpts,
106         num_comp,
107         num_comp * elem_qpts * num_elem,
108         libceed::CEED_STRIDES_BACKEND,
109     )?;
110     Ok((rstr, rstr_qdata))
111 }
112 
113 // ----------------------------------------------------------------------------
114 // Set mesh coordinates
115 // ----------------------------------------------------------------------------
cartesian_mesh_coords( ceed: &Ceed, dim: usize, num_xyz: [usize; 3], mesh_degree: usize, mesh_size: usize, ) -> libceed::Result<Vector>116 pub fn cartesian_mesh_coords(
117     ceed: &Ceed,
118     dim: usize,
119     num_xyz: [usize; 3],
120     mesh_degree: usize,
121     mesh_size: usize,
122 ) -> libceed::Result<Vector> {
123     let p = mesh_degree + 1;
124     let mut num_d = [0; 3];
125     let mut scalar_size = 1;
126     for d in 0..dim {
127         num_d[d] = num_xyz[d] * (p - 1) + 1;
128         scalar_size *= num_d[d];
129     }
130 
131     // Lobatto points
132     let lobatto_basis =
133         ceed.basis_tensor_H1_Lagrange(1, 1, 2, p, libceed::QuadMode::GaussLobatto)?;
134     let nodes_corners = ceed.vector_from_slice(&[0.0, 1.0])?;
135     let mut nodes_full = ceed.vector(p)?;
136     lobatto_basis.apply(
137         1,
138         libceed::TransposeMode::NoTranspose,
139         libceed::EvalMode::Interp,
140         &nodes_corners,
141         &mut nodes_full,
142     )?;
143 
144     // Coordinates for mesh
145     let mut mesh_coords = ceed.vector(mesh_size)?;
146     mesh_coords.set_value(0.0)?;
147     {
148         let mut coords = mesh_coords.view_mut()?;
149         let nodes = nodes_full.view()?;
150         for gs_nodes in 0..scalar_size {
151             let mut r_nodes = gs_nodes;
152             for d in 0..dim {
153                 let d_1d = r_nodes % num_d[d];
154                 coords[gs_nodes + scalar_size * d] = ((d_1d / (p - 1)) as libceed::Scalar
155                     + nodes[d_1d % (p - 1)])
156                     / num_xyz[d] as libceed::Scalar;
157                 r_nodes /= num_d[d];
158             }
159         }
160     }
161     Ok(mesh_coords)
162 }
163 
164 // ----------------------------------------------------------------------------
165