1*44554ea0SWill Pazner# COV_EXCL_START 2*44554ea0SWill Paznerusing .CUDA, Cassette 3*44554ea0SWill Pazner 4*44554ea0SWill Paznercuda_is_loaded = true 5*44554ea0SWill Pazner 6*44554ea0SWill Pazner#! format: off 7*44554ea0SWill Paznerconst cudafuns = ( 8*44554ea0SWill Pazner :cos, :cospi, :sin, :sinpi, :tan, 9*44554ea0SWill Pazner :acos, :asin, :atan, 10*44554ea0SWill Pazner :cosh, :sinh, :tanh, 11*44554ea0SWill Pazner :acosh, :asinh, :atanh, 12*44554ea0SWill Pazner :log, :log10, :log1p, :log2, 13*44554ea0SWill Pazner :exp, :exp2, :exp10, :expm1, :ldexp, 14*44554ea0SWill Pazner :abs, 15*44554ea0SWill Pazner :sqrt, :cbrt, 16*44554ea0SWill Pazner :ceil, :floor, 17*44554ea0SWill Pazner) 18*44554ea0SWill Pazner#! format: on 19*44554ea0SWill Pazner 20*44554ea0SWill PaznerCassette.@context CeedCudaContext 21*44554ea0SWill Pazner 22*44554ea0SWill Pazner@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.kwfunc), f) 23*44554ea0SWill Pazner return Core.kwfunc(f) 24*44554ea0SWill Paznerend 25*44554ea0SWill Pazner@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.apply_type), args...) 26*44554ea0SWill Pazner return Core.apply_type(args...) 27*44554ea0SWill Paznerend 28*44554ea0SWill Pazner@inline function Cassette.overdub( 29*44554ea0SWill Pazner ::CeedCudaContext, 30*44554ea0SWill Pazner ::typeof(StaticArrays.Size), 31*44554ea0SWill Pazner x::Type{<:AbstractArray{<:Any,N}}, 32*44554ea0SWill Pazner) where {N} 33*44554ea0SWill Pazner return StaticArrays.Size(x) 34*44554ea0SWill Paznerend 35*44554ea0SWill Pazner 36*44554ea0SWill Paznerfor f in cudafuns 37*44554ea0SWill Pazner @eval @inline function Cassette.overdub( 38*44554ea0SWill Pazner ::CeedCudaContext, 39*44554ea0SWill Pazner ::typeof(Base.$f), 40*44554ea0SWill Pazner x::Union{Float32,Float64}, 41*44554ea0SWill Pazner ) 42*44554ea0SWill Pazner return CUDA.$f(x) 43*44554ea0SWill Pazner end 44*44554ea0SWill Paznerend 45*44554ea0SWill Pazner 46*44554ea0SWill Paznerfunction setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr::CuArray) 47*44554ea0SWill Pazner ptr = Ptr{CeedScalar}(UInt64(pointer(arr))) 48*44554ea0SWill Pazner C.CeedVectorSetArray(v[], mtype, cmode, ptr) 49*44554ea0SWill Pazner if cmode == USE_POINTER 50*44554ea0SWill Pazner v.arr = arr 51*44554ea0SWill Pazner end 52*44554ea0SWill Paznerend 53*44554ea0SWill Pazner 54*44554ea0SWill Paznerstruct FieldsCuda 55*44554ea0SWill Pazner inputs::NTuple{16,Int} 56*44554ea0SWill Pazner outputs::NTuple{16,Int} 57*44554ea0SWill Paznerend 58*44554ea0SWill Pazner 59*44554ea0SWill Paznerfunction generate_kernel(qf_name, kf, dims_in, dims_out) 60*44554ea0SWill Pazner ninputs = length(dims_in) 61*44554ea0SWill Pazner noutputs = length(dims_out) 62*44554ea0SWill Pazner 63*44554ea0SWill Pazner input_sz = prod.(dims_in) 64*44554ea0SWill Pazner output_sz = prod.(dims_out) 65*44554ea0SWill Pazner 66*44554ea0SWill Pazner f_ins = [Symbol("rqi$i") for i = 1:ninputs] 67*44554ea0SWill Pazner f_outs = [Symbol("rqo$i") for i = 1:noutputs] 68*44554ea0SWill Pazner 69*44554ea0SWill Pazner args = Vector{Union{Symbol,Expr}}(undef, ninputs + noutputs) 70*44554ea0SWill Pazner def_ins = Vector{Expr}(undef, ninputs) 71*44554ea0SWill Pazner f_ins_j = Vector{Union{Symbol,Expr}}(undef, ninputs) 72*44554ea0SWill Pazner for i = 1:ninputs 73*44554ea0SWill Pazner if length(dims_in[i]) == 0 74*44554ea0SWill Pazner def_ins[i] = :(local $(f_ins[i])) 75*44554ea0SWill Pazner f_ins_j[i] = f_ins[i] 76*44554ea0SWill Pazner args[i] = f_ins[i] 77*44554ea0SWill Pazner else 78*44554ea0SWill Pazner def_ins[i] = 79*44554ea0SWill Pazner :($(f_ins[i]) = LibCEED.MArray{Tuple{$(dims_in[i]...)},Float64}(undef)) 80*44554ea0SWill Pazner f_ins_j[i] = :($(f_ins[i])[j]) 81*44554ea0SWill Pazner args[i] = :(LibCEED.SArray{Tuple{$(dims_in[i]...)},Float64}($(f_ins[i]))) 82*44554ea0SWill Pazner end 83*44554ea0SWill Pazner end 84*44554ea0SWill Pazner for i = 1:noutputs 85*44554ea0SWill Pazner args[ninputs+i] = f_outs[i] 86*44554ea0SWill Pazner end 87*44554ea0SWill Pazner 88*44554ea0SWill Pazner def_outs = [ 89*44554ea0SWill Pazner :($(f_outs[i]) = LibCEED.MArray{Tuple{$(dims_out[i]...)},Float64}(undef)) 90*44554ea0SWill Pazner for i = 1:noutputs 91*44554ea0SWill Pazner ] 92*44554ea0SWill Pazner 93*44554ea0SWill Pazner device_ptr_type = Core.LLVMPtr{CeedScalar,LibCEED.AS.Global} 94*44554ea0SWill Pazner 95*44554ea0SWill Pazner read_quads_in = [ 96*44554ea0SWill Pazner :( 97*44554ea0SWill Pazner for j = 1:$(input_sz[i]) 98*44554ea0SWill Pazner $(f_ins_j[i]) = unsafe_load( 99*44554ea0SWill Pazner reinterpret($device_ptr_type, fields.inputs[$i]), 100*44554ea0SWill Pazner q + (j - 1)*Q, 101*44554ea0SWill Pazner a, 102*44554ea0SWill Pazner ) 103*44554ea0SWill Pazner end 104*44554ea0SWill Pazner ) for i = 1:ninputs 105*44554ea0SWill Pazner ] 106*44554ea0SWill Pazner 107*44554ea0SWill Pazner write_quads_out = [ 108*44554ea0SWill Pazner :( 109*44554ea0SWill Pazner for j = 1:$(output_sz[i]) 110*44554ea0SWill Pazner unsafe_store!( 111*44554ea0SWill Pazner reinterpret($device_ptr_type, fields.outputs[$i]), 112*44554ea0SWill Pazner $(f_outs[i])[j], 113*44554ea0SWill Pazner q + (j - 1)*Q, 114*44554ea0SWill Pazner a, 115*44554ea0SWill Pazner ) 116*44554ea0SWill Pazner end 117*44554ea0SWill Pazner ) for i = 1:noutputs 118*44554ea0SWill Pazner ] 119*44554ea0SWill Pazner 120*44554ea0SWill Pazner qf = gensym(qf_name) 121*44554ea0SWill Pazner quote 122*44554ea0SWill Pazner function $qf(ctx_ptr, Q, fields) 123*44554ea0SWill Pazner gd = LibCEED.gridDim() 124*44554ea0SWill Pazner bi = LibCEED.blockIdx() 125*44554ea0SWill Pazner bd = LibCEED.blockDim() 126*44554ea0SWill Pazner ti = LibCEED.threadIdx() 127*44554ea0SWill Pazner 128*44554ea0SWill Pazner inc = bd.x*gd.x 129*44554ea0SWill Pazner 130*44554ea0SWill Pazner $(def_ins...) 131*44554ea0SWill Pazner $(def_outs...) 132*44554ea0SWill Pazner 133*44554ea0SWill Pazner # Alignment for data read/write 134*44554ea0SWill Pazner a = Val($(Base.datatype_alignment(CeedScalar))) 135*44554ea0SWill Pazner 136*44554ea0SWill Pazner # Cassette context for replacing intrinsics with CUDA versions 137*44554ea0SWill Pazner ctx = LibCEED.CeedCudaContext() 138*44554ea0SWill Pazner 139*44554ea0SWill Pazner for q = (ti.x+(bi.x-1)*bd.x):inc:Q 140*44554ea0SWill Pazner $(read_quads_in...) 141*44554ea0SWill Pazner LibCEED.Cassette.overdub(ctx, $kf, ctx_ptr, $(args...)) 142*44554ea0SWill Pazner $(write_quads_out...) 143*44554ea0SWill Pazner end 144*44554ea0SWill Pazner return 145*44554ea0SWill Pazner end 146*44554ea0SWill Pazner end 147*44554ea0SWill Paznerend 148*44554ea0SWill Pazner 149*44554ea0SWill Paznerfunction mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out) 150*44554ea0SWill Pazner k_fn = Core.eval(def_module, generate_kernel(qf_name, kf, dims_in, dims_out)) 151*44554ea0SWill Pazner tt = Tuple{Ptr{Nothing},Int32,FieldsCuda} 152*44554ea0SWill Pazner host_k = cufunction(k_fn, tt; maxregs=64) 153*44554ea0SWill Pazner return host_k.fun.handle 154*44554ea0SWill Paznerend 155*44554ea0SWill Pazner# COV_EXCL_STOP 156