xref: /libCEED/julia/LibCEED.jl/src/UserQFunction.jl (revision 97769a3d2c45e9067c16da7ef18281dfae6857b7)
144554ea0SWill Paznerstruct UserQFunction{F,K}
244554ea0SWill Pazner    f::F
344554ea0SWill Pazner    fptr::Ptr{Nothing}
444554ea0SWill Pazner    kf::K
544554ea0SWill Pazner    cuf::Union{Nothing,Ptr{Nothing}}
644554ea0SWill Paznerend
744554ea0SWill Pazner
844554ea0SWill Pazner@inline function extract_context(ptr, ::Type{T}) where {T}
944554ea0SWill Pazner    unsafe_load(Ptr{T}(ptr))
1044554ea0SWill Paznerend
1144554ea0SWill Pazner
1244554ea0SWill Pazner@inline function extract_array(ptr, idx, dims)
1344554ea0SWill Pazner    UnsafeArray(Ptr{CeedScalar}(unsafe_load(ptr, idx)), dims)
1444554ea0SWill Paznerend
1544554ea0SWill Pazner
1644554ea0SWill Paznerfunction generate_user_qfunction(
1744554ea0SWill Pazner    ceed,
1844554ea0SWill Pazner    def_module,
1944554ea0SWill Pazner    qf_name,
2044554ea0SWill Pazner    constants,
2144554ea0SWill Pazner    array_names,
2244554ea0SWill Pazner    ctx,
2344554ea0SWill Pazner    dims_in,
2444554ea0SWill Pazner    dims_out,
2544554ea0SWill Pazner    body,
2644554ea0SWill Pazner)
2744554ea0SWill Pazner    idx = gensym(:i)
2844554ea0SWill Pazner    Q = gensym(:Q)
2944554ea0SWill Pazner    ctx_ptr = gensym(:ctx_ptr)
3044554ea0SWill Pazner    in_ptr = gensym(:in_ptr)
3144554ea0SWill Pazner    out_ptr = gensym(:out_ptr)
3244554ea0SWill Pazner
3344554ea0SWill Pazner    const_assignments = Vector{Expr}(undef, length(constants))
3444554ea0SWill Pazner    for (i, c) ∈ enumerate(constants)
3544554ea0SWill Pazner        const_assignments[i] = :($(c[1]) = $(c[2]))
3644554ea0SWill Pazner    end
3744554ea0SWill Pazner
3844554ea0SWill Pazner    narrays = length(array_names)
3944554ea0SWill Pazner    arrays = Vector{Expr}(undef, narrays)
4044554ea0SWill Pazner    array_views = Vector{Expr}(undef, narrays)
4144554ea0SWill Pazner    n_in = length(dims_in)
4244554ea0SWill Pazner    for (i, arr_name) ∈ enumerate(array_names)
4344554ea0SWill Pazner        i_inout = (i <= n_in) ? i : i - n_in
4444554ea0SWill Pazner        dims = (i <= n_in) ? dims_in[i] : dims_out[i-n_in]
4544554ea0SWill Pazner        ptr = (i <= n_in) ? in_ptr : out_ptr
4644554ea0SWill Pazner        arr_name_gen = gensym(arr_name)
4744554ea0SWill Pazner        arrays[i] = :($arr_name_gen = extract_array($ptr, $i_inout, (Int($Q), $(dims...))))
4844554ea0SWill Pazner        ndims = length(dims)
4944554ea0SWill Pazner        slice = Expr(:ref, arr_name_gen, idx, (:(:) for i = 1:ndims)...)
5044554ea0SWill Pazner        if i <= n_in
5144554ea0SWill Pazner            if ndims == 0
5244554ea0SWill Pazner                array_views[i] = :($arr_name = $slice)
5344554ea0SWill Pazner            else
5444554ea0SWill Pazner                S = Tuple{dims...}
5544554ea0SWill Pazner                array_views[i] = :($arr_name = LibCEED.SArray{$S}(@view $slice))
5644554ea0SWill Pazner            end
5744554ea0SWill Pazner        else
5844554ea0SWill Pazner            array_views[i] = :($arr_name = @view $slice)
5944554ea0SWill Pazner        end
6044554ea0SWill Pazner    end
6144554ea0SWill Pazner
6244554ea0SWill Pazner    if isnothing(ctx)
6344554ea0SWill Pazner        ctx_assignment = nothing
6444554ea0SWill Pazner    else
6544554ea0SWill Pazner        ctx_assignment = :($(ctx.name) = extract_context($ctx_ptr, $(ctx.type)))
6644554ea0SWill Pazner    end
6744554ea0SWill Pazner
6844554ea0SWill Pazner    qf1 = gensym(qf_name)
6944554ea0SWill Pazner    f = Core.eval(
7044554ea0SWill Pazner        def_module,
7144554ea0SWill Pazner        quote
7244554ea0SWill Pazner            @inline function $qf1(
7344554ea0SWill Pazner                $ctx_ptr::Ptr{Cvoid},
7444554ea0SWill Pazner                $Q::CeedInt,
7544554ea0SWill Pazner                $in_ptr::Ptr{Ptr{CeedScalar}},
7644554ea0SWill Pazner                $out_ptr::Ptr{Ptr{CeedScalar}},
7744554ea0SWill Pazner            )
7844554ea0SWill Pazner                $(const_assignments...)
7944554ea0SWill Pazner                $ctx_assignment
8044554ea0SWill Pazner                $(arrays...)
8144554ea0SWill Pazner                @inbounds @simd for $idx = 1:$Q
8244554ea0SWill Pazner                    $(array_views...)
8344554ea0SWill Pazner                    $body
8444554ea0SWill Pazner                end
8544554ea0SWill Pazner                CeedInt(0)
8644554ea0SWill Pazner            end
8744554ea0SWill Pazner        end,
8844554ea0SWill Pazner    )
8944554ea0SWill Pazner    f_qn = QuoteNode(f)
9044554ea0SWill Pazner    rt = :CeedInt
9144554ea0SWill Pazner    at = :(Core.svec(Ptr{Cvoid}, CeedInt, Ptr{Ptr{CeedScalar}}, Ptr{Ptr{CeedScalar}}))
9244554ea0SWill Pazner    fptr = eval(Expr(:cfunction, Ptr{Cvoid}, f_qn, rt, at, QuoteNode(:ccall)))
9344554ea0SWill Pazner
9444554ea0SWill Pazner    # COV_EXCL_START
9544554ea0SWill Pazner    if iscuda(ceed)
96*97769a3dSJed Brown        getresource(ceed) == "/gpu/cuda/gen" && error(
97*97769a3dSJed Brown            string(
9844554ea0SWill Pazner                "/gpu/cuda/gen is not compatible with user Q-functions defined with ",
9944554ea0SWill Pazner                "libCEED.jl.\nPlease use a different backend, for example: /gpu/cuda/shared ",
10044554ea0SWill Pazner                "or /gpu/cuda/ref",
101*97769a3dSJed Brown            ),
102*97769a3dSJed Brown        )
10344554ea0SWill Pazner        if cuda_is_loaded
10444554ea0SWill Pazner            !has_cuda() && error("No valid CUDA installation found")
10544554ea0SWill Pazner            qf2 = gensym(qf_name)
10644554ea0SWill Pazner            kf = Core.eval(
10744554ea0SWill Pazner                def_module,
10844554ea0SWill Pazner                quote
10944554ea0SWill Pazner                    @inline function $qf2($ctx_ptr::Ptr{Cvoid}, $(array_names...))
11044554ea0SWill Pazner                        $(const_assignments...)
11144554ea0SWill Pazner                        $ctx_assignment
11244554ea0SWill Pazner                        $body
11344554ea0SWill Pazner                        nothing
11444554ea0SWill Pazner                    end
11544554ea0SWill Pazner                end,
11644554ea0SWill Pazner            )
11744554ea0SWill Pazner            cuf = mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out)
11844554ea0SWill Pazner        else
119*97769a3dSJed Brown            error(
120*97769a3dSJed Brown                string(
12144554ea0SWill Pazner                    "User Q-functions with CUDA backends require the CUDA.jl package to be ",
12244554ea0SWill Pazner                    "loaded.\nThe libCEED backend is: $(getresource(ceed))\n",
12344554ea0SWill Pazner                    "Please ensure that the CUDA.jl package is installed and loaded.",
124*97769a3dSJed Brown                ),
125*97769a3dSJed Brown            )
12644554ea0SWill Pazner        end
12744554ea0SWill Pazner    else
12844554ea0SWill Pazner        kf = nothing
12944554ea0SWill Pazner        cuf = nothing
13044554ea0SWill Pazner    end
13144554ea0SWill Pazner    # COV_EXCL_STOP
13244554ea0SWill Pazner
13344554ea0SWill Pazner    UserQFunction(f, fptr, kf, cuf)
13444554ea0SWill Paznerend
13544554ea0SWill Pazner
13644554ea0SWill Paznerfunction meta_user_qfunction(ceed, def_module, qf, args)
13744554ea0SWill Pazner    qf_name = Meta.quot(qf)
13844554ea0SWill Pazner
13944554ea0SWill Pazner    ctx = nothing
14044554ea0SWill Pazner    constants = Expr[]
14144554ea0SWill Pazner    dims_in = Expr[]
14244554ea0SWill Pazner    dims_out = Expr[]
14344554ea0SWill Pazner    names_in = Symbol[]
14444554ea0SWill Pazner    names_out = Symbol[]
14544554ea0SWill Pazner
14644554ea0SWill Pazner    for a ∈ args[1:end-1]
14744554ea0SWill Pazner        if Meta.isexpr(a, :(=))
14844554ea0SWill Pazner            a1 = Meta.quot(a.args[1])
14944554ea0SWill Pazner            a2 = esc(a.args[2])
15044554ea0SWill Pazner            push!(constants, :(($a1, $a2)))
15144554ea0SWill Pazner        elseif Meta.isexpr(a, :tuple)
15244554ea0SWill Pazner            arr_name = a.args[1]
15344554ea0SWill Pazner            inout = a.args[2].value
15444554ea0SWill Pazner            ndim = length(a.args) - 3
15544554ea0SWill Pazner            dims = Vector{Expr}(undef, ndim)
15644554ea0SWill Pazner            for d = 1:ndim
15744554ea0SWill Pazner                dims[d] = :(Int($(a.args[d+3])))
15844554ea0SWill Pazner            end
15944554ea0SWill Pazner            dims_expr = :(Int[$(esc.(a.args[4:end])...)])
16044554ea0SWill Pazner            if inout == :in
16144554ea0SWill Pazner                push!(dims_in, dims_expr)
16244554ea0SWill Pazner                push!(names_in, arr_name)
16344554ea0SWill Pazner            elseif inout == :out
16444554ea0SWill Pazner                push!(dims_out, dims_expr)
16544554ea0SWill Pazner                push!(names_out, arr_name)
16644554ea0SWill Pazner            else
16744554ea0SWill Pazner                error("Array specification must be either :in or :out. Given $inout.")
16844554ea0SWill Pazner            end
16944554ea0SWill Pazner        elseif Meta.isexpr(a, :(::))
17044554ea0SWill Pazner            ctx = (name=a.args[1], type=a.args[2])
17144554ea0SWill Pazner        else
17244554ea0SWill Pazner            error("Bad argument to @user_qfunction")
17344554ea0SWill Pazner        end
17444554ea0SWill Pazner    end
17544554ea0SWill Pazner
17644554ea0SWill Pazner    body = Meta.quot(args[end])
17744554ea0SWill Pazner
17844554ea0SWill Pazner    return :(generate_user_qfunction(
17944554ea0SWill Pazner        $ceed,
18044554ea0SWill Pazner        $def_module,
18144554ea0SWill Pazner        $qf_name,
18244554ea0SWill Pazner        [$(constants...)],
18344554ea0SWill Pazner        $([names_in; names_out]),
18444554ea0SWill Pazner        $ctx,
18544554ea0SWill Pazner        [$(dims_in...)],
18644554ea0SWill Pazner        [$(dims_out...)],
18744554ea0SWill Pazner        $body,
18844554ea0SWill Pazner    ))
18944554ea0SWill Paznerend
19044554ea0SWill Pazner
19144554ea0SWill Pazner"""
19244554ea0SWill Pazner    @interior_qf name=def
19344554ea0SWill Pazner
19444554ea0SWill PaznerCreates a user-defined interior (volumetric) Q-function, and assigns it to a variable named
19544554ea0SWill Pazner`name`. The definition of the Q-function is given as:
19644554ea0SWill Pazner```
19744554ea0SWill Pazner@interior_qf user_qf=(
19844554ea0SWill Pazner    ceed::CEED,
19944554ea0SWill Pazner    [const1=val1, const2=val2, ...],
20044554ea0SWill Pazner    [ctx::ContextType],
20144554ea0SWill Pazner    (I1, :in, EvalMode, dims...),
20244554ea0SWill Pazner    (I2, :in, EvalMode, dims...),
20344554ea0SWill Pazner    (O1, :out, EvalMode, dims...),
20444554ea0SWill Pazner    body
20544554ea0SWill Pazner)
20644554ea0SWill Pazner```
20744554ea0SWill PaznerThe definitions of form `const=val` are used for definitions which will be compile-time
20844554ea0SWill Paznerconstants in the Q-function. For example, if `dim` is a variable set to the dimension of the
20944554ea0SWill Paznerproblem, then `dim=dim` will make `dim` available in the body of the Q-function as a
21044554ea0SWill Paznercompile-time constant.
21144554ea0SWill Pazner
21244554ea0SWill PaznerIf the user wants to provide a context struct to the Q-function, that can be achieved by
21344554ea0SWill Pazneroptionally including `ctx::ContextType`, where `ContextType` is the type of the context
21444554ea0SWill Paznerstruct, and `ctx` is the name to which is will be bound in the body of the Q-function.
21544554ea0SWill Pazner
21644554ea0SWill PaznerThis is followed by the definition of the input and output arrays, which take the form
21744554ea0SWill Pazner`(arr_name, (:in|:out), EvalMode, dims...)`. Each array will be bound to a variable named
21844554ea0SWill Pazner`arr_name`. Input arrays should be tagged with :in, and output arrays with :out. An
21944554ea0SWill Pazner`EvalMode` should be specified, followed by the dimensions of the array. If the array
22044554ea0SWill Paznerconsists of scalars (one number per Q-point) then `dims` should be omitted.
22144554ea0SWill Pazner
22244554ea0SWill Pazner# Examples
22344554ea0SWill Pazner
22444554ea0SWill Pazner- Q-function to compute the "Q-data" for the mass operator, which is given by the quadrature
22544554ea0SWill Pazner  weight times the Jacobian determinant. The mesh Jacobian (the gradient of the nodal mesh
22644554ea0SWill Pazner  points) and the quadrature weights are given as input arrays, and the Q-data is the output
22744554ea0SWill Pazner  array. `dim` is given as a compile-time constant, and so the array `J` is statically
22844554ea0SWill Pazner  sized, and therefore `det(J)` will automatically dispatch to an optimized implementation
22944554ea0SWill Pazner  for the given dimension.
23044554ea0SWill Pazner```
23144554ea0SWill Pazner@interior_qf build_qfunc = (
23244554ea0SWill Pazner    ceed, dim=dim,
23344554ea0SWill Pazner    (J, :in, EVAL_GRAD, dim, dim),
23444554ea0SWill Pazner    (w, :in, EVAL_WEIGHT),
23544554ea0SWill Pazner    (qdata, :out, EVAL_NONE),
23644554ea0SWill Pazner    qdata[] = w*det(J)
23744554ea0SWill Pazner)
23844554ea0SWill Pazner```
23944554ea0SWill Pazner"""
24044554ea0SWill Paznermacro interior_qf(args)
24144554ea0SWill Pazner    if !Meta.isexpr(args, :(=))
24244554ea0SWill Pazner        error("@interior_qf must be of form `qf = (body)`") # COV_EXCL_LINE
24344554ea0SWill Pazner    end
24444554ea0SWill Pazner
24544554ea0SWill Pazner    qf = args.args[1]
24644554ea0SWill Pazner    user_qf = esc(qf)
24744554ea0SWill Pazner    args = args.args[2].args
24844554ea0SWill Pazner    ceed = esc(args[1])
24944554ea0SWill Pazner
25044554ea0SWill Pazner    # Calculate field sizes
25144554ea0SWill Pazner    fields_in = Expr[]
25244554ea0SWill Pazner    fields_out = Expr[]
25344554ea0SWill Pazner    for a ∈ args
25444554ea0SWill Pazner        if Meta.isexpr(a, :tuple)
25544554ea0SWill Pazner            field_name = String(a.args[1])
25644554ea0SWill Pazner            inout = a.args[2].value
25744554ea0SWill Pazner            evalmode = a.args[3]
25844554ea0SWill Pazner            ndim = length(a.args) - 3
25944554ea0SWill Pazner            dims = Vector{Expr}(undef, ndim)
26044554ea0SWill Pazner            for d = 1:ndim
26144554ea0SWill Pazner                dims[d] = esc(:(Int($(a.args[d+3]))))
26244554ea0SWill Pazner            end
26344554ea0SWill Pazner            sz_expr = :(prod(($(dims...),)))
26444554ea0SWill Pazner            if inout == :in
26544554ea0SWill Pazner                push!(fields_in, :(add_input!($user_qf, $field_name, $sz_expr, $evalmode)))
26644554ea0SWill Pazner            elseif inout == :out
26744554ea0SWill Pazner                push!(
26844554ea0SWill Pazner                    fields_out,
26944554ea0SWill Pazner                    :(add_output!($user_qf, $field_name, $sz_expr, $evalmode)),
27044554ea0SWill Pazner                )
27144554ea0SWill Pazner            end
27244554ea0SWill Pazner        end
27344554ea0SWill Pazner    end
27444554ea0SWill Pazner
27544554ea0SWill Pazner    gen_user_qf = meta_user_qfunction(ceed, __module__, qf, args[2:end])
27644554ea0SWill Pazner
27744554ea0SWill Pazner    quote
27844554ea0SWill Pazner        $user_qf = create_interior_qfunction($ceed, $gen_user_qf)
27944554ea0SWill Pazner        $(fields_in...)
28044554ea0SWill Pazner        $(fields_out...)
28144554ea0SWill Pazner    end
28244554ea0SWill Paznerend
283