xref: /libCEED/julia/LibCEED.jl/src/Cuda.jl (revision ec3da8bcb94d9f0073544b37b5081a06981a86f7)
1# COV_EXCL_START
2using .CUDA, Cassette
3
4cuda_is_loaded = true
5
6#! format: off
7const cudafuns = (
8    :cos, :cospi, :sin, :sinpi, :tan,
9    :acos, :asin, :atan,
10    :cosh, :sinh, :tanh,
11    :acosh, :asinh, :atanh,
12    :log, :log10, :log1p, :log2,
13    :exp, :exp2, :exp10, :expm1, :ldexp,
14    :abs,
15    :sqrt, :cbrt,
16    :ceil, :floor,
17)
18#! format: on
19
20Cassette.@context CeedCudaContext
21
22@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.kwfunc), f)
23    return Core.kwfunc(f)
24end
25@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.apply_type), args...)
26    return Core.apply_type(args...)
27end
28@inline function Cassette.overdub(
29    ::CeedCudaContext,
30    ::typeof(StaticArrays.Size),
31    x::Type{<:AbstractArray{<:Any,N}},
32) where {N}
33    return StaticArrays.Size(x)
34end
35
36for f in cudafuns
37    @eval @inline function Cassette.overdub(
38        ::CeedCudaContext,
39        ::typeof(Base.$f),
40        x::Union{Float32,Float64},
41    )
42        return CUDA.$f(x)
43    end
44end
45
46function setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr::CuArray)
47    ptr = Ptr{CeedScalar}(UInt64(pointer(arr)))
48    C.CeedVectorSetArray(v[], mtype, cmode, ptr)
49    if cmode == USE_POINTER
50        v.arr = arr
51    end
52end
53
54struct FieldsCuda
55    inputs::NTuple{16,Int}
56    outputs::NTuple{16,Int}
57end
58
59function generate_kernel(qf_name, kf, dims_in, dims_out)
60    ninputs = length(dims_in)
61    noutputs = length(dims_out)
62
63    input_sz = prod.(dims_in)
64    output_sz = prod.(dims_out)
65
66    f_ins = [Symbol("rqi$i") for i = 1:ninputs]
67    f_outs = [Symbol("rqo$i") for i = 1:noutputs]
68
69    args = Vector{Union{Symbol,Expr}}(undef, ninputs + noutputs)
70    def_ins = Vector{Expr}(undef, ninputs)
71    f_ins_j = Vector{Union{Symbol,Expr}}(undef, ninputs)
72    for i = 1:ninputs
73        if length(dims_in[i]) == 0
74            def_ins[i] = :(local $(f_ins[i]))
75            f_ins_j[i] = f_ins[i]
76            args[i] = f_ins[i]
77        else
78            def_ins[i] =
79                :($(f_ins[i]) = LibCEED.MArray{Tuple{$(dims_in[i]...)},Float64}(undef))
80            f_ins_j[i] = :($(f_ins[i])[j])
81            args[i] = :(LibCEED.SArray{Tuple{$(dims_in[i]...)},Float64}($(f_ins[i])))
82        end
83    end
84    for i = 1:noutputs
85        args[ninputs+i] = f_outs[i]
86    end
87
88    def_outs = [
89        :($(f_outs[i]) = LibCEED.MArray{Tuple{$(dims_out[i]...)},Float64}(undef))
90        for i = 1:noutputs
91    ]
92
93    device_ptr_type = Core.LLVMPtr{CeedScalar,LibCEED.AS.Global}
94
95    read_quads_in = [
96        :(
97            for j = 1:$(input_sz[i])
98                $(f_ins_j[i]) = unsafe_load(
99                    reinterpret($device_ptr_type, fields.inputs[$i]),
100                    q + (j - 1)*Q,
101                    a,
102                )
103            end
104        ) for i = 1:ninputs
105    ]
106
107    write_quads_out = [
108        :(
109            for j = 1:$(output_sz[i])
110                unsafe_store!(
111                    reinterpret($device_ptr_type, fields.outputs[$i]),
112                    $(f_outs[i])[j],
113                    q + (j - 1)*Q,
114                    a,
115                )
116            end
117        ) for i = 1:noutputs
118    ]
119
120    qf = gensym(qf_name)
121    quote
122        function $qf(ctx_ptr, Q, fields)
123            gd = LibCEED.gridDim()
124            bi = LibCEED.blockIdx()
125            bd = LibCEED.blockDim()
126            ti = LibCEED.threadIdx()
127
128            inc = bd.x*gd.x
129
130            $(def_ins...)
131            $(def_outs...)
132
133            # Alignment for data read/write
134            a = Val($(Base.datatype_alignment(CeedScalar)))
135
136            # Cassette context for replacing intrinsics with CUDA versions
137            ctx = LibCEED.CeedCudaContext()
138
139            for q = (ti.x+(bi.x-1)*bd.x):inc:Q
140                $(read_quads_in...)
141                LibCEED.Cassette.overdub(ctx, $kf, ctx_ptr, $(args...))
142                $(write_quads_out...)
143            end
144            return
145        end
146    end
147end
148
149function mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out)
150    k_fn = Core.eval(def_module, generate_kernel(qf_name, kf, dims_in, dims_out))
151    tt = Tuple{Ptr{Nothing},Int32,FieldsCuda}
152    host_k = cufunction(k_fn, tt; maxregs=64)
153    return host_k.fun.handle
154end
155# COV_EXCL_STOP
156