xref: /libCEED/julia/LibCEED.jl/src/Basis.jl (revision 11b88dda510d0aa70e79dc59ad165e2a5539c3c3)
1abstract type AbstractBasis end
2
3"""
4    BasisCollocated()
5
6Returns the singleton object corresponding to libCEED's `CEED_BASIS_COLLOCATED`.
7"""
8struct BasisCollocated <: AbstractBasis end
9Base.getindex(::BasisCollocated) = C.CEED_BASIS_COLLOCATED[]
10
11"""
12    Basis
13
14Wraps a `CeedBasis` object, representing a finite element basis. A `Basis` object can be
15created using one of:
16
17- [`create_tensor_h1_lagrange_basis`](@ref)
18- [`create_tensor_h1_basis`](@ref)
19- [`create_h1_basis`](@ref)
20- [`create_hdiv_basis`](@ref)
21- [`create_hcurl_basis`](@ref)
22"""
23mutable struct Basis <: AbstractBasis
24    ref::RefValue{C.CeedBasis}
25    function Basis(ref)
26        obj = new(ref)
27        finalizer(obj) do x
28            # ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring), "Finalizing %s.\n", repr(x))
29            destroy(x)
30        end
31        return obj
32    end
33end
34destroy(b::Basis) = C.CeedBasisDestroy(b.ref) # COV_EXCL_LINE
35Base.getindex(b::Basis) = b.ref[]
36Base.show(io::IO, ::MIME"text/plain", b::Basis) = ceed_show(io, b, C.CeedBasisView)
37
38@doc raw"""
39    create_tensor_h1_lagrange_basis(ceed, dim, ncomp, p, q, qmode)
40
41Create a tensor-product Lagrange basis.
42
43# Arguments:
44- `ceed`:  A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
45- `dim`:   Topological dimension of element.
46- `ncomp`: Number of field components (1 for scalar fields).
47- `p`:     Number of Gauss-Lobatto nodes in one dimension.  The polynomial degree of the
48           resulting $Q_k$ element is $k=p-1$.
49- `q`:     Number of quadrature points in one dimension.
50- `qmode`: Distribution of the $q$ quadrature points (affects order of accuracy for the
51           quadrature).
52"""
53function create_tensor_h1_lagrange_basis(c::Ceed, dim, ncomp, p, q, quad_mode::QuadMode)
54    ref = Ref{C.CeedBasis}()
55    C.CeedBasisCreateTensorH1Lagrange(c[], dim, ncomp, p, q, quad_mode, ref)
56    Basis(ref)
57end
58
59@doc raw"""
60    create_tensor_h1_basis(c::Ceed, dim, ncomp, p, q, interp1d, grad1d, qref1d, qweight1d)
61
62Create a tensor-product basis for $H^1$ discretizations.
63
64# Arguments:
65- `ceed`:      A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
66- `dim`:       Topological dimension.
67- `ncomp`:     Number of field components (1 for scalar fields).
68- `p`:         Number of nodes in one dimension.
69- `q`:         Number of quadrature points in one dimension
70- `interp1d`:  Matrix of size `(q, p)` expressing the values of nodal basis functions at
71               quadrature points.
72- `grad1d`:    Matrix of size `(p, q)` expressing derivatives of nodal basis functions at
73               quadrature points.
74- `qref1d`:    Array of length `q` holding the locations of quadrature points on the 1D
75               reference element $[-1, 1]$.
76- `qweight1d`: Array of length `q` holding the quadrature weights on the reference element.
77"""
78function create_tensor_h1_basis(
79    c::Ceed,
80    dim,
81    ncomp,
82    p,
83    q,
84    interp1d::AbstractArray{CeedScalar},
85    grad1d::AbstractArray{CeedScalar},
86    qref1d::AbstractArray{CeedScalar},
87    qweight1d::AbstractArray{CeedScalar},
88)
89    @assert size(interp1d) == (q, p)
90    @assert size(grad1d) == (q, p)
91    @assert length(qref1d) == q
92    @assert length(qweight1d) == q
93
94    # Convert from Julia matrices (column-major) to row-major format
95    interp1d_rowmajor = collect(interp1d')
96    grad1d_rowmajor = collect(grad1d')
97
98    ref = Ref{C.CeedBasis}()
99    C.CeedBasisCreateTensorH1(
100        c[],
101        dim,
102        ncomp,
103        p,
104        q,
105        interp1d_rowmajor,
106        grad1d_rowmajor,
107        qref1d,
108        qweight1d,
109        ref,
110    )
111    Basis(ref)
112end
113
114@doc raw"""
115    create_h1_basis(c::Ceed, topo::Topology, ncomp, nnodes, nqpts, interp, grad, qref, qweight)
116
117Create a non tensor-product basis for $H^1$ discretizations
118
119# Arguments:
120- `ceed`:    A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
121- `topo`:    [`Topology`](@ref) of element, e.g. hypercube, simplex, etc.
122- `ncomp`:   Number of field components (1 for scalar fields).
123- `nnodes`:  Total number of nodes.
124- `nqpts`:   Total number of quadrature points.
125- `interp`:  Matrix of size `(nqpts, nnodes)` expressing the values of nodal basis functions
126             at quadrature points.
127- `grad`:    Array of size `(dim, nqpts, nnodes)` expressing derivatives of nodal basis
128             functions at quadrature points.
129- `qref`:    Array of length `nqpts` holding the locations of quadrature points on the
130             reference element $[-1, 1]$.
131- `qweight`: Array of length `nqpts` holding the quadrature weights on the reference
132             element.
133"""
134function create_h1_basis(
135    c::Ceed,
136    topo::Topology,
137    ncomp,
138    nnodes,
139    nqpts,
140    interp::AbstractArray{CeedScalar},
141    grad::AbstractArray{CeedScalar},
142    qref::AbstractArray{CeedScalar},
143    qweight::AbstractArray{CeedScalar},
144)
145    dim = getdimension(topo)
146    @assert size(interp) == (nqpts, nnodes)
147    @assert size(grad) == (dim, nqpts, nnodes)
148    @assert length(qref) == nqpts
149    @assert length(qweight) == nqpts
150
151    # Convert from Julia matrices and tensors (column-major) to row-major format
152    interp_rowmajor = collect(interp')
153    grad_rowmajor = permutedims(grad, [3, 2, 1])
154
155    ref = Ref{C.CeedBasis}()
156    C.CeedBasisCreateH1(
157        c[],
158        topo,
159        ncomp,
160        nnodes,
161        nqpts,
162        interp_rowmajor,
163        grad_rowmajor,
164        qref,
165        qweight,
166        ref,
167    )
168    Basis(ref)
169end
170
171@doc raw"""
172    create_hdiv_basis(c::Ceed, topo::Topology, ncomp, nnodes, nqpts, interp, div, qref, qweight)
173
174Create a non tensor-product basis for H(div) discretizations
175
176# Arguments:
177- `ceed`:    A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
178- `topo`:    [`Topology`](@ref) of element, e.g. hypercube, simplex, etc.
179- `ncomp`:   Number of field components (1 for scalar fields).
180- `nnodes`:  Total number of nodes.
181- `nqpts`:   Total number of quadrature points.
182- `interp`:  Matrix of size `(dim, nqpts, nnodes)` expressing the values of basis functions
183             at quadrature points.
184- `div`:     Array of size `(nqpts, nnodes)` expressing divergence of basis functions at
185             quadrature points.
186- `qref`:    Array of length `nqpts` holding the locations of quadrature points on the
187             reference element $[-1, 1]$.
188- `qweight`: Array of length `nqpts` holding the quadrature weights on the reference
189             element.
190"""
191function create_hdiv_basis(
192    c::Ceed,
193    topo::Topology,
194    ncomp,
195    nnodes,
196    nqpts,
197    interp::AbstractArray{CeedScalar},
198    div::AbstractArray{CeedScalar},
199    qref::AbstractArray{CeedScalar},
200    qweight::AbstractArray{CeedScalar},
201)
202    dim = getdimension(topo)
203    @assert size(interp) == (dim, nqpts, nnodes)
204    @assert size(div) == (nqpts, nnodes)
205    @assert length(qref) == nqpts
206    @assert length(qweight) == nqpts
207
208    # Convert from Julia matrices and tensors (column-major) to row-major format
209    interp_rowmajor = permutedims(interp, [3, 2, 1])
210    div_rowmajor = collect(div')
211
212    ref = Ref{C.CeedBasis}()
213    C.CeedBasisCreateHdiv(
214        c[],
215        topo,
216        ncomp,
217        nnodes,
218        nqpts,
219        interp_rowmajor,
220        div_rowmajor,
221        qref,
222        qweight,
223        ref,
224    )
225    Basis(ref)
226end
227
228@doc raw"""
229    create_hcurl_basis(c::Ceed, topo::Topology, ncomp, nnodes, nqpts, interp, curl, qref, qweight)
230
231Create a non tensor-product basis for H(curl) discretizations
232
233# Arguments:
234- `ceed`:    A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
235- `topo`:    [`Topology`](@ref) of element, e.g. hypercube, simplex, etc.
236- `ncomp`:   Number of field components (1 for scalar fields).
237- `nnodes`:  Total number of nodes.
238- `nqpts`:   Total number of quadrature points.
239- `interp`:  Matrix of size `(dim, nqpts, nnodes)` expressing the values of basis functions
240             at quadrature points.
241- `curl`:    Matrix of size `(curlcomp, nqpts, nnodes)`, `curlcomp = 1 if dim < 3 else dim`)
242             matrix expressing curl of basis functions at quadrature points.
243- `qref`:    Array of length `nqpts` holding the locations of quadrature points on the
244             reference element $[-1, 1]$.
245- `qweight`: Array of length `nqpts` holding the quadrature weights on the reference
246             element.
247"""
248function create_hcurl_basis(
249    c::Ceed,
250    topo::Topology,
251    ncomp,
252    nnodes,
253    nqpts,
254    interp::AbstractArray{CeedScalar},
255    curl::AbstractArray{CeedScalar},
256    qref::AbstractArray{CeedScalar},
257    qweight::AbstractArray{CeedScalar},
258)
259    dim = getdimension(topo)
260    curlcomp = dim < 3 ? 1 : dim
261    @assert size(interp) == (dim, nqpts, nnodes)
262    @assert size(curl) == (curlcomp, nqpts, nnodes)
263    @assert length(qref) == nqpts
264    @assert length(qweight) == nqpts
265
266    # Convert from Julia matrices and tensors (column-major) to row-major format
267    interp_rowmajor = permutedims(interp, [3, 2, 1])
268    curl_rowmajor = permutedims(curl, [3, 2, 1])
269
270    ref = Ref{C.CeedBasis}()
271    C.CeedBasisCreateHcurl(
272        c[],
273        topo,
274        ncomp,
275        nnodes,
276        nqpts,
277        interp_rowmajor,
278        curl_rowmajor,
279        qref,
280        qweight,
281        ref,
282    )
283    Basis(ref)
284end
285
286"""
287    apply!(b::Basis, nelem, tmode::TransposeMode, emode::EvalMode, u::AbstractCeedVector, v::AbstractCeedVector)
288
289Apply basis evaluation from nodes to quadrature points or vice versa, storing the result in
290the [`CeedVector`](@ref) `v`.
291
292`nelem` specifies the number of elements to apply the basis evaluation to; the backend will
293specify the ordering in CeedElemRestrictionCreateBlocked()
294
295Set `tmode` to `CEED_NOTRANSPOSE` to evaluate from nodes to quadrature or to
296`CEED_TRANSPOSE` to apply the transpose, mapping from quadrature points to nodes.
297
298Set the [`EvalMode`](@ref) `emode` to:
299- `CEED_EVAL_NONE` to use values directly,
300- `CEED_EVAL_INTERP` to use interpolated values,
301- `CEED_EVAL_GRAD` to use gradients,
302- `CEED_EVAL_WEIGHT` to use quadrature weights.
303"""
304function apply!(
305    b::Basis,
306    nelem,
307    tmode::TransposeMode,
308    emode::EvalMode,
309    u::AbstractCeedVector,
310    v::AbstractCeedVector,
311)
312    C.CeedBasisApply(b[], nelem, tmode, emode, u[], v[])
313end
314
315"""
316    apply(b::Basis, u::AbstractVector; nelem=1, tmode=NOTRANSPOSE, emode=EVAL_INTERP)
317
318Performs the same function as the above-defined [`apply!`](@ref apply!(b::Basis, nelem,
319tmode::TransposeMode, emode::EvalMode, u::AbstractCeedVector, v::AbstractCeedVector)), but
320automatically convert from Julia arrays to [`CeedVector`](@ref) for convenience.
321
322The result will be returned in a newly allocated array of the correct size.
323"""
324function apply(b::Basis, u::AbstractVector; nelem=1, tmode=NOTRANSPOSE, emode=EVAL_INTERP)
325    ceed_ref = Ref{C.Ceed}()
326    ccall((:CeedBasisGetCeed, C.libceed), Cint, (C.CeedBasis, Ptr{C.Ceed}), b[], ceed_ref)
327    c = Ceed(ceed_ref)
328
329    u_vec = CeedVector(c, u)
330
331    len_v = (tmode == TRANSPOSE) ? getnumnodes(b) : getnumqpts(b)
332    if emode == EVAL_GRAD
333        len_v *= getdimension(b)
334    end
335
336    v_vec = CeedVector(c, len_v)
337
338    apply!(b, nelem, tmode, emode, u_vec, v_vec)
339    Vector(v_vec)
340end
341
342"""
343    getdimension(b::Basis)
344
345Return the spatial dimension of the given [`Basis`](@ref).
346"""
347function getdimension(b::Basis)
348    dim = Ref{CeedInt}()
349    C.CeedBasisGetDimension(b[], dim)
350    dim[]
351end
352
353"""
354    getdimension(t::Topology)
355
356Return the spatial dimension of the given [`Topology`](@ref).
357"""
358function getdimension(t::Topology)
359    return Int(t) >> 16
360end
361
362"""
363    gettopology(b::Basis)
364
365Return the [`Topology`](@ref) of the given [`Basis`](@ref).
366"""
367function gettopology(b::Basis)
368    topo = Ref{Topology}()
369    C.CeedBasisGetTopology(b[], topo)
370    topo[]
371end
372
373"""
374    getnumcomponents(b::Basis)
375
376Return the number of components of the given [`Basis`](@ref).
377"""
378function getnumcomponents(b::Basis)
379    ncomp = Ref{CeedInt}()
380    C.CeedBasisGetNumComponents(b[], ncomp)
381    ncomp[]
382end
383
384"""
385    getnumnodes(b::Basis)
386
387Return the number of nodes of the given [`Basis`](@ref).
388"""
389function getnumnodes(b::Basis)
390    nnodes = Ref{CeedInt}()
391    C.CeedBasisGetNumNodes(b[], nnodes)
392    nnodes[]
393end
394
395"""
396    getnumnodes1d(b::Basis)
397
398    Return the number of 1D nodes of the given (tensor-product) [`Basis`](@ref).
399"""
400function getnumnodes1d(b::Basis)
401    nnodes1d = Ref{CeedInt}()
402    C.CeedBasisGetNumNodes1D(b[], nnodes1d)
403    nnodes1d[]
404end
405
406"""
407    getnumqpts(b::Basis)
408
409Return the number of quadrature points of the given [`Basis`](@ref).
410"""
411function getnumqpts(b::Basis)
412    nqpts = Ref{CeedInt}()
413    C.CeedBasisGetNumQuadraturePoints(b[], nqpts)
414    nqpts[]
415end
416
417"""
418    getnumqpts1d(b::Basis)
419
420Return the number of 1D quadrature points of the given (tensor-product) [`Basis`](@ref).
421"""
422function getnumqpts1d(b::Basis)
423    nqpts1d = Ref{CeedInt}()
424    C.CeedBasisGetNumQuadraturePoints1D(b[], nqpts1d)
425    nqpts1d[]
426end
427
428"""
429    getqref(b::Basis)
430
431Get the reference coordinates of quadrature points (in `dim` dimensions) of the given
432[`Basis`](@ref).
433"""
434function getqref(b::Basis)
435    istensor = Ref{Bool}()
436    C.CeedBasisIsTensor(b[], istensor)
437    ref = Ref{Ptr{CeedScalar}}()
438    C.CeedBasisGetQRef(b[], ref)
439    copy(
440        unsafe_wrap(
441            Array,
442            ref[],
443            istensor[] ? getnumqpts1d(b) : (getnumqpts(b)*getdimension(b)),
444        ),
445    )
446end
447
448"""
449    getqref(b::Basis)
450
451Get the quadrature weights of quadrature points (in `dim` dimensions) of the given
452[`Basis`](@ref).
453"""
454function getqweights(b::Basis)
455    istensor = Ref{Bool}()
456    C.CeedBasisIsTensor(b[], istensor)
457    ref = Ref{Ptr{CeedScalar}}()
458    C.CeedBasisGetQWeights(b[], ref)
459    copy(unsafe_wrap(Array, ref[], istensor[] ? getnumqpts1d(b) : getnumqpts(b)))
460end
461
462@doc raw"""
463    getinterp(b::Basis)
464
465Get the interpolation matrix of the given [`Basis`](@ref). Returns a matrix of size
466`(getnumqpts(b), getnumnodes(b))` for a given $H^1$ basis or `(getdimension(b),
467getnumqpts(b), getnumnodes(b))` for a given vector $H(div)$ or $H(curl)$ basis.
468"""
469function getinterp(b::Basis)
470    ref = Ref{Ptr{CeedScalar}}()
471    C.CeedBasisGetInterp(b[], ref)
472    q = getnumqpts(b)
473    p = getnumnodes(b)
474    qcomp = Ref{CeedInt}()
475    C.CeedBasisGetNumQuadratureComponents(b[], C.CEED_EVAL_INTERP, qcomp)
476    if qcomp[] == 1
477        collect(unsafe_wrap(Array, ref[], (p, q))')
478    else
479        permutedims(unsafe_wrap(Array, ref[], (p, q, qcomp[])), [3, 2, 1])
480    end
481end
482
483"""
484    getinterp1d(b::Basis)
485
486Get the 1D interpolation matrix of the given [`Basis`](@ref). `b` must be a tensor-product
487basis, otherwise this function will fail. Returns a matrix of size `(getnumqpts1d(b),
488getnumnodes1d(b))`.
489"""
490function getinterp1d(b::Basis)
491    ref = Ref{Ptr{CeedScalar}}()
492    C.CeedBasisGetInterp1D(b[], ref)
493    q = getnumqpts1d(b)
494    p = getnumnodes1d(b)
495    collect(unsafe_wrap(Array, ref[], (p, q))')
496end
497
498"""
499    getgrad(b::Basis)
500
501Get the gradient matrix of the given [`Basis`](@ref). Returns a tensor of size
502`(getdimension(b), getnumqpts(b), getnumnodes(b))`.
503"""
504function getgrad(b::Basis)
505    ref = Ref{Ptr{CeedScalar}}()
506    C.CeedBasisGetGrad(b[], ref)
507    dim = getdimension(b)
508    q = getnumqpts(b)
509    p = getnumnodes(b)
510    permutedims(unsafe_wrap(Array, ref[], (p, q, dim)), [3, 2, 1])
511end
512
513"""
514    getgrad1d(b::Basis)
515
516Get the 1D derivative matrix of the given [`Basis`](@ref). Returns a matrix of size
517`(getnumqpts(b), getnumnodes(b))`.
518"""
519function getgrad1d(b::Basis)
520    ref = Ref{Ptr{CeedScalar}}()
521    C.CeedBasisGetGrad1D(b[], ref)
522    q = getnumqpts1d(b)
523    p = getnumnodes1d(b)
524    collect(unsafe_wrap(Array, ref[], (p, q))')
525end
526
527"""
528    getdiv(b::Basis)
529
530Get the divergence matrix of the given [`Basis`](@ref). Returns a tensor of size
531`(getnumqpts(b), getnumnodes(b))`.
532"""
533function getdiv(b::Basis)
534    ref = Ref{Ptr{CeedScalar}}()
535    C.CeedBasisGetDiv(b[], ref)
536    q = getnumqpts(b)
537    p = getnumnodes(b)
538    collect(unsafe_wrap(Array, ref[], (p, q))')
539end
540
541"""
542    getcurl(b::Basis)
543
544Get the curl matrix of the given [`Basis`](@ref). Returns a tensor of size
545`(curlcomp, getnumqpts(b), getnumnodes(b))`, `curlcomp = 1 if getdimension(b) < 3 else
546getdimension(b)`.
547"""
548function getcurl(b::Basis)
549    ref = Ref{Ptr{CeedScalar}}()
550    C.CeedBasisGetCurl(b[], ref)
551    q = getnumqpts(b)
552    p = getnumnodes(b)
553    qcomp = Ref{CeedInt}()
554    C.CeedBasisGetNumQuadratureComponents(b[], C.CEED_EVAL_CURL, qcomp)
555    permutedims(unsafe_wrap(Array, ref[], (p, q, qcomp[])), [3, 2, 1])
556end
557