xref: /libCEED/julia/LibCEED.jl/src/UserQFunction.jl (revision 3a739e1923fc978b8248ffcb40b5d2f55443c9d9)
1struct UserQFunction{F,K}
2    f::F
3    fptr::Ptr{Nothing}
4    kf::K
5    cuf::Union{Nothing,Ptr{Nothing}}
6end
7
8@inline function extract_context(ptr, ::Type{T}) where {T}
9    unsafe_load(Ptr{T}(ptr))
10end
11
12@inline function extract_array(ptr, idx, dims)
13    UnsafeArray(Ptr{CeedScalar}(unsafe_load(ptr, idx)), dims)
14end
15
16function generate_user_qfunction(
17    ceed,
18    def_module,
19    qf_name,
20    constants,
21    array_names,
22    ctx,
23    dims_in,
24    dims_out,
25    body,
26)
27    idx = gensym(:i)
28    Q = gensym(:Q)
29    ctx_ptr = gensym(:ctx_ptr)
30    in_ptr = gensym(:in_ptr)
31    out_ptr = gensym(:out_ptr)
32
33    const_assignments = Vector{Expr}(undef, length(constants))
34    for (i, c) ∈ enumerate(constants)
35        const_assignments[i] = :($(c[1]) = $(c[2]))
36    end
37
38    narrays = length(array_names)
39    arrays = Vector{Expr}(undef, narrays)
40    array_views = Vector{Expr}(undef, narrays)
41    n_in = length(dims_in)
42    for (i, arr_name) ∈ enumerate(array_names)
43        i_inout = (i <= n_in) ? i : i - n_in
44        dims = (i <= n_in) ? dims_in[i] : dims_out[i-n_in]
45        ptr = (i <= n_in) ? in_ptr : out_ptr
46        arr_name_gen = gensym(arr_name)
47        arrays[i] = :($arr_name_gen = extract_array($ptr, $i_inout, (Int($Q), $(dims...))))
48        ndims = length(dims)
49        slice = Expr(:ref, arr_name_gen, idx, (:(:) for i = 1:ndims)...)
50        if i <= n_in
51            if ndims == 0
52                array_views[i] = :($arr_name = $slice)
53            else
54                S = Tuple{dims...}
55                array_views[i] = :($arr_name = LibCEED.SArray{$S}(@view $slice))
56            end
57        else
58            array_views[i] = :($arr_name = @view $slice)
59        end
60    end
61
62    if isnothing(ctx)
63        ctx_assignment = nothing
64    else
65        ctx_assignment = :($(ctx.name) = extract_context($ctx_ptr, $(ctx.type)))
66    end
67
68    qf1 = gensym(qf_name)
69    f = Core.eval(
70        def_module,
71        quote
72            @inline function $qf1(
73                $ctx_ptr::Ptr{Cvoid},
74                $Q::CeedInt,
75                $in_ptr::Ptr{Ptr{CeedScalar}},
76                $out_ptr::Ptr{Ptr{CeedScalar}},
77            )
78                $(const_assignments...)
79                $ctx_assignment
80                $(arrays...)
81                @inbounds @simd for $idx = 1:$Q
82                    $(array_views...)
83                    $body
84                end
85                CeedInt(0)
86            end
87        end,
88    )
89    f_qn = QuoteNode(f)
90    rt = :CeedInt
91    at = :(Core.svec(Ptr{Cvoid}, CeedInt, Ptr{Ptr{CeedScalar}}, Ptr{Ptr{CeedScalar}}))
92    fptr = eval(Expr(:cfunction, Ptr{Cvoid}, f_qn, rt, at, QuoteNode(:ccall)))
93
94    # COV_EXCL_START
95    if iscuda(ceed)
96        getresource(ceed) == "/gpu/cuda/gen" && error(string(
97            "/gpu/cuda/gen is not compatible with user Q-functions defined with ",
98            "libCEED.jl.\nPlease use a different backend, for example: /gpu/cuda/shared ",
99            "or /gpu/cuda/ref",
100        ))
101        if cuda_is_loaded
102            !has_cuda() && error("No valid CUDA installation found")
103            qf2 = gensym(qf_name)
104            kf = Core.eval(
105                def_module,
106                quote
107                    @inline function $qf2($ctx_ptr::Ptr{Cvoid}, $(array_names...))
108                        $(const_assignments...)
109                        $ctx_assignment
110                        $body
111                        nothing
112                    end
113                end,
114            )
115            cuf = mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out)
116        else
117            error(string(
118                "User Q-functions with CUDA backends require the CUDA.jl package to be ",
119                "loaded.\nThe libCEED backend is: $(getresource(ceed))\n",
120                "Please ensure that the CUDA.jl package is installed and loaded.",
121            ))
122        end
123    else
124        kf = nothing
125        cuf = nothing
126    end
127    # COV_EXCL_STOP
128
129    UserQFunction(f, fptr, kf, cuf)
130end
131
132function meta_user_qfunction(ceed, def_module, qf, args)
133    qf_name = Meta.quot(qf)
134
135    ctx = nothing
136    constants = Expr[]
137    dims_in = Expr[]
138    dims_out = Expr[]
139    names_in = Symbol[]
140    names_out = Symbol[]
141
142    for a ∈ args[1:end-1]
143        if Meta.isexpr(a, :(=))
144            a1 = Meta.quot(a.args[1])
145            a2 = esc(a.args[2])
146            push!(constants, :(($a1, $a2)))
147        elseif Meta.isexpr(a, :tuple)
148            arr_name = a.args[1]
149            inout = a.args[2].value
150            ndim = length(a.args) - 3
151            dims = Vector{Expr}(undef, ndim)
152            for d = 1:ndim
153                dims[d] = :(Int($(a.args[d+3])))
154            end
155            dims_expr = :(Int[$(esc.(a.args[4:end])...)])
156            if inout == :in
157                push!(dims_in, dims_expr)
158                push!(names_in, arr_name)
159            elseif inout == :out
160                push!(dims_out, dims_expr)
161                push!(names_out, arr_name)
162            else
163                error("Array specification must be either :in or :out. Given $inout.")
164            end
165        elseif Meta.isexpr(a, :(::))
166            ctx = (name=a.args[1], type=a.args[2])
167        else
168            error("Bad argument to @user_qfunction")
169        end
170    end
171
172    body = Meta.quot(args[end])
173
174    return :(generate_user_qfunction(
175        $ceed,
176        $def_module,
177        $qf_name,
178        [$(constants...)],
179        $([names_in; names_out]),
180        $ctx,
181        [$(dims_in...)],
182        [$(dims_out...)],
183        $body,
184    ))
185end
186
187"""
188    @interior_qf name=def
189
190Creates a user-defined interior (volumetric) Q-function, and assigns it to a variable named
191`name`. The definition of the Q-function is given as:
192```
193@interior_qf user_qf=(
194    ceed::CEED,
195    [const1=val1, const2=val2, ...],
196    [ctx::ContextType],
197    (I1, :in, EvalMode, dims...),
198    (I2, :in, EvalMode, dims...),
199    (O1, :out, EvalMode, dims...),
200    body
201)
202```
203The definitions of form `const=val` are used for definitions which will be compile-time
204constants in the Q-function. For example, if `dim` is a variable set to the dimension of the
205problem, then `dim=dim` will make `dim` available in the body of the Q-function as a
206compile-time constant.
207
208If the user wants to provide a context struct to the Q-function, that can be achieved by
209optionally including `ctx::ContextType`, where `ContextType` is the type of the context
210struct, and `ctx` is the name to which is will be bound in the body of the Q-function.
211
212This is followed by the definition of the input and output arrays, which take the form
213`(arr_name, (:in|:out), EvalMode, dims...)`. Each array will be bound to a variable named
214`arr_name`. Input arrays should be tagged with :in, and output arrays with :out. An
215`EvalMode` should be specified, followed by the dimensions of the array. If the array
216consists of scalars (one number per Q-point) then `dims` should be omitted.
217
218# Examples
219
220- Q-function to compute the "Q-data" for the mass operator, which is given by the quadrature
221  weight times the Jacobian determinant. The mesh Jacobian (the gradient of the nodal mesh
222  points) and the quadrature weights are given as input arrays, and the Q-data is the output
223  array. `dim` is given as a compile-time constant, and so the array `J` is statically
224  sized, and therefore `det(J)` will automatically dispatch to an optimized implementation
225  for the given dimension.
226```
227@interior_qf build_qfunc = (
228    ceed, dim=dim,
229    (J, :in, EVAL_GRAD, dim, dim),
230    (w, :in, EVAL_WEIGHT),
231    (qdata, :out, EVAL_NONE),
232    qdata[] = w*det(J)
233)
234```
235"""
236macro interior_qf(args)
237    if !Meta.isexpr(args, :(=))
238        error("@interior_qf must be of form `qf = (body)`") # COV_EXCL_LINE
239    end
240
241    qf = args.args[1]
242    user_qf = esc(qf)
243    args = args.args[2].args
244    ceed = esc(args[1])
245
246    # Calculate field sizes
247    fields_in = Expr[]
248    fields_out = Expr[]
249    for a ∈ args
250        if Meta.isexpr(a, :tuple)
251            field_name = String(a.args[1])
252            inout = a.args[2].value
253            evalmode = a.args[3]
254            ndim = length(a.args) - 3
255            dims = Vector{Expr}(undef, ndim)
256            for d = 1:ndim
257                dims[d] = esc(:(Int($(a.args[d+3]))))
258            end
259            sz_expr = :(prod(($(dims...),)))
260            if inout == :in
261                push!(fields_in, :(add_input!($user_qf, $field_name, $sz_expr, $evalmode)))
262            elseif inout == :out
263                push!(
264                    fields_out,
265                    :(add_output!($user_qf, $field_name, $sz_expr, $evalmode)),
266                )
267            end
268        end
269    end
270
271    gen_user_qf = meta_user_qfunction(ceed, __module__, qf, args[2:end])
272
273    quote
274        $user_qf = create_interior_qfunction($ceed, $gen_user_qf)
275        $(fields_in...)
276        $(fields_out...)
277    end
278end
279