xref: /libCEED/julia/LibCEED.jl/src/Basis.jl (revision ca94c3ddc8f82b7d93a79f9e4812e99b8be840ff)
1abstract type AbstractBasis end
2
3"""
4    BasisNone()
5
6Returns the singleton object corresponding to libCEED's `CEED_BASIS_NONE`.
7"""
8struct BasisNone <: AbstractBasis end
9Base.getindex(::BasisNone) = C.CEED_BASIS_NONE[]
10
11"""
12    Basis
13
14Wraps a `CeedBasis` object, representing a finite element basis. A `Basis` object can be
15created using one of:
16
17- [`create_tensor_h1_lagrange_basis`](@ref)
18- [`create_tensor_h1_basis`](@ref)
19- [`create_h1_basis`](@ref)
20- [`create_hdiv_basis`](@ref)
21- [`create_hcurl_basis`](@ref)
22"""
23mutable struct Basis <: AbstractBasis
24    ref::RefValue{C.CeedBasis}
25    function Basis(ref)
26        obj = new(ref)
27        finalizer(obj) do x
28            # ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring), "Finalizing %s.\n", repr(x))
29            destroy(x)
30        end
31        return obj
32    end
33end
34destroy(b::Basis) = C.CeedBasisDestroy(b.ref) # COV_EXCL_LINE
35Base.getindex(b::Basis) = b.ref[]
36Base.show(io::IO, ::MIME"text/plain", b::Basis) = ceed_show(io, b, C.CeedBasisView)
37
38@doc raw"""
39    create_tensor_h1_lagrange_basis(ceed, dim, ncomp, p, q, qmode)
40
41Create a tensor-product Lagrange basis.
42
43# Arguments:
44- `ceed`:  A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
45- `dim`:   Topological dimension of element.
46- `ncomp`: Number of field components (1 for scalar fields).
47- `p`:     Number of Gauss-Lobatto nodes in one dimension.  The polynomial degree of the
48           resulting $Q_k$ element is $k=p-1$.
49- `q`:     Number of quadrature points in one dimension.
50- `qmode`: Distribution of the $q$ quadrature points (affects order of accuracy for the
51           quadrature).
52"""
53function create_tensor_h1_lagrange_basis(c::Ceed, dim, ncomp, p, q, quad_mode::QuadMode)
54    ref = Ref{C.CeedBasis}()
55    C.CeedBasisCreateTensorH1Lagrange(c[], dim, ncomp, p, q, quad_mode, ref)
56    Basis(ref)
57end
58
59@doc raw"""
60    create_tensor_h1_basis(c::Ceed, dim, ncomp, p, q, interp1d, grad1d, qref1d, qweight1d)
61
62Create a tensor-product basis for $H^1$ discretizations.
63
64# Arguments:
65- `ceed`:      A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
66- `dim`:       Topological dimension.
67- `ncomp`:     Number of field components (1 for scalar fields).
68- `p`:         Number of nodes in one dimension.
69- `q`:         Number of quadrature points in one dimension
70- `interp1d`:  Matrix of size `(q, p)` expressing the values of nodal basis functions at
71               quadrature points.
72- `grad1d`:    Matrix of size `(p, q)` expressing derivatives of nodal basis functions at
73               quadrature points.
74- `qref1d`:    Array of length `q` holding the locations of quadrature points on the 1D
75               reference element $[-1, 1]$.
76- `qweight1d`: Array of length `q` holding the quadrature weights on the reference element.
77"""
78function create_tensor_h1_basis(
79    c::Ceed,
80    dim,
81    ncomp,
82    p,
83    q,
84    interp1d::AbstractArray{CeedScalar},
85    grad1d::AbstractArray{CeedScalar},
86    qref1d::AbstractArray{CeedScalar},
87    qweight1d::AbstractArray{CeedScalar},
88)
89    @assert size(interp1d) == (q, p)
90    @assert size(grad1d) == (q, p)
91    @assert length(qref1d) == q
92    @assert length(qweight1d) == q
93
94    # Convert from Julia matrices (column-major) to row-major format
95    interp1d_rowmajor = collect(interp1d')
96    grad1d_rowmajor = collect(grad1d')
97
98    ref = Ref{C.CeedBasis}()
99    C.CeedBasisCreateTensorH1(
100        c[],
101        dim,
102        ncomp,
103        p,
104        q,
105        interp1d_rowmajor,
106        grad1d_rowmajor,
107        qref1d,
108        qweight1d,
109        ref,
110    )
111    Basis(ref)
112end
113
114@doc raw"""
115    create_h1_basis(c::Ceed, topo::Topology, ncomp, nnodes, nqpts, interp, grad, qref, qweight)
116
117Create a non tensor-product basis for $H^1$ discretizations
118
119# Arguments:
120- `ceed`:    A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
121- `topo`:    [`Topology`](@ref) of element, e.g. hypercube, simplex, etc.
122- `ncomp`:   Number of field components (1 for scalar fields).
123- `nnodes`:  Total number of nodes.
124- `nqpts`:   Total number of quadrature points.
125- `interp`:  Matrix of size `(nqpts, nnodes)` expressing the values of nodal basis functions
126             at quadrature points.
127- `grad`:    Array of size `(dim, nqpts, nnodes)` expressing derivatives of nodal basis
128             functions at quadrature points.
129- `qref`:    Matrix of size `(dim, nqpts)` holding the locations of quadrature points on the
130             reference element $[-1, 1]$.
131- `qweight`: Array of length `nqpts` holding the quadrature weights on the reference
132             element.
133"""
134function create_h1_basis(
135    c::Ceed,
136    topo::Topology,
137    ncomp,
138    nnodes,
139    nqpts,
140    interp::AbstractArray{CeedScalar},
141    grad::AbstractArray{CeedScalar},
142    qref::AbstractArray{CeedScalar},
143    qweight::AbstractArray{CeedScalar},
144)
145    dim = getdimension(topo)
146    @assert size(interp) == (nqpts, nnodes)
147    @assert size(grad) == (dim, nqpts, nnodes)
148    @assert size(qref) == (dim, nqpts)
149    @assert length(qweight) == nqpts
150
151    # Convert from Julia matrices and tensors (column-major) to row-major format
152    interp_rowmajor = collect(interp')
153    grad_rowmajor = permutedims(grad, [3, 2, 1])
154    qref_rowmajor = collect(qref')
155
156    ref = Ref{C.CeedBasis}()
157    C.CeedBasisCreateH1(
158        c[],
159        topo,
160        ncomp,
161        nnodes,
162        nqpts,
163        interp_rowmajor,
164        grad_rowmajor,
165        qref_rowmajor,
166        qweight,
167        ref,
168    )
169    Basis(ref)
170end
171
172@doc raw"""
173    create_hdiv_basis(c::Ceed, topo::Topology, ncomp, nnodes, nqpts, interp, div, qref, qweight)
174
175Create a non tensor-product basis for H(div) discretizations
176
177# Arguments:
178- `ceed`:    A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
179- `topo`:    [`Topology`](@ref) of element, e.g. hypercube, simplex, etc.
180- `ncomp`:   Number of field components (1 for scalar fields).
181- `nnodes`:  Total number of nodes.
182- `nqpts`:   Total number of quadrature points.
183- `interp`:  Matrix of size `(dim, nqpts, nnodes)` expressing the values of basis functions
184             at quadrature points.
185- `div`:     Array of size `(nqpts, nnodes)` expressing divergence of basis functions at
186             quadrature points.
187- `qref`:    Matrix of size `(dim, nqpts)` holding the locations of quadrature points on the
188             reference element $[-1, 1]$.
189- `qweight`: Array of length `nqpts` holding the quadrature weights on the reference
190             element.
191"""
192function create_hdiv_basis(
193    c::Ceed,
194    topo::Topology,
195    ncomp,
196    nnodes,
197    nqpts,
198    interp::AbstractArray{CeedScalar},
199    div::AbstractArray{CeedScalar},
200    qref::AbstractArray{CeedScalar},
201    qweight::AbstractArray{CeedScalar},
202)
203    dim = getdimension(topo)
204    @assert size(interp) == (dim, nqpts, nnodes)
205    @assert size(div) == (nqpts, nnodes)
206    @assert size(qref) == (dim, nqpts)
207    @assert length(qweight) == nqpts
208
209    # Convert from Julia matrices and tensors (column-major) to row-major format
210    interp_rowmajor = permutedims(interp, [3, 2, 1])
211    div_rowmajor = collect(div')
212    qref_rowmajor = collect(qref')
213
214    ref = Ref{C.CeedBasis}()
215    C.CeedBasisCreateHdiv(
216        c[],
217        topo,
218        ncomp,
219        nnodes,
220        nqpts,
221        interp_rowmajor,
222        div_rowmajor,
223        qref_rowmajor,
224        qweight,
225        ref,
226    )
227    Basis(ref)
228end
229
230@doc raw"""
231    create_hcurl_basis(c::Ceed, topo::Topology, ncomp, nnodes, nqpts, interp, curl, qref, qweight)
232
233Create a non tensor-product basis for H(curl) discretizations
234
235# Arguments:
236- `ceed`:    A [`Ceed`](@ref) object where the [`Basis`](@ref) will be created.
237- `topo`:    [`Topology`](@ref) of element, e.g. hypercube, simplex, etc.
238- `ncomp`:   Number of field components (1 for scalar fields).
239- `nnodes`:  Total number of nodes.
240- `nqpts`:   Total number of quadrature points.
241- `interp`:  Matrix of size `(dim, nqpts, nnodes)` expressing the values of basis functions
242             at quadrature points.
243- `curl`:    Matrix of size `(curlcomp, nqpts, nnodes)`, `curlcomp = 1 if dim < 3 else dim`)
244             matrix expressing curl of basis functions at quadrature points.
245- `qref`:    Matrix of size `(dim, nqpts)` holding the locations of quadrature points on the
246             reference element $[-1, 1]$.
247- `qweight`: Array of length `nqpts` holding the quadrature weights on the reference
248             element.
249"""
250function create_hcurl_basis(
251    c::Ceed,
252    topo::Topology,
253    ncomp,
254    nnodes,
255    nqpts,
256    interp::AbstractArray{CeedScalar},
257    curl::AbstractArray{CeedScalar},
258    qref::AbstractArray{CeedScalar},
259    qweight::AbstractArray{CeedScalar},
260)
261    dim = getdimension(topo)
262    curlcomp = dim < 3 ? 1 : dim
263    @assert size(interp) == (dim, nqpts, nnodes)
264    @assert size(curl) == (curlcomp, nqpts, nnodes)
265    @assert size(qref) == (dim, nqpts)
266    @assert length(qweight) == nqpts
267
268    # Convert from Julia matrices and tensors (column-major) to row-major format
269    interp_rowmajor = permutedims(interp, [3, 2, 1])
270    curl_rowmajor = permutedims(curl, [3, 2, 1])
271    qref_rowmajor = collect(qref')
272
273    ref = Ref{C.CeedBasis}()
274    C.CeedBasisCreateHcurl(
275        c[],
276        topo,
277        ncomp,
278        nnodes,
279        nqpts,
280        interp_rowmajor,
281        curl_rowmajor,
282        qref_rowmajor,
283        qweight,
284        ref,
285    )
286    Basis(ref)
287end
288
289"""
290    apply!(b::Basis, nelem, tmode::TransposeMode, emode::EvalMode, u::AbstractCeedVector, v::AbstractCeedVector)
291
292Apply basis evaluation from nodes to quadrature points or vice versa, storing the result in
293the [`CeedVector`](@ref) `v`.
294
295`nelem` specifies the number of elements to apply the basis evaluation to; the backend will
296specify the ordering in CeedElemRestrictionCreateBlocked()
297
298Set `tmode` to `CEED_NOTRANSPOSE` to evaluate from nodes to quadrature or to
299`CEED_TRANSPOSE` to apply the transpose, mapping from quadrature points to nodes.
300
301Set the [`EvalMode`](@ref) `emode` to:
302- `CEED_EVAL_NONE` to use values directly,
303- `CEED_EVAL_INTERP` to use interpolated values,
304- `CEED_EVAL_GRAD` to use gradients,
305- `CEED_EVAL_WEIGHT` to use quadrature weights.
306"""
307function apply!(
308    b::Basis,
309    nelem,
310    tmode::TransposeMode,
311    emode::EvalMode,
312    u::AbstractCeedVector,
313    v::AbstractCeedVector,
314)
315    C.CeedBasisApply(b[], nelem, tmode, emode, u[], v[])
316end
317
318"""
319    apply(b::Basis, u::AbstractVector; nelem=1, tmode=NOTRANSPOSE, emode=EVAL_INTERP)
320
321Performs the same function as the above-defined [`apply!`](@ref apply!(b::Basis, nelem,
322tmode::TransposeMode, emode::EvalMode, u::AbstractCeedVector, v::AbstractCeedVector)), but
323automatically convert from Julia arrays to [`CeedVector`](@ref) for convenience.
324
325The result will be returned in a newly allocated array of the correct size.
326"""
327function apply(b::Basis, u::AbstractVector; nelem=1, tmode=NOTRANSPOSE, emode=EVAL_INTERP)
328    ceed_ref = Ref{C.Ceed}()
329    ccall((:CeedBasisGetCeed, C.libceed), Cint, (C.CeedBasis, Ptr{C.Ceed}), b[], ceed_ref)
330    c = Ceed(ceed_ref)
331
332    u_vec = CeedVector(c, u)
333
334    qcomp = Ref{CeedInt}()
335    C.CeedBasisGetNumQuadratureComponents(b[], emode, qcomp)
336    len_v = (tmode == TRANSPOSE) ? getnumnodes(b) : qcomp[]*getnumqpts(b)
337
338    v_vec = CeedVector(c, len_v)
339
340    apply!(b, nelem, tmode, emode, u_vec, v_vec)
341    Vector(v_vec)
342end
343
344"""
345    getdimension(b::Basis)
346
347Return the spatial dimension of the given [`Basis`](@ref).
348"""
349function getdimension(b::Basis)
350    dim = Ref{CeedInt}()
351    C.CeedBasisGetDimension(b[], dim)
352    dim[]
353end
354
355"""
356    getdimension(t::Topology)
357
358Return the spatial dimension of the given [`Topology`](@ref).
359"""
360function getdimension(t::Topology)
361    return Int(t) >> 16
362end
363
364"""
365    gettopology(b::Basis)
366
367Return the [`Topology`](@ref) of the given [`Basis`](@ref).
368"""
369function gettopology(b::Basis)
370    topo = Ref{Topology}()
371    C.CeedBasisGetTopology(b[], topo)
372    topo[]
373end
374
375"""
376    getnumcomponents(b::Basis)
377
378Return the number of components of the given [`Basis`](@ref).
379"""
380function getnumcomponents(b::Basis)
381    ncomp = Ref{CeedInt}()
382    C.CeedBasisGetNumComponents(b[], ncomp)
383    ncomp[]
384end
385
386"""
387    getnumnodes(b::Basis)
388
389Return the number of nodes of the given [`Basis`](@ref).
390"""
391function getnumnodes(b::Basis)
392    nnodes = Ref{CeedInt}()
393    C.CeedBasisGetNumNodes(b[], nnodes)
394    nnodes[]
395end
396
397"""
398    getnumnodes1d(b::Basis)
399
400    Return the number of 1D nodes of the given (tensor-product) [`Basis`](@ref).
401"""
402function getnumnodes1d(b::Basis)
403    nnodes1d = Ref{CeedInt}()
404    C.CeedBasisGetNumNodes1D(b[], nnodes1d)
405    nnodes1d[]
406end
407
408"""
409    getnumqpts(b::Basis)
410
411Return the number of quadrature points of the given [`Basis`](@ref).
412"""
413function getnumqpts(b::Basis)
414    nqpts = Ref{CeedInt}()
415    C.CeedBasisGetNumQuadraturePoints(b[], nqpts)
416    nqpts[]
417end
418
419"""
420    getnumqpts1d(b::Basis)
421
422Return the number of 1D quadrature points of the given (tensor-product) [`Basis`](@ref).
423"""
424function getnumqpts1d(b::Basis)
425    nqpts1d = Ref{CeedInt}()
426    C.CeedBasisGetNumQuadraturePoints1D(b[], nqpts1d)
427    nqpts1d[]
428end
429
430"""
431    getqref(b::Basis)
432
433Get the reference coordinates of quadrature points (in `dim` dimensions) of the given
434[`Basis`](@ref).
435"""
436function getqref(b::Basis)
437    istensor = Ref{Bool}()
438    C.CeedBasisIsTensor(b[], istensor)
439    ref = Ref{Ptr{CeedScalar}}()
440    C.CeedBasisGetQRef(b[], ref)
441    copy(
442        istensor[] ? unsafe_wrap(Array, ref[], getnumqpts1d(b)) :
443        unsafe_wrap(Array, ref[], (getnumqpts(b), getdimension(b)))',
444    )
445end
446
447"""
448    getqref(b::Basis)
449
450Get the quadrature weights of quadrature points (in `dim` dimensions) of the given
451[`Basis`](@ref).
452"""
453function getqweights(b::Basis)
454    istensor = Ref{Bool}()
455    C.CeedBasisIsTensor(b[], istensor)
456    ref = Ref{Ptr{CeedScalar}}()
457    C.CeedBasisGetQWeights(b[], ref)
458    copy(unsafe_wrap(Array, ref[], istensor[] ? getnumqpts1d(b) : getnumqpts(b)))
459end
460
461@doc raw"""
462    getinterp(b::Basis)
463
464Get the interpolation matrix of the given [`Basis`](@ref). Returns a matrix of size
465`(getnumqpts(b), getnumnodes(b))` for a given $H^1$ basis or `(getdimension(b),
466getnumqpts(b), getnumnodes(b))` for a given vector $H(div)$ or $H(curl)$ basis.
467"""
468function getinterp(b::Basis)
469    ref = Ref{Ptr{CeedScalar}}()
470    C.CeedBasisGetInterp(b[], ref)
471    q = getnumqpts(b)
472    p = getnumnodes(b)
473    qcomp = Ref{CeedInt}()
474    C.CeedBasisGetNumQuadratureComponents(b[], C.CEED_EVAL_INTERP, qcomp)
475    if qcomp[] == 1
476        collect(unsafe_wrap(Array, ref[], (p, q))')
477    else
478        permutedims(unsafe_wrap(Array, ref[], (p, q, qcomp[])), [3, 2, 1])
479    end
480end
481
482"""
483    getinterp1d(b::Basis)
484
485Get the 1D interpolation matrix of the given [`Basis`](@ref). `b` must be a tensor-product
486basis, otherwise this function will fail. Returns a matrix of size `(getnumqpts1d(b),
487getnumnodes1d(b))`.
488"""
489function getinterp1d(b::Basis)
490    ref = Ref{Ptr{CeedScalar}}()
491    C.CeedBasisGetInterp1D(b[], ref)
492    q = getnumqpts1d(b)
493    p = getnumnodes1d(b)
494    collect(unsafe_wrap(Array, ref[], (p, q))')
495end
496
497"""
498    getgrad(b::Basis)
499
500Get the gradient matrix of the given [`Basis`](@ref). Returns a tensor of size
501`(getdimension(b), getnumqpts(b), getnumnodes(b))`.
502"""
503function getgrad(b::Basis)
504    ref = Ref{Ptr{CeedScalar}}()
505    C.CeedBasisGetGrad(b[], ref)
506    dim = getdimension(b)
507    q = getnumqpts(b)
508    p = getnumnodes(b)
509    permutedims(unsafe_wrap(Array, ref[], (p, q, dim)), [3, 2, 1])
510end
511
512"""
513    getgrad1d(b::Basis)
514
515Get the 1D derivative matrix of the given [`Basis`](@ref). Returns a matrix of size
516`(getnumqpts(b), getnumnodes(b))`.
517"""
518function getgrad1d(b::Basis)
519    ref = Ref{Ptr{CeedScalar}}()
520    C.CeedBasisGetGrad1D(b[], ref)
521    q = getnumqpts1d(b)
522    p = getnumnodes1d(b)
523    collect(unsafe_wrap(Array, ref[], (p, q))')
524end
525
526"""
527    getdiv(b::Basis)
528
529Get the divergence matrix of the given [`Basis`](@ref). Returns a tensor of size
530`(getnumqpts(b), getnumnodes(b))`.
531"""
532function getdiv(b::Basis)
533    ref = Ref{Ptr{CeedScalar}}()
534    C.CeedBasisGetDiv(b[], ref)
535    q = getnumqpts(b)
536    p = getnumnodes(b)
537    collect(unsafe_wrap(Array, ref[], (p, q))')
538end
539
540"""
541    getcurl(b::Basis)
542
543Get the curl matrix of the given [`Basis`](@ref). Returns a tensor of size
544`(curlcomp, getnumqpts(b), getnumnodes(b))`, `curlcomp = 1 if getdimension(b) < 3 else
545getdimension(b)`.
546"""
547function getcurl(b::Basis)
548    ref = Ref{Ptr{CeedScalar}}()
549    C.CeedBasisGetCurl(b[], ref)
550    q = getnumqpts(b)
551    p = getnumnodes(b)
552    qcomp = Ref{CeedInt}()
553    C.CeedBasisGetNumQuadratureComponents(b[], C.CEED_EVAL_CURL, qcomp)
554    permutedims(unsafe_wrap(Array, ref[], (p, q, qcomp[])), [3, 2, 1])
555end
556