144554ea0SWill Pazner# COV_EXCL_START 244554ea0SWill Paznerusing .CUDA, Cassette 344554ea0SWill Pazner 444554ea0SWill Pazner#! format: off 544554ea0SWill Paznerconst cudafuns = ( 644554ea0SWill Pazner :cos, :cospi, :sin, :sinpi, :tan, 744554ea0SWill Pazner :acos, :asin, :atan, 844554ea0SWill Pazner :cosh, :sinh, :tanh, 944554ea0SWill Pazner :acosh, :asinh, :atanh, 1044554ea0SWill Pazner :log, :log10, :log1p, :log2, 1144554ea0SWill Pazner :exp, :exp2, :exp10, :expm1, :ldexp, 1244554ea0SWill Pazner :abs, 1344554ea0SWill Pazner :sqrt, :cbrt, 1444554ea0SWill Pazner :ceil, :floor, 1544554ea0SWill Pazner) 1644554ea0SWill Pazner#! format: on 1744554ea0SWill Pazner 1844554ea0SWill PaznerCassette.@context CeedCudaContext 1944554ea0SWill Pazner 2044554ea0SWill Pazner@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.kwfunc), f) 2144554ea0SWill Pazner return Core.kwfunc(f) 2244554ea0SWill Paznerend 2344554ea0SWill Pazner@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.apply_type), args...) 2444554ea0SWill Pazner return Core.apply_type(args...) 2544554ea0SWill Paznerend 2644554ea0SWill Pazner@inline function Cassette.overdub( 2744554ea0SWill Pazner ::CeedCudaContext, 2844554ea0SWill Pazner ::typeof(StaticArrays.Size), 2944554ea0SWill Pazner x::Type{<:AbstractArray{<:Any,N}}, 3044554ea0SWill Pazner) where {N} 3144554ea0SWill Pazner return StaticArrays.Size(x) 3244554ea0SWill Paznerend 3344554ea0SWill Pazner 3444554ea0SWill Paznerfor f in cudafuns 3544554ea0SWill Pazner @eval @inline function Cassette.overdub( 3644554ea0SWill Pazner ::CeedCudaContext, 3744554ea0SWill Pazner ::typeof(Base.$f), 3844554ea0SWill Pazner x::Union{Float32,Float64}, 3944554ea0SWill Pazner ) 4044554ea0SWill Pazner return CUDA.$f(x) 4144554ea0SWill Pazner end 4244554ea0SWill Paznerend 4344554ea0SWill Pazner 4444554ea0SWill Paznerfunction setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr::CuArray) 4544554ea0SWill Pazner ptr = Ptr{CeedScalar}(UInt64(pointer(arr))) 4644554ea0SWill Pazner C.CeedVectorSetArray(v[], mtype, cmode, ptr) 4744554ea0SWill Pazner if cmode == USE_POINTER 4844554ea0SWill Pazner v.arr = arr 4944554ea0SWill Pazner end 5044554ea0SWill Paznerend 5144554ea0SWill Pazner 5244554ea0SWill Paznerstruct FieldsCuda 5344554ea0SWill Pazner inputs::NTuple{16,Int} 5444554ea0SWill Pazner outputs::NTuple{16,Int} 5544554ea0SWill Paznerend 5644554ea0SWill Pazner 5744554ea0SWill Paznerfunction generate_kernel(qf_name, kf, dims_in, dims_out) 5844554ea0SWill Pazner ninputs = length(dims_in) 5944554ea0SWill Pazner noutputs = length(dims_out) 6044554ea0SWill Pazner 6144554ea0SWill Pazner input_sz = prod.(dims_in) 6244554ea0SWill Pazner output_sz = prod.(dims_out) 6344554ea0SWill Pazner 6444554ea0SWill Pazner f_ins = [Symbol("rqi$i") for i = 1:ninputs] 6544554ea0SWill Pazner f_outs = [Symbol("rqo$i") for i = 1:noutputs] 6644554ea0SWill Pazner 6744554ea0SWill Pazner args = Vector{Union{Symbol,Expr}}(undef, ninputs + noutputs) 6844554ea0SWill Pazner def_ins = Vector{Expr}(undef, ninputs) 6944554ea0SWill Pazner f_ins_j = Vector{Union{Symbol,Expr}}(undef, ninputs) 7044554ea0SWill Pazner for i = 1:ninputs 7144554ea0SWill Pazner if length(dims_in[i]) == 0 7244554ea0SWill Pazner def_ins[i] = :(local $(f_ins[i])) 7344554ea0SWill Pazner f_ins_j[i] = f_ins[i] 7444554ea0SWill Pazner args[i] = f_ins[i] 7544554ea0SWill Pazner else 7644554ea0SWill Pazner def_ins[i] = 7780a9ef05SNatalie Beams :($(f_ins[i]) = LibCEED.MArray{Tuple{$(dims_in[i]...)},CeedScalar}(undef)) 7844554ea0SWill Pazner f_ins_j[i] = :($(f_ins[i])[j]) 7980a9ef05SNatalie Beams args[i] = :(LibCEED.SArray{Tuple{$(dims_in[i]...)},CeedScalar}($(f_ins[i]))) 8044554ea0SWill Pazner end 8144554ea0SWill Pazner end 8244554ea0SWill Pazner for i = 1:noutputs 8344554ea0SWill Pazner args[ninputs+i] = f_outs[i] 8444554ea0SWill Pazner end 8544554ea0SWill Pazner 8644554ea0SWill Pazner def_outs = [ 87*cdf95791SWill Pazner :($(f_outs[i]) = LibCEED.MArray{Tuple{$(dims_out[i]...)},CeedScalar}(undef)) for 88*cdf95791SWill Pazner i = 1:noutputs 8944554ea0SWill Pazner ] 9044554ea0SWill Pazner 9144554ea0SWill Pazner device_ptr_type = Core.LLVMPtr{CeedScalar,LibCEED.AS.Global} 9244554ea0SWill Pazner 9344554ea0SWill Pazner read_quads_in = [ 9444554ea0SWill Pazner :( 9544554ea0SWill Pazner for j = 1:$(input_sz[i]) 9644554ea0SWill Pazner $(f_ins_j[i]) = unsafe_load( 9744554ea0SWill Pazner reinterpret($device_ptr_type, fields.inputs[$i]), 9844554ea0SWill Pazner q + (j - 1)*Q, 9944554ea0SWill Pazner a, 10044554ea0SWill Pazner ) 10144554ea0SWill Pazner end 10244554ea0SWill Pazner ) for i = 1:ninputs 10344554ea0SWill Pazner ] 10444554ea0SWill Pazner 10544554ea0SWill Pazner write_quads_out = [ 10644554ea0SWill Pazner :( 10744554ea0SWill Pazner for j = 1:$(output_sz[i]) 10844554ea0SWill Pazner unsafe_store!( 10944554ea0SWill Pazner reinterpret($device_ptr_type, fields.outputs[$i]), 11044554ea0SWill Pazner $(f_outs[i])[j], 11144554ea0SWill Pazner q + (j - 1)*Q, 11244554ea0SWill Pazner a, 11344554ea0SWill Pazner ) 11444554ea0SWill Pazner end 11544554ea0SWill Pazner ) for i = 1:noutputs 11644554ea0SWill Pazner ] 11744554ea0SWill Pazner 11844554ea0SWill Pazner qf = gensym(qf_name) 11944554ea0SWill Pazner quote 12044554ea0SWill Pazner function $qf(ctx_ptr, Q, fields) 12144554ea0SWill Pazner gd = LibCEED.gridDim() 12244554ea0SWill Pazner bi = LibCEED.blockIdx() 12344554ea0SWill Pazner bd = LibCEED.blockDim() 12444554ea0SWill Pazner ti = LibCEED.threadIdx() 12544554ea0SWill Pazner 12644554ea0SWill Pazner inc = bd.x*gd.x 12744554ea0SWill Pazner 12844554ea0SWill Pazner $(def_ins...) 12944554ea0SWill Pazner $(def_outs...) 13044554ea0SWill Pazner 13144554ea0SWill Pazner # Alignment for data read/write 13244554ea0SWill Pazner a = Val($(Base.datatype_alignment(CeedScalar))) 13344554ea0SWill Pazner 13444554ea0SWill Pazner # Cassette context for replacing intrinsics with CUDA versions 13544554ea0SWill Pazner ctx = LibCEED.CeedCudaContext() 13644554ea0SWill Pazner 13744554ea0SWill Pazner for q = (ti.x+(bi.x-1)*bd.x):inc:Q 13844554ea0SWill Pazner $(read_quads_in...) 13944554ea0SWill Pazner LibCEED.Cassette.overdub(ctx, $kf, ctx_ptr, $(args...)) 14044554ea0SWill Pazner $(write_quads_out...) 14144554ea0SWill Pazner end 14244554ea0SWill Pazner return 14344554ea0SWill Pazner end 14444554ea0SWill Pazner end 14544554ea0SWill Paznerend 14644554ea0SWill Pazner 14744554ea0SWill Paznerfunction mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out) 14844554ea0SWill Pazner k_fn = Core.eval(def_module, generate_kernel(qf_name, kf, dims_in, dims_out)) 14944554ea0SWill Pazner tt = Tuple{Ptr{Nothing},Int32,FieldsCuda} 15044554ea0SWill Pazner host_k = cufunction(k_fn, tt; maxregs=64) 15144554ea0SWill Pazner return host_k.fun.handle 15244554ea0SWill Paznerend 15344554ea0SWill Pazner# COV_EXCL_STOP 154