xref: /libCEED/julia/LibCEED.jl/src/Cuda.jl (revision 3d8e882215d238700cdceb37404f76ca7fa24eaa)
1# COV_EXCL_START
2using .CUDA, Cassette
3
4#! format: off
5const cudafuns = (
6    :cos, :cospi, :sin, :sinpi, :tan,
7    :acos, :asin, :atan,
8    :cosh, :sinh, :tanh,
9    :acosh, :asinh, :atanh,
10    :log, :log10, :log1p, :log2,
11    :exp, :exp2, :exp10, :expm1, :ldexp,
12    :abs,
13    :sqrt, :cbrt,
14    :ceil, :floor,
15)
16#! format: on
17
18Cassette.@context CeedCudaContext
19
20@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.kwfunc), f)
21    return Core.kwfunc(f)
22end
23@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.apply_type), args...)
24    return Core.apply_type(args...)
25end
26@inline function Cassette.overdub(
27    ::CeedCudaContext,
28    ::typeof(StaticArrays.Size),
29    x::Type{<:AbstractArray{<:Any,N}},
30) where {N}
31    return StaticArrays.Size(x)
32end
33
34for f in cudafuns
35    @eval @inline function Cassette.overdub(
36        ::CeedCudaContext,
37        ::typeof(Base.$f),
38        x::Union{Float32,Float64},
39    )
40        return CUDA.$f(x)
41    end
42end
43
44function setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr::CuArray)
45    ptr = Ptr{CeedScalar}(UInt64(pointer(arr)))
46    C.CeedVectorSetArray(v[], mtype, cmode, ptr)
47    if cmode == USE_POINTER
48        v.arr = arr
49    end
50end
51
52struct FieldsCuda
53    inputs::NTuple{16,Int}
54    outputs::NTuple{16,Int}
55end
56
57function generate_kernel(qf_name, kf, dims_in, dims_out)
58    ninputs = length(dims_in)
59    noutputs = length(dims_out)
60
61    input_sz = prod.(dims_in)
62    output_sz = prod.(dims_out)
63
64    f_ins = [Symbol("rqi$i") for i = 1:ninputs]
65    f_outs = [Symbol("rqo$i") for i = 1:noutputs]
66
67    args = Vector{Union{Symbol,Expr}}(undef, ninputs + noutputs)
68    def_ins = Vector{Expr}(undef, ninputs)
69    f_ins_j = Vector{Union{Symbol,Expr}}(undef, ninputs)
70    for i = 1:ninputs
71        if length(dims_in[i]) == 0
72            def_ins[i] = :(local $(f_ins[i]))
73            f_ins_j[i] = f_ins[i]
74            args[i] = f_ins[i]
75        else
76            def_ins[i] =
77                :($(f_ins[i]) = LibCEED.MArray{Tuple{$(dims_in[i]...)},CeedScalar}(undef))
78            f_ins_j[i] = :($(f_ins[i])[j])
79            args[i] = :(LibCEED.SArray{Tuple{$(dims_in[i]...)},CeedScalar}($(f_ins[i])))
80        end
81    end
82    for i = 1:noutputs
83        args[ninputs+i] = f_outs[i]
84    end
85
86    def_outs = [
87        :($(f_outs[i]) = LibCEED.MArray{Tuple{$(dims_out[i]...)},CeedScalar}(undef)) for
88        i = 1:noutputs
89    ]
90
91    device_ptr_type = Core.LLVMPtr{CeedScalar,LibCEED.AS.Global}
92
93    read_quads_in = [
94        :(
95            for j = 1:$(input_sz[i])
96                $(f_ins_j[i]) = unsafe_load(
97                    reinterpret($device_ptr_type, fields.inputs[$i]),
98                    q + (j - 1)*Q,
99                    a,
100                )
101            end
102        ) for i = 1:ninputs
103    ]
104
105    write_quads_out = [
106        :(
107            for j = 1:$(output_sz[i])
108                unsafe_store!(
109                    reinterpret($device_ptr_type, fields.outputs[$i]),
110                    $(f_outs[i])[j],
111                    q + (j - 1)*Q,
112                    a,
113                )
114            end
115        ) for i = 1:noutputs
116    ]
117
118    qf = gensym(qf_name)
119    quote
120        function $qf(ctx_ptr, Q, fields)
121            gd = LibCEED.gridDim()
122            bi = LibCEED.blockIdx()
123            bd = LibCEED.blockDim()
124            ti = LibCEED.threadIdx()
125
126            inc = bd.x*gd.x
127
128            $(def_ins...)
129            $(def_outs...)
130
131            # Alignment for data read/write
132            a = Val($(Base.datatype_alignment(CeedScalar)))
133
134            # Cassette context for replacing intrinsics with CUDA versions
135            ctx = LibCEED.CeedCudaContext()
136
137            for q = (ti.x+(bi.x-1)*bd.x):inc:Q
138                $(read_quads_in...)
139                LibCEED.Cassette.overdub(ctx, $kf, ctx_ptr, $(args...))
140                $(write_quads_out...)
141            end
142            return
143        end
144    end
145end
146
147function mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out)
148    k_fn = Core.eval(def_module, generate_kernel(qf_name, kf, dims_in, dims_out))
149    tt = Tuple{Ptr{Nothing},Int32,FieldsCuda}
150    host_k = cufunction(k_fn, tt; maxregs=64)
151    return host_k.fun.handle
152end
153# COV_EXCL_STOP
154