# COV_EXCL_START using .CUDA, Cassette cuda_is_loaded = true #! format: off const cudafuns = ( :cos, :cospi, :sin, :sinpi, :tan, :acos, :asin, :atan, :cosh, :sinh, :tanh, :acosh, :asinh, :atanh, :log, :log10, :log1p, :log2, :exp, :exp2, :exp10, :expm1, :ldexp, :abs, :sqrt, :cbrt, :ceil, :floor, ) #! format: on Cassette.@context CeedCudaContext @inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.kwfunc), f) return Core.kwfunc(f) end @inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.apply_type), args...) return Core.apply_type(args...) end @inline function Cassette.overdub( ::CeedCudaContext, ::typeof(StaticArrays.Size), x::Type{<:AbstractArray{<:Any,N}}, ) where {N} return StaticArrays.Size(x) end for f in cudafuns @eval @inline function Cassette.overdub( ::CeedCudaContext, ::typeof(Base.$f), x::Union{Float32,Float64}, ) return CUDA.$f(x) end end function setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr::CuArray) ptr = Ptr{CeedScalar}(UInt64(pointer(arr))) C.CeedVectorSetArray(v[], mtype, cmode, ptr) if cmode == USE_POINTER v.arr = arr end end struct FieldsCuda inputs::NTuple{16,Int} outputs::NTuple{16,Int} end function generate_kernel(qf_name, kf, dims_in, dims_out) ninputs = length(dims_in) noutputs = length(dims_out) input_sz = prod.(dims_in) output_sz = prod.(dims_out) f_ins = [Symbol("rqi$i") for i = 1:ninputs] f_outs = [Symbol("rqo$i") for i = 1:noutputs] args = Vector{Union{Symbol,Expr}}(undef, ninputs + noutputs) def_ins = Vector{Expr}(undef, ninputs) f_ins_j = Vector{Union{Symbol,Expr}}(undef, ninputs) for i = 1:ninputs if length(dims_in[i]) == 0 def_ins[i] = :(local $(f_ins[i])) f_ins_j[i] = f_ins[i] args[i] = f_ins[i] else def_ins[i] = :($(f_ins[i]) = LibCEED.MArray{Tuple{$(dims_in[i]...)},Float64}(undef)) f_ins_j[i] = :($(f_ins[i])[j]) args[i] = :(LibCEED.SArray{Tuple{$(dims_in[i]...)},Float64}($(f_ins[i]))) end end for i = 1:noutputs args[ninputs+i] = f_outs[i] end def_outs = [ :($(f_outs[i]) = LibCEED.MArray{Tuple{$(dims_out[i]...)},Float64}(undef)) for i = 1:noutputs ] device_ptr_type = Core.LLVMPtr{CeedScalar,LibCEED.AS.Global} read_quads_in = [ :( for j = 1:$(input_sz[i]) $(f_ins_j[i]) = unsafe_load( reinterpret($device_ptr_type, fields.inputs[$i]), q + (j - 1)*Q, a, ) end ) for i = 1:ninputs ] write_quads_out = [ :( for j = 1:$(output_sz[i]) unsafe_store!( reinterpret($device_ptr_type, fields.outputs[$i]), $(f_outs[i])[j], q + (j - 1)*Q, a, ) end ) for i = 1:noutputs ] qf = gensym(qf_name) quote function $qf(ctx_ptr, Q, fields) gd = LibCEED.gridDim() bi = LibCEED.blockIdx() bd = LibCEED.blockDim() ti = LibCEED.threadIdx() inc = bd.x*gd.x $(def_ins...) $(def_outs...) # Alignment for data read/write a = Val($(Base.datatype_alignment(CeedScalar))) # Cassette context for replacing intrinsics with CUDA versions ctx = LibCEED.CeedCudaContext() for q = (ti.x+(bi.x-1)*bd.x):inc:Q $(read_quads_in...) LibCEED.Cassette.overdub(ctx, $kf, ctx_ptr, $(args...)) $(write_quads_out...) end return end end end function mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out) k_fn = Core.eval(def_module, generate_kernel(qf_name, kf, dims_in, dims_out)) tt = Tuple{Ptr{Nothing},Int32,FieldsCuda} host_k = cufunction(k_fn, tt; maxregs=64) return host_k.fun.handle end # COV_EXCL_STOP