xref: /libCEED/julia/LibCEED.jl/src/Cuda.jl (revision 44554ea01e90fce366fc2a203c44be15754a38d6)
1*44554ea0SWill Pazner# COV_EXCL_START
2*44554ea0SWill Paznerusing .CUDA, Cassette
3*44554ea0SWill Pazner
4*44554ea0SWill Paznercuda_is_loaded = true
5*44554ea0SWill Pazner
6*44554ea0SWill Pazner#! format: off
7*44554ea0SWill Paznerconst cudafuns = (
8*44554ea0SWill Pazner    :cos, :cospi, :sin, :sinpi, :tan,
9*44554ea0SWill Pazner    :acos, :asin, :atan,
10*44554ea0SWill Pazner    :cosh, :sinh, :tanh,
11*44554ea0SWill Pazner    :acosh, :asinh, :atanh,
12*44554ea0SWill Pazner    :log, :log10, :log1p, :log2,
13*44554ea0SWill Pazner    :exp, :exp2, :exp10, :expm1, :ldexp,
14*44554ea0SWill Pazner    :abs,
15*44554ea0SWill Pazner    :sqrt, :cbrt,
16*44554ea0SWill Pazner    :ceil, :floor,
17*44554ea0SWill Pazner)
18*44554ea0SWill Pazner#! format: on
19*44554ea0SWill Pazner
20*44554ea0SWill PaznerCassette.@context CeedCudaContext
21*44554ea0SWill Pazner
22*44554ea0SWill Pazner@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.kwfunc), f)
23*44554ea0SWill Pazner    return Core.kwfunc(f)
24*44554ea0SWill Paznerend
25*44554ea0SWill Pazner@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.apply_type), args...)
26*44554ea0SWill Pazner    return Core.apply_type(args...)
27*44554ea0SWill Paznerend
28*44554ea0SWill Pazner@inline function Cassette.overdub(
29*44554ea0SWill Pazner    ::CeedCudaContext,
30*44554ea0SWill Pazner    ::typeof(StaticArrays.Size),
31*44554ea0SWill Pazner    x::Type{<:AbstractArray{<:Any,N}},
32*44554ea0SWill Pazner) where {N}
33*44554ea0SWill Pazner    return StaticArrays.Size(x)
34*44554ea0SWill Paznerend
35*44554ea0SWill Pazner
36*44554ea0SWill Paznerfor f in cudafuns
37*44554ea0SWill Pazner    @eval @inline function Cassette.overdub(
38*44554ea0SWill Pazner        ::CeedCudaContext,
39*44554ea0SWill Pazner        ::typeof(Base.$f),
40*44554ea0SWill Pazner        x::Union{Float32,Float64},
41*44554ea0SWill Pazner    )
42*44554ea0SWill Pazner        return CUDA.$f(x)
43*44554ea0SWill Pazner    end
44*44554ea0SWill Paznerend
45*44554ea0SWill Pazner
46*44554ea0SWill Paznerfunction setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr::CuArray)
47*44554ea0SWill Pazner    ptr = Ptr{CeedScalar}(UInt64(pointer(arr)))
48*44554ea0SWill Pazner    C.CeedVectorSetArray(v[], mtype, cmode, ptr)
49*44554ea0SWill Pazner    if cmode == USE_POINTER
50*44554ea0SWill Pazner        v.arr = arr
51*44554ea0SWill Pazner    end
52*44554ea0SWill Paznerend
53*44554ea0SWill Pazner
54*44554ea0SWill Paznerstruct FieldsCuda
55*44554ea0SWill Pazner    inputs::NTuple{16,Int}
56*44554ea0SWill Pazner    outputs::NTuple{16,Int}
57*44554ea0SWill Paznerend
58*44554ea0SWill Pazner
59*44554ea0SWill Paznerfunction generate_kernel(qf_name, kf, dims_in, dims_out)
60*44554ea0SWill Pazner    ninputs = length(dims_in)
61*44554ea0SWill Pazner    noutputs = length(dims_out)
62*44554ea0SWill Pazner
63*44554ea0SWill Pazner    input_sz = prod.(dims_in)
64*44554ea0SWill Pazner    output_sz = prod.(dims_out)
65*44554ea0SWill Pazner
66*44554ea0SWill Pazner    f_ins = [Symbol("rqi$i") for i = 1:ninputs]
67*44554ea0SWill Pazner    f_outs = [Symbol("rqo$i") for i = 1:noutputs]
68*44554ea0SWill Pazner
69*44554ea0SWill Pazner    args = Vector{Union{Symbol,Expr}}(undef, ninputs + noutputs)
70*44554ea0SWill Pazner    def_ins = Vector{Expr}(undef, ninputs)
71*44554ea0SWill Pazner    f_ins_j = Vector{Union{Symbol,Expr}}(undef, ninputs)
72*44554ea0SWill Pazner    for i = 1:ninputs
73*44554ea0SWill Pazner        if length(dims_in[i]) == 0
74*44554ea0SWill Pazner            def_ins[i] = :(local $(f_ins[i]))
75*44554ea0SWill Pazner            f_ins_j[i] = f_ins[i]
76*44554ea0SWill Pazner            args[i] = f_ins[i]
77*44554ea0SWill Pazner        else
78*44554ea0SWill Pazner            def_ins[i] =
79*44554ea0SWill Pazner                :($(f_ins[i]) = LibCEED.MArray{Tuple{$(dims_in[i]...)},Float64}(undef))
80*44554ea0SWill Pazner            f_ins_j[i] = :($(f_ins[i])[j])
81*44554ea0SWill Pazner            args[i] = :(LibCEED.SArray{Tuple{$(dims_in[i]...)},Float64}($(f_ins[i])))
82*44554ea0SWill Pazner        end
83*44554ea0SWill Pazner    end
84*44554ea0SWill Pazner    for i = 1:noutputs
85*44554ea0SWill Pazner        args[ninputs+i] = f_outs[i]
86*44554ea0SWill Pazner    end
87*44554ea0SWill Pazner
88*44554ea0SWill Pazner    def_outs = [
89*44554ea0SWill Pazner        :($(f_outs[i]) = LibCEED.MArray{Tuple{$(dims_out[i]...)},Float64}(undef))
90*44554ea0SWill Pazner        for i = 1:noutputs
91*44554ea0SWill Pazner    ]
92*44554ea0SWill Pazner
93*44554ea0SWill Pazner    device_ptr_type = Core.LLVMPtr{CeedScalar,LibCEED.AS.Global}
94*44554ea0SWill Pazner
95*44554ea0SWill Pazner    read_quads_in = [
96*44554ea0SWill Pazner        :(
97*44554ea0SWill Pazner            for j = 1:$(input_sz[i])
98*44554ea0SWill Pazner                $(f_ins_j[i]) = unsafe_load(
99*44554ea0SWill Pazner                    reinterpret($device_ptr_type, fields.inputs[$i]),
100*44554ea0SWill Pazner                    q + (j - 1)*Q,
101*44554ea0SWill Pazner                    a,
102*44554ea0SWill Pazner                )
103*44554ea0SWill Pazner            end
104*44554ea0SWill Pazner        ) for i = 1:ninputs
105*44554ea0SWill Pazner    ]
106*44554ea0SWill Pazner
107*44554ea0SWill Pazner    write_quads_out = [
108*44554ea0SWill Pazner        :(
109*44554ea0SWill Pazner            for j = 1:$(output_sz[i])
110*44554ea0SWill Pazner                unsafe_store!(
111*44554ea0SWill Pazner                    reinterpret($device_ptr_type, fields.outputs[$i]),
112*44554ea0SWill Pazner                    $(f_outs[i])[j],
113*44554ea0SWill Pazner                    q + (j - 1)*Q,
114*44554ea0SWill Pazner                    a,
115*44554ea0SWill Pazner                )
116*44554ea0SWill Pazner            end
117*44554ea0SWill Pazner        ) for i = 1:noutputs
118*44554ea0SWill Pazner    ]
119*44554ea0SWill Pazner
120*44554ea0SWill Pazner    qf = gensym(qf_name)
121*44554ea0SWill Pazner    quote
122*44554ea0SWill Pazner        function $qf(ctx_ptr, Q, fields)
123*44554ea0SWill Pazner            gd = LibCEED.gridDim()
124*44554ea0SWill Pazner            bi = LibCEED.blockIdx()
125*44554ea0SWill Pazner            bd = LibCEED.blockDim()
126*44554ea0SWill Pazner            ti = LibCEED.threadIdx()
127*44554ea0SWill Pazner
128*44554ea0SWill Pazner            inc = bd.x*gd.x
129*44554ea0SWill Pazner
130*44554ea0SWill Pazner            $(def_ins...)
131*44554ea0SWill Pazner            $(def_outs...)
132*44554ea0SWill Pazner
133*44554ea0SWill Pazner            # Alignment for data read/write
134*44554ea0SWill Pazner            a = Val($(Base.datatype_alignment(CeedScalar)))
135*44554ea0SWill Pazner
136*44554ea0SWill Pazner            # Cassette context for replacing intrinsics with CUDA versions
137*44554ea0SWill Pazner            ctx = LibCEED.CeedCudaContext()
138*44554ea0SWill Pazner
139*44554ea0SWill Pazner            for q = (ti.x+(bi.x-1)*bd.x):inc:Q
140*44554ea0SWill Pazner                $(read_quads_in...)
141*44554ea0SWill Pazner                LibCEED.Cassette.overdub(ctx, $kf, ctx_ptr, $(args...))
142*44554ea0SWill Pazner                $(write_quads_out...)
143*44554ea0SWill Pazner            end
144*44554ea0SWill Pazner            return
145*44554ea0SWill Pazner        end
146*44554ea0SWill Pazner    end
147*44554ea0SWill Paznerend
148*44554ea0SWill Pazner
149*44554ea0SWill Paznerfunction mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out)
150*44554ea0SWill Pazner    k_fn = Core.eval(def_module, generate_kernel(qf_name, kf, dims_in, dims_out))
151*44554ea0SWill Pazner    tt = Tuple{Ptr{Nothing},Int32,FieldsCuda}
152*44554ea0SWill Pazner    host_k = cufunction(k_fn, tt; maxregs=64)
153*44554ea0SWill Pazner    return host_k.fun.handle
154*44554ea0SWill Paznerend
155*44554ea0SWill Pazner# COV_EXCL_STOP
156