xref: /libCEED/julia/LibCEED.jl/src/Cuda.jl (revision 60801d19602b94955220fc3cc63a65b52bc34d1e)
144554ea0SWill Pazner# COV_EXCL_START
244554ea0SWill Paznerusing .CUDA, Cassette
344554ea0SWill Pazner
444554ea0SWill Pazner#! format: off
544554ea0SWill Paznerconst cudafuns = (
644554ea0SWill Pazner    :cos, :cospi, :sin, :sinpi, :tan,
744554ea0SWill Pazner    :acos, :asin, :atan,
844554ea0SWill Pazner    :cosh, :sinh, :tanh,
944554ea0SWill Pazner    :acosh, :asinh, :atanh,
1044554ea0SWill Pazner    :log, :log10, :log1p, :log2,
1144554ea0SWill Pazner    :exp, :exp2, :exp10, :expm1, :ldexp,
1244554ea0SWill Pazner    :abs,
1344554ea0SWill Pazner    :sqrt, :cbrt,
1444554ea0SWill Pazner    :ceil, :floor,
1544554ea0SWill Pazner)
1644554ea0SWill Pazner#! format: on
1744554ea0SWill Pazner
1844554ea0SWill PaznerCassette.@context CeedCudaContext
1944554ea0SWill Pazner
2044554ea0SWill Pazner@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.kwfunc), f)
2144554ea0SWill Pazner    return Core.kwfunc(f)
2244554ea0SWill Paznerend
2344554ea0SWill Pazner@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.apply_type), args...)
2444554ea0SWill Pazner    return Core.apply_type(args...)
2544554ea0SWill Paznerend
2644554ea0SWill Pazner@inline function Cassette.overdub(
2744554ea0SWill Pazner    ::CeedCudaContext,
2844554ea0SWill Pazner    ::typeof(StaticArrays.Size),
2944554ea0SWill Pazner    x::Type{<:AbstractArray{<:Any,N}},
3044554ea0SWill Pazner) where {N}
3144554ea0SWill Pazner    return StaticArrays.Size(x)
3244554ea0SWill Paznerend
3344554ea0SWill Pazner
3444554ea0SWill Paznerfor f in cudafuns
3544554ea0SWill Pazner    @eval @inline function Cassette.overdub(
3644554ea0SWill Pazner        ::CeedCudaContext,
3744554ea0SWill Pazner        ::typeof(Base.$f),
3844554ea0SWill Pazner        x::Union{Float32,Float64},
3944554ea0SWill Pazner    )
4044554ea0SWill Pazner        return CUDA.$f(x)
4144554ea0SWill Pazner    end
4244554ea0SWill Paznerend
4344554ea0SWill Pazner
4444554ea0SWill Paznerfunction setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr::CuArray)
4544554ea0SWill Pazner    ptr = Ptr{CeedScalar}(UInt64(pointer(arr)))
4644554ea0SWill Pazner    C.CeedVectorSetArray(v[], mtype, cmode, ptr)
4744554ea0SWill Pazner    if cmode == USE_POINTER
4844554ea0SWill Pazner        v.arr = arr
4944554ea0SWill Pazner    end
5044554ea0SWill Paznerend
5144554ea0SWill Pazner
5244554ea0SWill Paznerstruct FieldsCuda
5344554ea0SWill Pazner    inputs::NTuple{16,Int}
5444554ea0SWill Pazner    outputs::NTuple{16,Int}
5544554ea0SWill Paznerend
5644554ea0SWill Pazner
5744554ea0SWill Paznerfunction generate_kernel(qf_name, kf, dims_in, dims_out)
5844554ea0SWill Pazner    ninputs = length(dims_in)
5944554ea0SWill Pazner    noutputs = length(dims_out)
6044554ea0SWill Pazner
6144554ea0SWill Pazner    input_sz = prod.(dims_in)
6244554ea0SWill Pazner    output_sz = prod.(dims_out)
6344554ea0SWill Pazner
6444554ea0SWill Pazner    f_ins = [Symbol("rqi$i") for i = 1:ninputs]
6544554ea0SWill Pazner    f_outs = [Symbol("rqo$i") for i = 1:noutputs]
6644554ea0SWill Pazner
6744554ea0SWill Pazner    args = Vector{Union{Symbol,Expr}}(undef, ninputs + noutputs)
6844554ea0SWill Pazner    def_ins = Vector{Expr}(undef, ninputs)
6944554ea0SWill Pazner    f_ins_j = Vector{Union{Symbol,Expr}}(undef, ninputs)
7044554ea0SWill Pazner    for i = 1:ninputs
7144554ea0SWill Pazner        if length(dims_in[i]) == 0
7244554ea0SWill Pazner            def_ins[i] = :(local $(f_ins[i]))
7344554ea0SWill Pazner            f_ins_j[i] = f_ins[i]
7444554ea0SWill Pazner            args[i] = f_ins[i]
7544554ea0SWill Pazner        else
7644554ea0SWill Pazner            def_ins[i] =
7780a9ef05SNatalie Beams                :($(f_ins[i]) = LibCEED.MArray{Tuple{$(dims_in[i]...)},CeedScalar}(undef))
7844554ea0SWill Pazner            f_ins_j[i] = :($(f_ins[i])[j])
7980a9ef05SNatalie Beams            args[i] = :(LibCEED.SArray{Tuple{$(dims_in[i]...)},CeedScalar}($(f_ins[i])))
8044554ea0SWill Pazner        end
8144554ea0SWill Pazner    end
8244554ea0SWill Pazner    for i = 1:noutputs
8344554ea0SWill Pazner        args[ninputs+i] = f_outs[i]
8444554ea0SWill Pazner    end
8544554ea0SWill Pazner
8644554ea0SWill Pazner    def_outs = [
87*cdf95791SWill Pazner        :($(f_outs[i]) = LibCEED.MArray{Tuple{$(dims_out[i]...)},CeedScalar}(undef)) for
88*cdf95791SWill Pazner        i = 1:noutputs
8944554ea0SWill Pazner    ]
9044554ea0SWill Pazner
9144554ea0SWill Pazner    device_ptr_type = Core.LLVMPtr{CeedScalar,LibCEED.AS.Global}
9244554ea0SWill Pazner
9344554ea0SWill Pazner    read_quads_in = [
9444554ea0SWill Pazner        :(
9544554ea0SWill Pazner            for j = 1:$(input_sz[i])
9644554ea0SWill Pazner                $(f_ins_j[i]) = unsafe_load(
9744554ea0SWill Pazner                    reinterpret($device_ptr_type, fields.inputs[$i]),
9844554ea0SWill Pazner                    q + (j - 1)*Q,
9944554ea0SWill Pazner                    a,
10044554ea0SWill Pazner                )
10144554ea0SWill Pazner            end
10244554ea0SWill Pazner        ) for i = 1:ninputs
10344554ea0SWill Pazner    ]
10444554ea0SWill Pazner
10544554ea0SWill Pazner    write_quads_out = [
10644554ea0SWill Pazner        :(
10744554ea0SWill Pazner            for j = 1:$(output_sz[i])
10844554ea0SWill Pazner                unsafe_store!(
10944554ea0SWill Pazner                    reinterpret($device_ptr_type, fields.outputs[$i]),
11044554ea0SWill Pazner                    $(f_outs[i])[j],
11144554ea0SWill Pazner                    q + (j - 1)*Q,
11244554ea0SWill Pazner                    a,
11344554ea0SWill Pazner                )
11444554ea0SWill Pazner            end
11544554ea0SWill Pazner        ) for i = 1:noutputs
11644554ea0SWill Pazner    ]
11744554ea0SWill Pazner
11844554ea0SWill Pazner    qf = gensym(qf_name)
11944554ea0SWill Pazner    quote
12044554ea0SWill Pazner        function $qf(ctx_ptr, Q, fields)
12144554ea0SWill Pazner            gd = LibCEED.gridDim()
12244554ea0SWill Pazner            bi = LibCEED.blockIdx()
12344554ea0SWill Pazner            bd = LibCEED.blockDim()
12444554ea0SWill Pazner            ti = LibCEED.threadIdx()
12544554ea0SWill Pazner
12644554ea0SWill Pazner            inc = bd.x*gd.x
12744554ea0SWill Pazner
12844554ea0SWill Pazner            $(def_ins...)
12944554ea0SWill Pazner            $(def_outs...)
13044554ea0SWill Pazner
13144554ea0SWill Pazner            # Alignment for data read/write
13244554ea0SWill Pazner            a = Val($(Base.datatype_alignment(CeedScalar)))
13344554ea0SWill Pazner
13444554ea0SWill Pazner            # Cassette context for replacing intrinsics with CUDA versions
13544554ea0SWill Pazner            ctx = LibCEED.CeedCudaContext()
13644554ea0SWill Pazner
13744554ea0SWill Pazner            for q = (ti.x+(bi.x-1)*bd.x):inc:Q
13844554ea0SWill Pazner                $(read_quads_in...)
13944554ea0SWill Pazner                LibCEED.Cassette.overdub(ctx, $kf, ctx_ptr, $(args...))
14044554ea0SWill Pazner                $(write_quads_out...)
14144554ea0SWill Pazner            end
14244554ea0SWill Pazner            return
14344554ea0SWill Pazner        end
14444554ea0SWill Pazner    end
14544554ea0SWill Paznerend
14644554ea0SWill Pazner
14744554ea0SWill Paznerfunction mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out)
14844554ea0SWill Pazner    k_fn = Core.eval(def_module, generate_kernel(qf_name, kf, dims_in, dims_out))
14944554ea0SWill Pazner    tt = Tuple{Ptr{Nothing},Int32,FieldsCuda}
15044554ea0SWill Pazner    host_k = cufunction(k_fn, tt; maxregs=64)
15144554ea0SWill Pazner    return host_k.fun.handle
15244554ea0SWill Paznerend
15344554ea0SWill Pazner# COV_EXCL_STOP
154