xref: /libCEED/julia/LibCEED.jl/src/UserQFunction.jl (revision 44554ea01e90fce366fc2a203c44be15754a38d6)
1*44554ea0SWill Paznerstruct UserQFunction{F,K}
2*44554ea0SWill Pazner    f::F
3*44554ea0SWill Pazner    fptr::Ptr{Nothing}
4*44554ea0SWill Pazner    kf::K
5*44554ea0SWill Pazner    cuf::Union{Nothing,Ptr{Nothing}}
6*44554ea0SWill Paznerend
7*44554ea0SWill Pazner
8*44554ea0SWill Pazner@inline function extract_context(ptr, ::Type{T}) where {T}
9*44554ea0SWill Pazner    unsafe_load(Ptr{T}(ptr))
10*44554ea0SWill Paznerend
11*44554ea0SWill Pazner
12*44554ea0SWill Pazner@inline function extract_array(ptr, idx, dims)
13*44554ea0SWill Pazner    UnsafeArray(Ptr{CeedScalar}(unsafe_load(ptr, idx)), dims)
14*44554ea0SWill Paznerend
15*44554ea0SWill Pazner
16*44554ea0SWill Paznerfunction generate_user_qfunction(
17*44554ea0SWill Pazner    ceed,
18*44554ea0SWill Pazner    def_module,
19*44554ea0SWill Pazner    qf_name,
20*44554ea0SWill Pazner    constants,
21*44554ea0SWill Pazner    array_names,
22*44554ea0SWill Pazner    ctx,
23*44554ea0SWill Pazner    dims_in,
24*44554ea0SWill Pazner    dims_out,
25*44554ea0SWill Pazner    body,
26*44554ea0SWill Pazner)
27*44554ea0SWill Pazner    idx = gensym(:i)
28*44554ea0SWill Pazner    Q = gensym(:Q)
29*44554ea0SWill Pazner    ctx_ptr = gensym(:ctx_ptr)
30*44554ea0SWill Pazner    in_ptr = gensym(:in_ptr)
31*44554ea0SWill Pazner    out_ptr = gensym(:out_ptr)
32*44554ea0SWill Pazner
33*44554ea0SWill Pazner    const_assignments = Vector{Expr}(undef, length(constants))
34*44554ea0SWill Pazner    for (i, c) ∈ enumerate(constants)
35*44554ea0SWill Pazner        const_assignments[i] = :($(c[1]) = $(c[2]))
36*44554ea0SWill Pazner    end
37*44554ea0SWill Pazner
38*44554ea0SWill Pazner    narrays = length(array_names)
39*44554ea0SWill Pazner    arrays = Vector{Expr}(undef, narrays)
40*44554ea0SWill Pazner    array_views = Vector{Expr}(undef, narrays)
41*44554ea0SWill Pazner    n_in = length(dims_in)
42*44554ea0SWill Pazner    for (i, arr_name) ∈ enumerate(array_names)
43*44554ea0SWill Pazner        i_inout = (i <= n_in) ? i : i - n_in
44*44554ea0SWill Pazner        dims = (i <= n_in) ? dims_in[i] : dims_out[i-n_in]
45*44554ea0SWill Pazner        ptr = (i <= n_in) ? in_ptr : out_ptr
46*44554ea0SWill Pazner        arr_name_gen = gensym(arr_name)
47*44554ea0SWill Pazner        arrays[i] = :($arr_name_gen = extract_array($ptr, $i_inout, (Int($Q), $(dims...))))
48*44554ea0SWill Pazner        ndims = length(dims)
49*44554ea0SWill Pazner        slice = Expr(:ref, arr_name_gen, idx, (:(:) for i = 1:ndims)...)
50*44554ea0SWill Pazner        if i <= n_in
51*44554ea0SWill Pazner            if ndims == 0
52*44554ea0SWill Pazner                array_views[i] = :($arr_name = $slice)
53*44554ea0SWill Pazner            else
54*44554ea0SWill Pazner                S = Tuple{dims...}
55*44554ea0SWill Pazner                array_views[i] = :($arr_name = LibCEED.SArray{$S}(@view $slice))
56*44554ea0SWill Pazner            end
57*44554ea0SWill Pazner        else
58*44554ea0SWill Pazner            array_views[i] = :($arr_name = @view $slice)
59*44554ea0SWill Pazner        end
60*44554ea0SWill Pazner    end
61*44554ea0SWill Pazner
62*44554ea0SWill Pazner    if isnothing(ctx)
63*44554ea0SWill Pazner        ctx_assignment = nothing
64*44554ea0SWill Pazner    else
65*44554ea0SWill Pazner        ctx_assignment = :($(ctx.name) = extract_context($ctx_ptr, $(ctx.type)))
66*44554ea0SWill Pazner    end
67*44554ea0SWill Pazner
68*44554ea0SWill Pazner    qf1 = gensym(qf_name)
69*44554ea0SWill Pazner    f = Core.eval(
70*44554ea0SWill Pazner        def_module,
71*44554ea0SWill Pazner        quote
72*44554ea0SWill Pazner            @inline function $qf1(
73*44554ea0SWill Pazner                $ctx_ptr::Ptr{Cvoid},
74*44554ea0SWill Pazner                $Q::CeedInt,
75*44554ea0SWill Pazner                $in_ptr::Ptr{Ptr{CeedScalar}},
76*44554ea0SWill Pazner                $out_ptr::Ptr{Ptr{CeedScalar}},
77*44554ea0SWill Pazner            )
78*44554ea0SWill Pazner                $(const_assignments...)
79*44554ea0SWill Pazner                $ctx_assignment
80*44554ea0SWill Pazner                $(arrays...)
81*44554ea0SWill Pazner                @inbounds @simd for $idx = 1:$Q
82*44554ea0SWill Pazner                    $(array_views...)
83*44554ea0SWill Pazner                    $body
84*44554ea0SWill Pazner                end
85*44554ea0SWill Pazner                CeedInt(0)
86*44554ea0SWill Pazner            end
87*44554ea0SWill Pazner        end,
88*44554ea0SWill Pazner    )
89*44554ea0SWill Pazner    f_qn = QuoteNode(f)
90*44554ea0SWill Pazner    rt = :CeedInt
91*44554ea0SWill Pazner    at = :(Core.svec(Ptr{Cvoid}, CeedInt, Ptr{Ptr{CeedScalar}}, Ptr{Ptr{CeedScalar}}))
92*44554ea0SWill Pazner    fptr = eval(Expr(:cfunction, Ptr{Cvoid}, f_qn, rt, at, QuoteNode(:ccall)))
93*44554ea0SWill Pazner
94*44554ea0SWill Pazner    # COV_EXCL_START
95*44554ea0SWill Pazner    if iscuda(ceed)
96*44554ea0SWill Pazner        getresource(ceed) == "/gpu/cuda/gen" && error(string(
97*44554ea0SWill Pazner            "/gpu/cuda/gen is not compatible with user Q-functions defined with ",
98*44554ea0SWill Pazner            "libCEED.jl.\nPlease use a different backend, for example: /gpu/cuda/shared ",
99*44554ea0SWill Pazner            "or /gpu/cuda/ref",
100*44554ea0SWill Pazner        ))
101*44554ea0SWill Pazner        if cuda_is_loaded
102*44554ea0SWill Pazner            !has_cuda() && error("No valid CUDA installation found")
103*44554ea0SWill Pazner            qf2 = gensym(qf_name)
104*44554ea0SWill Pazner            kf = Core.eval(
105*44554ea0SWill Pazner                def_module,
106*44554ea0SWill Pazner                quote
107*44554ea0SWill Pazner                    @inline function $qf2($ctx_ptr::Ptr{Cvoid}, $(array_names...))
108*44554ea0SWill Pazner                        $(const_assignments...)
109*44554ea0SWill Pazner                        $ctx_assignment
110*44554ea0SWill Pazner                        $body
111*44554ea0SWill Pazner                        nothing
112*44554ea0SWill Pazner                    end
113*44554ea0SWill Pazner                end,
114*44554ea0SWill Pazner            )
115*44554ea0SWill Pazner            cuf = mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out)
116*44554ea0SWill Pazner        else
117*44554ea0SWill Pazner            error(string(
118*44554ea0SWill Pazner                "User Q-functions with CUDA backends require the CUDA.jl package to be ",
119*44554ea0SWill Pazner                "loaded.\nThe libCEED backend is: $(getresource(ceed))\n",
120*44554ea0SWill Pazner                "Please ensure that the CUDA.jl package is installed and loaded.",
121*44554ea0SWill Pazner            ))
122*44554ea0SWill Pazner        end
123*44554ea0SWill Pazner    else
124*44554ea0SWill Pazner        kf = nothing
125*44554ea0SWill Pazner        cuf = nothing
126*44554ea0SWill Pazner    end
127*44554ea0SWill Pazner    # COV_EXCL_STOP
128*44554ea0SWill Pazner
129*44554ea0SWill Pazner    UserQFunction(f, fptr, kf, cuf)
130*44554ea0SWill Paznerend
131*44554ea0SWill Pazner
132*44554ea0SWill Paznerfunction meta_user_qfunction(ceed, def_module, qf, args)
133*44554ea0SWill Pazner    qf_name = Meta.quot(qf)
134*44554ea0SWill Pazner
135*44554ea0SWill Pazner    ctx = nothing
136*44554ea0SWill Pazner    constants = Expr[]
137*44554ea0SWill Pazner    dims_in = Expr[]
138*44554ea0SWill Pazner    dims_out = Expr[]
139*44554ea0SWill Pazner    names_in = Symbol[]
140*44554ea0SWill Pazner    names_out = Symbol[]
141*44554ea0SWill Pazner
142*44554ea0SWill Pazner    for a ∈ args[1:end-1]
143*44554ea0SWill Pazner        if Meta.isexpr(a, :(=))
144*44554ea0SWill Pazner            a1 = Meta.quot(a.args[1])
145*44554ea0SWill Pazner            a2 = esc(a.args[2])
146*44554ea0SWill Pazner            push!(constants, :(($a1, $a2)))
147*44554ea0SWill Pazner        elseif Meta.isexpr(a, :tuple)
148*44554ea0SWill Pazner            arr_name = a.args[1]
149*44554ea0SWill Pazner            inout = a.args[2].value
150*44554ea0SWill Pazner            ndim = length(a.args) - 3
151*44554ea0SWill Pazner            dims = Vector{Expr}(undef, ndim)
152*44554ea0SWill Pazner            for d = 1:ndim
153*44554ea0SWill Pazner                dims[d] = :(Int($(a.args[d+3])))
154*44554ea0SWill Pazner            end
155*44554ea0SWill Pazner            dims_expr = :(Int[$(esc.(a.args[4:end])...)])
156*44554ea0SWill Pazner            if inout == :in
157*44554ea0SWill Pazner                push!(dims_in, dims_expr)
158*44554ea0SWill Pazner                push!(names_in, arr_name)
159*44554ea0SWill Pazner            elseif inout == :out
160*44554ea0SWill Pazner                push!(dims_out, dims_expr)
161*44554ea0SWill Pazner                push!(names_out, arr_name)
162*44554ea0SWill Pazner            else
163*44554ea0SWill Pazner                error("Array specification must be either :in or :out. Given $inout.")
164*44554ea0SWill Pazner            end
165*44554ea0SWill Pazner        elseif Meta.isexpr(a, :(::))
166*44554ea0SWill Pazner            ctx = (name=a.args[1], type=a.args[2])
167*44554ea0SWill Pazner        else
168*44554ea0SWill Pazner            error("Bad argument to @user_qfunction")
169*44554ea0SWill Pazner        end
170*44554ea0SWill Pazner    end
171*44554ea0SWill Pazner
172*44554ea0SWill Pazner    body = Meta.quot(args[end])
173*44554ea0SWill Pazner
174*44554ea0SWill Pazner    return :(generate_user_qfunction(
175*44554ea0SWill Pazner        $ceed,
176*44554ea0SWill Pazner        $def_module,
177*44554ea0SWill Pazner        $qf_name,
178*44554ea0SWill Pazner        [$(constants...)],
179*44554ea0SWill Pazner        $([names_in; names_out]),
180*44554ea0SWill Pazner        $ctx,
181*44554ea0SWill Pazner        [$(dims_in...)],
182*44554ea0SWill Pazner        [$(dims_out...)],
183*44554ea0SWill Pazner        $body,
184*44554ea0SWill Pazner    ))
185*44554ea0SWill Paznerend
186*44554ea0SWill Pazner
187*44554ea0SWill Pazner"""
188*44554ea0SWill Pazner    @interior_qf name=def
189*44554ea0SWill Pazner
190*44554ea0SWill PaznerCreates a user-defined interior (volumetric) Q-function, and assigns it to a variable named
191*44554ea0SWill Pazner`name`. The definition of the Q-function is given as:
192*44554ea0SWill Pazner```
193*44554ea0SWill Pazner@interior_qf user_qf=(
194*44554ea0SWill Pazner    ceed::CEED,
195*44554ea0SWill Pazner    [const1=val1, const2=val2, ...],
196*44554ea0SWill Pazner    [ctx::ContextType],
197*44554ea0SWill Pazner    (I1, :in, EvalMode, dims...),
198*44554ea0SWill Pazner    (I2, :in, EvalMode, dims...),
199*44554ea0SWill Pazner    (O1, :out, EvalMode, dims...),
200*44554ea0SWill Pazner    body
201*44554ea0SWill Pazner)
202*44554ea0SWill Pazner```
203*44554ea0SWill PaznerThe definitions of form `const=val` are used for definitions which will be compile-time
204*44554ea0SWill Paznerconstants in the Q-function. For example, if `dim` is a variable set to the dimension of the
205*44554ea0SWill Paznerproblem, then `dim=dim` will make `dim` available in the body of the Q-function as a
206*44554ea0SWill Paznercompile-time constant.
207*44554ea0SWill Pazner
208*44554ea0SWill PaznerIf the user wants to provide a context struct to the Q-function, that can be achieved by
209*44554ea0SWill Pazneroptionally including `ctx::ContextType`, where `ContextType` is the type of the context
210*44554ea0SWill Paznerstruct, and `ctx` is the name to which is will be bound in the body of the Q-function.
211*44554ea0SWill Pazner
212*44554ea0SWill PaznerThis is followed by the definition of the input and output arrays, which take the form
213*44554ea0SWill Pazner`(arr_name, (:in|:out), EvalMode, dims...)`. Each array will be bound to a variable named
214*44554ea0SWill Pazner`arr_name`. Input arrays should be tagged with :in, and output arrays with :out. An
215*44554ea0SWill Pazner`EvalMode` should be specified, followed by the dimensions of the array. If the array
216*44554ea0SWill Paznerconsists of scalars (one number per Q-point) then `dims` should be omitted.
217*44554ea0SWill Pazner
218*44554ea0SWill Pazner# Examples
219*44554ea0SWill Pazner
220*44554ea0SWill Pazner- Q-function to compute the "Q-data" for the mass operator, which is given by the quadrature
221*44554ea0SWill Pazner  weight times the Jacobian determinant. The mesh Jacobian (the gradient of the nodal mesh
222*44554ea0SWill Pazner  points) and the quadrature weights are given as input arrays, and the Q-data is the output
223*44554ea0SWill Pazner  array. `dim` is given as a compile-time constant, and so the array `J` is statically
224*44554ea0SWill Pazner  sized, and therefore `det(J)` will automatically dispatch to an optimized implementation
225*44554ea0SWill Pazner  for the given dimension.
226*44554ea0SWill Pazner```
227*44554ea0SWill Pazner@interior_qf build_qfunc = (
228*44554ea0SWill Pazner    ceed, dim=dim,
229*44554ea0SWill Pazner    (J, :in, EVAL_GRAD, dim, dim),
230*44554ea0SWill Pazner    (w, :in, EVAL_WEIGHT),
231*44554ea0SWill Pazner    (qdata, :out, EVAL_NONE),
232*44554ea0SWill Pazner    qdata[] = w*det(J)
233*44554ea0SWill Pazner)
234*44554ea0SWill Pazner```
235*44554ea0SWill Pazner"""
236*44554ea0SWill Paznermacro interior_qf(args)
237*44554ea0SWill Pazner    if !Meta.isexpr(args, :(=))
238*44554ea0SWill Pazner        error("@interior_qf must be of form `qf = (body)`") # COV_EXCL_LINE
239*44554ea0SWill Pazner    end
240*44554ea0SWill Pazner
241*44554ea0SWill Pazner    qf = args.args[1]
242*44554ea0SWill Pazner    user_qf = esc(qf)
243*44554ea0SWill Pazner    args = args.args[2].args
244*44554ea0SWill Pazner    ceed = esc(args[1])
245*44554ea0SWill Pazner
246*44554ea0SWill Pazner    # Calculate field sizes
247*44554ea0SWill Pazner    fields_in = Expr[]
248*44554ea0SWill Pazner    fields_out = Expr[]
249*44554ea0SWill Pazner    for a ∈ args
250*44554ea0SWill Pazner        if Meta.isexpr(a, :tuple)
251*44554ea0SWill Pazner            field_name = String(a.args[1])
252*44554ea0SWill Pazner            inout = a.args[2].value
253*44554ea0SWill Pazner            evalmode = a.args[3]
254*44554ea0SWill Pazner            ndim = length(a.args) - 3
255*44554ea0SWill Pazner            dims = Vector{Expr}(undef, ndim)
256*44554ea0SWill Pazner            for d = 1:ndim
257*44554ea0SWill Pazner                dims[d] = esc(:(Int($(a.args[d+3]))))
258*44554ea0SWill Pazner            end
259*44554ea0SWill Pazner            sz_expr = :(prod(($(dims...),)))
260*44554ea0SWill Pazner            if inout == :in
261*44554ea0SWill Pazner                push!(fields_in, :(add_input!($user_qf, $field_name, $sz_expr, $evalmode)))
262*44554ea0SWill Pazner            elseif inout == :out
263*44554ea0SWill Pazner                push!(
264*44554ea0SWill Pazner                    fields_out,
265*44554ea0SWill Pazner                    :(add_output!($user_qf, $field_name, $sz_expr, $evalmode)),
266*44554ea0SWill Pazner                )
267*44554ea0SWill Pazner            end
268*44554ea0SWill Pazner        end
269*44554ea0SWill Pazner    end
270*44554ea0SWill Pazner
271*44554ea0SWill Pazner    gen_user_qf = meta_user_qfunction(ceed, __module__, qf, args[2:end])
272*44554ea0SWill Pazner
273*44554ea0SWill Pazner    quote
274*44554ea0SWill Pazner        $user_qf = create_interior_qfunction($ceed, $gen_user_qf)
275*44554ea0SWill Pazner        $(fields_in...)
276*44554ea0SWill Pazner        $(fields_out...)
277*44554ea0SWill Pazner    end
278*44554ea0SWill Paznerend
279