1# COV_EXCL_START 2using .CUDA, Cassette 3 4#! format: off 5const cudafuns = ( 6 :cos, :cospi, :sin, :sinpi, :tan, 7 :acos, :asin, :atan, 8 :cosh, :sinh, :tanh, 9 :acosh, :asinh, :atanh, 10 :log, :log10, :log1p, :log2, 11 :exp, :exp2, :exp10, :expm1, :ldexp, 12 :abs, 13 :sqrt, :cbrt, 14 :ceil, :floor, 15) 16#! format: on 17 18Cassette.@context CeedCudaContext 19 20@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.kwfunc), f) 21 return Core.kwfunc(f) 22end 23@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.apply_type), args...) 24 return Core.apply_type(args...) 25end 26@inline function Cassette.overdub( 27 ::CeedCudaContext, 28 ::typeof(StaticArrays.Size), 29 x::Type{<:AbstractArray{<:Any,N}}, 30) where {N} 31 return StaticArrays.Size(x) 32end 33 34for f in cudafuns 35 @eval @inline function Cassette.overdub( 36 ::CeedCudaContext, 37 ::typeof(Base.$f), 38 x::Union{Float32,Float64}, 39 ) 40 return CUDA.$f(x) 41 end 42end 43 44function setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr::CuArray) 45 ptr = Ptr{CeedScalar}(UInt64(pointer(arr))) 46 C.CeedVectorSetArray(v[], mtype, cmode, ptr) 47 if cmode == USE_POINTER 48 v.arr = arr 49 end 50end 51 52struct FieldsCuda 53 inputs::NTuple{16,Int} 54 outputs::NTuple{16,Int} 55end 56 57function generate_kernel(qf_name, kf, dims_in, dims_out) 58 ninputs = length(dims_in) 59 noutputs = length(dims_out) 60 61 input_sz = prod.(dims_in) 62 output_sz = prod.(dims_out) 63 64 f_ins = [Symbol("rqi$i") for i = 1:ninputs] 65 f_outs = [Symbol("rqo$i") for i = 1:noutputs] 66 67 args = Vector{Union{Symbol,Expr}}(undef, ninputs + noutputs) 68 def_ins = Vector{Expr}(undef, ninputs) 69 f_ins_j = Vector{Union{Symbol,Expr}}(undef, ninputs) 70 for i = 1:ninputs 71 if length(dims_in[i]) == 0 72 def_ins[i] = :(local $(f_ins[i])) 73 f_ins_j[i] = f_ins[i] 74 args[i] = f_ins[i] 75 else 76 def_ins[i] = 77 :($(f_ins[i]) = LibCEED.MArray{Tuple{$(dims_in[i]...)},CeedScalar}(undef)) 78 f_ins_j[i] = :($(f_ins[i])[j]) 79 args[i] = :(LibCEED.SArray{Tuple{$(dims_in[i]...)},CeedScalar}($(f_ins[i]))) 80 end 81 end 82 for i = 1:noutputs 83 args[ninputs+i] = f_outs[i] 84 end 85 86 def_outs = [ 87 :($(f_outs[i]) = LibCEED.MArray{Tuple{$(dims_out[i]...)},CeedScalar}(undef)) for 88 i = 1:noutputs 89 ] 90 91 device_ptr_type = Core.LLVMPtr{CeedScalar,LibCEED.AS.Global} 92 93 read_quads_in = [ 94 :( 95 for j = 1:$(input_sz[i]) 96 $(f_ins_j[i]) = unsafe_load( 97 reinterpret($device_ptr_type, fields.inputs[$i]), 98 q + (j - 1)*Q, 99 a, 100 ) 101 end 102 ) for i = 1:ninputs 103 ] 104 105 write_quads_out = [ 106 :( 107 for j = 1:$(output_sz[i]) 108 unsafe_store!( 109 reinterpret($device_ptr_type, fields.outputs[$i]), 110 $(f_outs[i])[j], 111 q + (j - 1)*Q, 112 a, 113 ) 114 end 115 ) for i = 1:noutputs 116 ] 117 118 qf = gensym(qf_name) 119 quote 120 function $qf(ctx_ptr, Q, fields) 121 gd = LibCEED.gridDim() 122 bi = LibCEED.blockIdx() 123 bd = LibCEED.blockDim() 124 ti = LibCEED.threadIdx() 125 126 inc = bd.x*gd.x 127 128 $(def_ins...) 129 $(def_outs...) 130 131 # Alignment for data read/write 132 a = Val($(Base.datatype_alignment(CeedScalar))) 133 134 # Cassette context for replacing intrinsics with CUDA versions 135 ctx = LibCEED.CeedCudaContext() 136 137 for q = (ti.x+(bi.x-1)*bd.x):inc:Q 138 $(read_quads_in...) 139 LibCEED.Cassette.overdub(ctx, $kf, ctx_ptr, $(args...)) 140 $(write_quads_out...) 141 end 142 return 143 end 144 end 145end 146 147function mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out) 148 k_fn = Core.eval(def_module, generate_kernel(qf_name, kf, dims_in, dims_out)) 149 tt = Tuple{Ptr{Nothing},Int32,FieldsCuda} 150 host_k = cufunction(k_fn, tt; maxregs=64) 151 return host_k.fun.handle 152end 153# COV_EXCL_STOP 154