xref: /libCEED/julia/LibCEED.jl/src/UserQFunction.jl (revision bc251d84045f19174d64c067f78a1766a73c1ffc)
1struct UserQFunction{F,K}
2    f::F
3    fptr::Ptr{Nothing}
4    kf::K
5    cuf::Union{Nothing,Ptr{Nothing}}
6end
7
8@inline function extract_context(ptr, ::Type{T}) where {T}
9    unsafe_load(Ptr{T}(ptr))
10end
11
12@inline function extract_array(ptr, idx, dims)
13    UnsafeArray(Ptr{CeedScalar}(unsafe_load(ptr, idx)), dims)
14end
15
16function generate_user_qfunction(
17    ceed,
18    def_module,
19    qf_name,
20    constants,
21    array_names,
22    ctx,
23    dims_in,
24    dims_out,
25    body,
26)
27    idx = gensym(:i)
28    Q = gensym(:Q)
29    ctx_ptr = gensym(:ctx_ptr)
30    in_ptr = gensym(:in_ptr)
31    out_ptr = gensym(:out_ptr)
32
33    const_assignments = Vector{Expr}(undef, length(constants))
34    for (i, c) ∈ enumerate(constants)
35        const_assignments[i] = :($(c[1]) = $(c[2]))
36    end
37
38    narrays = length(array_names)
39    arrays = Vector{Expr}(undef, narrays)
40    array_views = Vector{Expr}(undef, narrays)
41    n_in = length(dims_in)
42    for (i, arr_name) ∈ enumerate(array_names)
43        i_inout = (i <= n_in) ? i : i - n_in
44        dims = (i <= n_in) ? dims_in[i] : dims_out[i-n_in]
45        ptr = (i <= n_in) ? in_ptr : out_ptr
46        arr_name_gen = gensym(arr_name)
47        arrays[i] = :($arr_name_gen = extract_array($ptr, $i_inout, (Int($Q), $(dims...))))
48        ndims = length(dims)
49        slice = Expr(:ref, arr_name_gen, idx, (:(:) for i = 1:ndims)...)
50        if i <= n_in
51            if ndims == 0
52                array_views[i] = :($arr_name = $slice)
53            else
54                S = Tuple{dims...}
55                array_views[i] = :($arr_name = LibCEED.SArray{$S}(@view $slice))
56            end
57        else
58            array_views[i] = :($arr_name = @view $slice)
59        end
60    end
61
62    if isnothing(ctx)
63        ctx_assignment = nothing
64    else
65        ctx_assignment = :($(ctx.name) = extract_context($ctx_ptr, $(ctx.type)))
66    end
67
68    qf1 = gensym(qf_name)
69    f = Core.eval(
70        def_module,
71        quote
72            @inline function $qf1(
73                $ctx_ptr::Ptr{Cvoid},
74                $Q::CeedInt,
75                $in_ptr::Ptr{Ptr{CeedScalar}},
76                $out_ptr::Ptr{Ptr{CeedScalar}},
77            )
78                $(const_assignments...)
79                $ctx_assignment
80                $(arrays...)
81                @inbounds @simd for $idx = 1:$Q
82                    $(array_views...)
83                    $body
84                end
85                CeedInt(0)
86            end
87        end,
88    )
89    f_qn = QuoteNode(f)
90    rt = :CeedInt
91    at = :(Core.svec(Ptr{Cvoid}, CeedInt, Ptr{Ptr{CeedScalar}}, Ptr{Ptr{CeedScalar}}))
92    fptr = eval(Expr(:cfunction, Ptr{Cvoid}, f_qn, rt, at, QuoteNode(:ccall)))
93
94    # COV_EXCL_START
95    if iscuda(ceed)
96        getresource(ceed) == "/gpu/cuda/gen" && error(
97            string(
98                "/gpu/cuda/gen is not compatible with user Q-functions defined with ",
99                "libCEED.jl.\nPlease use a different backend, for example: /gpu/cuda/shared ",
100                "or /gpu/cuda/ref",
101            ),
102        )
103        if isdefined(@__MODULE__, :CUDA)
104            !has_cuda() && error("No valid CUDA installation found")
105            qf2 = gensym(qf_name)
106            kf = Core.eval(
107                def_module,
108                quote
109                    @inline function $qf2($ctx_ptr::Ptr{Cvoid}, $(array_names...))
110                        $(const_assignments...)
111                        $ctx_assignment
112                        $body
113                        nothing
114                    end
115                end,
116            )
117            cuf = mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out)
118        else
119            error(
120                string(
121                    "User Q-functions with CUDA backends require the CUDA.jl package to be ",
122                    "loaded.\nThe libCEED backend is: $(getresource(ceed))\n",
123                    "Please ensure that the CUDA.jl package is installed and loaded.",
124                ),
125            )
126        end
127    else
128        kf = nothing
129        cuf = nothing
130    end
131    # COV_EXCL_STOP
132
133    UserQFunction(f, fptr, kf, cuf)
134end
135
136function meta_user_qfunction(ceed, def_module, qf, args)
137    qf_name = Meta.quot(qf)
138
139    ctx = nothing
140    constants = Expr[]
141    dims_in = Expr[]
142    dims_out = Expr[]
143    names_in = Symbol[]
144    names_out = Symbol[]
145
146    for a ∈ args[1:end-1]
147        if Meta.isexpr(a, :(=))
148            a1 = Meta.quot(a.args[1])
149            a2 = esc(a.args[2])
150            push!(constants, :(($a1, $a2)))
151        elseif Meta.isexpr(a, :tuple)
152            arr_name = a.args[1]
153            inout = a.args[2].value
154            ndim = length(a.args) - 3
155            dims = Vector{Expr}(undef, ndim)
156            for d = 1:ndim
157                dims[d] = :(Int($(a.args[d+3])))
158            end
159            dims_expr = :(Int[$(esc.(a.args[4:end])...)])
160            if inout == :in
161                push!(dims_in, dims_expr)
162                push!(names_in, arr_name)
163            elseif inout == :out
164                push!(dims_out, dims_expr)
165                push!(names_out, arr_name)
166            else
167                error("Array specification must be either :in or :out. Given $inout.")
168            end
169        elseif Meta.isexpr(a, :(::))
170            ctx = (name=a.args[1], type=a.args[2])
171        else
172            error("Bad argument to @user_qfunction")
173        end
174    end
175
176    body = Meta.quot(args[end])
177
178    return :(generate_user_qfunction(
179        $ceed,
180        $def_module,
181        $qf_name,
182        [$(constants...)],
183        $([names_in; names_out]),
184        $ctx,
185        [$(dims_in...)],
186        [$(dims_out...)],
187        $body,
188    ))
189end
190
191"""
192    @interior_qf name=def
193
194Creates a user-defined interior (volumetric) Q-function, and assigns it to a variable named
195`name`. The definition of the Q-function is given as:
196```
197@interior_qf user_qf=(
198    ceed::CEED,
199    [const1=val1, const2=val2, ...],
200    [ctx::ContextType],
201    (I1, :in, EvalMode, dims...),
202    (I2, :in, EvalMode, dims...),
203    (O1, :out, EvalMode, dims...),
204    body
205)
206```
207The definitions of form `const=val` are used for definitions which will be compile-time
208constants in the Q-function. For example, if `dim` is a variable set to the dimension of the
209problem, then `dim=dim` will make `dim` available in the body of the Q-function as a
210compile-time constant.
211
212If the user wants to provide a context struct to the Q-function, that can be achieved by
213optionally including `ctx::ContextType`, where `ContextType` is the type of the context
214struct, and `ctx` is the name to which is will be bound in the body of the Q-function.
215
216This is followed by the definition of the input and output arrays, which take the form
217`(arr_name, (:in|:out), EvalMode, dims...)`. Each array will be bound to a variable named
218`arr_name`. Input arrays should be tagged with :in, and output arrays with :out. An
219`EvalMode` should be specified, followed by the dimensions of the array. If the array
220consists of scalars (one number per Q-point) then `dims` should be omitted.
221
222# Examples
223
224- Q-function to compute the "Q-data" for the mass operator, which is given by the quadrature
225  weight times the Jacobian determinant. The mesh Jacobian (the gradient of the nodal mesh
226  points) and the quadrature weights are given as input arrays, and the Q-data is the output
227  array. `dim` is given as a compile-time constant, and so the array `J` is statically
228  sized, and therefore `det(J)` will automatically dispatch to an optimized implementation
229  for the given dimension.
230```
231@interior_qf build_qfunc = (
232    ceed, dim=dim,
233    (J, :in, EVAL_GRAD, dim, dim),
234    (w, :in, EVAL_WEIGHT),
235    (qdata, :out, EVAL_NONE),
236    qdata[] = w*det(J)
237)
238```
239"""
240macro interior_qf(args)
241    if !Meta.isexpr(args, :(=))
242        error("@interior_qf must be of form `qf = (body)`") # COV_EXCL_LINE
243    end
244
245    qf = args.args[1]
246    user_qf = esc(qf)
247    args = args.args[2].args
248    ceed = esc(args[1])
249
250    # Calculate field sizes
251    fields_in = Expr[]
252    fields_out = Expr[]
253    for a ∈ args
254        if Meta.isexpr(a, :tuple)
255            field_name = String(a.args[1])
256            inout = a.args[2].value
257            evalmode = a.args[3]
258            ndim = length(a.args) - 3
259            dims = Vector{Expr}(undef, ndim)
260            for d = 1:ndim
261                dims[d] = esc(:(Int($(a.args[d+3]))))
262            end
263            sz_expr = :(prod(($(dims...),)))
264            if inout == :in
265                push!(fields_in, :(add_input!($user_qf, $field_name, $sz_expr, $evalmode)))
266            elseif inout == :out
267                push!(
268                    fields_out,
269                    :(add_output!($user_qf, $field_name, $sz_expr, $evalmode)),
270                )
271            end
272        end
273    end
274
275    gen_user_qf = meta_user_qfunction(ceed, __module__, qf, args[2:end])
276
277    quote
278        $user_qf = create_interior_qfunction($ceed, $gen_user_qf)
279        $(fields_in...)
280        $(fields_out...)
281    end
282end
283