xref: /libCEED/julia/LibCEED.jl/src/Basis.jl (revision 3a739e1923fc978b8248ffcb40b5d2f55443c9d9)
1abstract type AbstractBasis end
2
3"""
4    BasisCollocated()
5
6Returns the singleton object corresponding to libCEED's `CEED_BASIS_COLLOCATED`.
7"""
8struct BasisCollocated <: AbstractBasis end
9Base.getindex(::BasisCollocated) = C.CEED_BASIS_COLLOCATED[]
10
11"""
12    Basis
13
14Wraps a `CeedBasis` object, representing a finite element basis. A `Basis` object can be
15created using one of:
16
17- [`create_tensor_h1_lagrange_basis`](@ref)
18- [`create_tensor_h1_basis`](@ref)
19- [`create_h1_basis`](@ref)
20"""
21mutable struct Basis <: AbstractBasis
22    ref::RefValue{C.CeedBasis}
23    function Basis(ref)
24        obj = new(ref)
25        finalizer(obj) do x
26            # ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring), "Finalizing %s.\n", repr(x))
27            destroy(x)
28        end
29        return obj
30    end
31end
32destroy(b::Basis) = C.CeedBasisDestroy(b.ref) # COV_EXCL_LINE
33Base.getindex(b::Basis) = b.ref[]
34Base.show(io::IO, ::MIME"text/plain", b::Basis) = ceed_show(io, b, C.CeedBasisView)
35
36@doc raw"""
37    create_tensor_h1_lagrange_basis(ceed, dim, ncomp, p, q, qmode)
38
39Create a tensor-product Lagrange basis.
40
41# Arguments:
42- `ceed`:  A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
43- `dim`:   Topological dimension of element.
44- `ncomp`: Number of field components (1 for scalar fields).
45- `p`:     Number of Gauss-Lobatto nodes in one dimension.  The polynomial degree of the
46           resulting $Q_k$ element is $k=p-1$.
47- `q`:     Number of quadrature points in one dimension.
48- `qmode`: Distribution of the $q$ quadrature points (affects order of accuracy for the
49           quadrature).
50"""
51function create_tensor_h1_lagrange_basis(c::Ceed, dim, ncomp, p, q, quad_mode::QuadMode)
52    ref = Ref{C.CeedBasis}()
53    C.CeedBasisCreateTensorH1Lagrange(c[], dim, ncomp, p, q, quad_mode, ref)
54    Basis(ref)
55end
56
57@doc raw"""
58    create_tensor_h1_basis(c::Ceed, dim, ncomp, p, q, interp1d, grad1d, qref1d, qweight1d)
59
60Create a tensor-product basis for $H^1$ discretizations.
61
62# Arguments:
63- `ceed`:      A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
64- `dim`:       Topological dimension.
65- `ncomp`:     Number of field components (1 for scalar fields).
66- `p`:         Number of nodes in one dimension.
67- `q`:         Number of quadrature points in one dimension
68- `interp1d`:  Matrix of size `(q, p)` expressing the values of nodal basis functions at
69               quadrature points.
70- `grad1d`:    Matrix of size `(p, q)` expressing derivatives of nodal basis functions at
71               quadrature points.
72- `qref1d`:    Array of length `q` holding the locations of quadrature points on the 1D
73               reference element $[-1, 1]$.
74- `qweight1d`: Array of length `q` holding the quadrature weights on the reference element.
75"""
76function create_tensor_h1_basis(
77    c::Ceed,
78    dim,
79    ncomp,
80    p,
81    q,
82    interp1d,
83    grad1d,
84    qref1d,
85    qweight1d,
86)
87    @assert size(interp1d) == (q, p)
88    @assert size(grad1d) == (q, p)
89    @assert length(qref1d) == q
90    @assert length(qweight1d) == q
91
92    # Convert from Julia matrices (column-major) to row-major format
93    interp1d_rowmajor = collect(interp1d')
94    grad1d_rowmajor = collect(grad1d')
95
96    ref = Ref{C.CeedBasis}()
97    C.CeedBasisCreateTensorH1(
98        c[],
99        dim,
100        ncomp,
101        p,
102        q,
103        interp1d_rowmajor,
104        grad1d_rowmajor,
105        qref1d,
106        qweight1d,
107        ref,
108    )
109    Basis(ref)
110end
111
112@doc raw"""
113    create_h1_basis(c::Ceed, topo::Topology, ncomp, nnodes, nqpts, interp, grad, qref, qweight)
114
115Create a non tensor-product basis for H^1 discretizations
116
117# Arguments:
118- `ceed`:    A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
119- `topo`:    [`Topology`](@ref) of element, e.g. hypercube, simplex, etc.
120- `ncomp`:   Number of field components (1 for scalar fields).
121- `nnodes`:  Total number of nodes.
122- `nqpts`:   Total number of quadrature points.
123- `interp`:  Matrix of size `(nqpts, nnodes)` expressing the values of nodal basis functions
124             at quadrature points.
125- `grad`:    Array of size `(dim, nqpts, nnodes)` expressing derivatives of nodal basis
126             functions at quadrature points.
127- `qref`:    Array of length `nqpts` holding the locations of quadrature points on the
128             reference element $[-1, 1]$.
129- `qweight`: Array of length `nqpts` holding the quadrature weights on the reference
130             element.
131"""
132function create_h1_basis(
133    c::Ceed,
134    topo::Topology,
135    ncomp,
136    nnodes,
137    nqpts,
138    interp,
139    grad,
140    qref,
141    qweight,
142)
143    dim = getdimension(topo)
144    @assert size(interp) == (nqpts, nnodes)
145    @assert size(grad) == (dim, nqpts, nnodes)
146    @assert length(qref) == nqpts
147    @assert length(qweight) == nqpts
148
149    # Convert from Julia matrices and tensors (column-major) to row-major format
150    interp_rowmajor = collect(interp')
151    grad_rowmajor = permutedims(grad, [3, 2, 1])
152
153    ref = Ref{C.CeedBasis}()
154    C.CeedBasisCreateH1(
155        c[],
156        topo,
157        ncomp,
158        nnodes,
159        nqpts,
160        interp_rowmajor,
161        grad_rowmajor,
162        qref,
163        qweight,
164        ref,
165    )
166    Basis(ref)
167end
168
169"""
170    apply!(b::Basis, nelem, tmode::TransposeMode, emode::EvalMode, u::AbstractCeedVector, v::AbstractCeedVector)
171
172Apply basis evaluation from nodes to quadrature points or vice versa, storing the result in
173the [`CeedVector`](@ref) `v`.
174
175`nelem` specifies the number of elements to apply the basis evaluation to; the backend will
176specify the ordering in CeedElemRestrictionCreateBlocked()
177
178Set `tmode` to `CEED_NOTRANSPOSE` to evaluate from nodes to quadrature or to
179`CEED_TRANSPOSE` to apply the transpose, mapping from quadrature points to nodes.
180
181Set the [`EvalMode`](@ref) `emode` to:
182- `CEED_EVAL_NONE` to use values directly,
183- `CEED_EVAL_INTERP` to use interpolated values,
184- `CEED_EVAL_GRAD` to use gradients,
185- `CEED_EVAL_WEIGHT` to use quadrature weights.
186"""
187function apply!(
188    b::Basis,
189    nelem,
190    tmode::TransposeMode,
191    emode::EvalMode,
192    u::AbstractCeedVector,
193    v::AbstractCeedVector,
194)
195    C.CeedBasisApply(b[], nelem, tmode, emode, u[], v[])
196end
197
198"""
199    apply(b::Basis, u::AbstractVector; nelem=1, tmode=NOTRANSPOSE, emode=EVAL_INTERP)
200
201Performs the same function as the above-defined [`apply!`](@ref apply!(b::Basis, nelem,
202tmode::TransposeMode, emode::EvalMode, u::AbstractCeedVector, v::AbstractCeedVector)), but
203automatically convert from Julia arrays to [`CeedVector`](@ref) for convenience.
204
205The result will be returned in a newly allocated array of the correct size.
206"""
207function apply(b::Basis, u::AbstractVector; nelem=1, tmode=NOTRANSPOSE, emode=EVAL_INTERP)
208    ceed_ref = Ref{C.Ceed}()
209    ccall((:CeedBasisGetCeed, C.libceed), Cint, (C.CeedBasis, Ptr{C.Ceed}), b[], ceed_ref)
210    c = Ceed(ceed_ref)
211
212    u_vec = CeedVector(c, u)
213
214    len_v = (tmode == TRANSPOSE) ? getnumnodes(b) : getnumqpts(b)
215    if emode == EVAL_GRAD
216        len_v *= getdimension(b)
217    end
218
219    v_vec = CeedVector(c, len_v)
220
221    apply!(b, nelem, tmode, emode, u_vec, v_vec)
222    Vector(v_vec)
223end
224
225"""
226    getdimension(b::Basis)
227
228Return the spatial dimension of the given [`Basis`](@ref).
229"""
230function getdimension(b::Basis)
231    dim = Ref{CeedInt}()
232    C.CeedBasisGetDimension(b[], dim)
233    dim[]
234end
235
236"""
237    getdimension(t::Topology)
238
239Return the spatial dimension of the given [`Topology`](@ref).
240"""
241function getdimension(t::Topology)
242    return Int(t) >> 16
243end
244
245"""
246    gettopology(b::Basis)
247
248Return the [`Topology`](@ref) of the given [`Basis`](@ref).
249"""
250function gettopology(b::Basis)
251    topo = Ref{Topology}()
252    C.CeedBasisGetTopology(b[], topo)
253    topo[]
254end
255
256"""
257    getnumcomponents(b::Basis)
258
259Return the number of components of the given [`Basis`](@ref).
260"""
261function getnumcomponents(b::Basis)
262    ncomp = Ref{CeedInt}()
263    C.CeedBasisGetNumComponents(b[], ncomp)
264    ncomp[]
265end
266
267"""
268    getnumnodes(b::Basis)
269
270Return the number of nodes of the given [`Basis`](@ref).
271"""
272function getnumnodes(b::Basis)
273    nnodes = Ref{CeedInt}()
274    C.CeedBasisGetNumNodes(b[], nnodes)
275    nnodes[]
276end
277
278"""
279    getnumnodes1d(b::Basis)
280
281    Return the number of 1D nodes of the given (tensor-product) [`Basis`](@ref).
282"""
283function getnumnodes1d(b::Basis)
284    nnodes1d = Ref{CeedInt}()
285    C.CeedBasisGetNumNodes1D(b[], nnodes1d)
286    nnodes1d[]
287end
288
289"""
290    getnumqpts(b::Basis)
291
292Return the number of quadrature points of the given [`Basis`](@ref).
293"""
294function getnumqpts(b::Basis)
295    nqpts = Ref{CeedInt}()
296    C.CeedBasisGetNumQuadraturePoints(b[], nqpts)
297    nqpts[]
298end
299
300"""
301    getnumqpts1d(b::Basis)
302
303Return the number of 1D quadrature points of the given (tensor-product) [`Basis`](@ref).
304"""
305function getnumqpts1d(b::Basis)
306    nqpts1d = Ref{CeedInt}()
307    C.CeedBasisGetNumQuadraturePoints1D(b[], nqpts1d)
308    nqpts1d[]
309end
310
311"""
312    getqref(b::Basis)
313
314Get the reference coordinates of quadrature points (in `dim` dimensions) of the given
315[`Basis`](@ref).
316"""
317function getqref(b::Basis)
318    ref = Ref{Ptr{CeedScalar}}()
319    C.CeedBasisGetQRef(b[], ref)
320    copy(unsafe_wrap(Array, ref[], getnumqpts(b)))
321end
322
323"""
324    getqref(b::Basis)
325
326Get the quadrature weights of quadrature points (in `dim` dimensions) of the given
327[`Basis`](@ref).
328"""
329function getqweights(b::Basis)
330    ref = Ref{Ptr{CeedScalar}}()
331    C.CeedBasisGetQWeights(b[], ref)
332    copy(unsafe_wrap(Array, ref[], getnumqpts(b)))
333end
334
335"""
336    getinterp(b::Basis)
337
338Get the interpolation matrix of the given [`Basis`](@ref). Returns a matrix of size
339`(getnumqpts(b), getnumnodes(b))`.
340"""
341function getinterp(b::Basis)
342    ref = Ref{Ptr{CeedScalar}}()
343    C.CeedBasisGetInterp(b[], ref)
344    q = getnumqpts(b)
345    p = getnumnodes(b)
346    collect(unsafe_wrap(Array, ref[], (p, q))')
347end
348
349"""
350    getinterp1d(b::Basis)
351
352Get the 1D interpolation matrix of the given [`Basis`](@ref). `b` must be a tensor-product
353basis, otherwise this function will fail. Returns a matrix of size `(getnumqpts1d(b),
354getnumnodes1d(b))`.
355"""
356function getinterp1d(b::Basis)
357    ref = Ref{Ptr{CeedScalar}}()
358    C.CeedBasisGetInterp1D(b[], ref)
359    q = getnumqpts1d(b)
360    p = getnumnodes1d(b)
361    collect(unsafe_wrap(Array, ref[], (p, q))')
362end
363
364"""
365    getgad(b::Basis)
366
367Get the gradient matrix of the given [`Basis`](@ref). Returns a tensor of size
368`(getdimension(b), getnumqpts(b), getnumnodes(b))`.
369"""
370function getgrad(b::Basis)
371    ref = Ref{Ptr{CeedScalar}}()
372    C.CeedBasisGetGrad(b[], ref)
373    dim = getdimension(b)
374    q = getnumqpts(b)
375    p = getnumnodes(b)
376    permutedims(unsafe_wrap(Array, ref[], (p, q, dim)), [3, 2, 1])
377end
378
379"""
380    getgrad1d(b::Basis)
381
382Get the 1D derivative matrix of the given [`Basis`](@ref). Returns a matrix of size
383`(getnumqpts(b), getnumnodes(b))`.
384"""
385function getgrad1d(b::Basis)
386    ref = Ref{Ptr{CeedScalar}}()
387    C.CeedBasisGetGrad1D(b[], ref)
388    q = getnumqpts1d(b)
389    p = getnumnodes1d(b)
390    collect(unsafe_wrap(Array, ref[], (p, q))')
391end
392