1# COV_EXCL_START 2using .CUDA, Cassette 3 4cuda_is_loaded = true 5 6#! format: off 7const cudafuns = ( 8 :cos, :cospi, :sin, :sinpi, :tan, 9 :acos, :asin, :atan, 10 :cosh, :sinh, :tanh, 11 :acosh, :asinh, :atanh, 12 :log, :log10, :log1p, :log2, 13 :exp, :exp2, :exp10, :expm1, :ldexp, 14 :abs, 15 :sqrt, :cbrt, 16 :ceil, :floor, 17) 18#! format: on 19 20Cassette.@context CeedCudaContext 21 22@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.kwfunc), f) 23 return Core.kwfunc(f) 24end 25@inline function Cassette.overdub(::CeedCudaContext, ::typeof(Core.apply_type), args...) 26 return Core.apply_type(args...) 27end 28@inline function Cassette.overdub( 29 ::CeedCudaContext, 30 ::typeof(StaticArrays.Size), 31 x::Type{<:AbstractArray{<:Any,N}}, 32) where {N} 33 return StaticArrays.Size(x) 34end 35 36for f in cudafuns 37 @eval @inline function Cassette.overdub( 38 ::CeedCudaContext, 39 ::typeof(Base.$f), 40 x::Union{Float32,Float64}, 41 ) 42 return CUDA.$f(x) 43 end 44end 45 46function setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr::CuArray) 47 ptr = Ptr{CeedScalar}(UInt64(pointer(arr))) 48 C.CeedVectorSetArray(v[], mtype, cmode, ptr) 49 if cmode == USE_POINTER 50 v.arr = arr 51 end 52end 53 54struct FieldsCuda 55 inputs::NTuple{16,Int} 56 outputs::NTuple{16,Int} 57end 58 59function generate_kernel(qf_name, kf, dims_in, dims_out) 60 ninputs = length(dims_in) 61 noutputs = length(dims_out) 62 63 input_sz = prod.(dims_in) 64 output_sz = prod.(dims_out) 65 66 f_ins = [Symbol("rqi$i") for i = 1:ninputs] 67 f_outs = [Symbol("rqo$i") for i = 1:noutputs] 68 69 args = Vector{Union{Symbol,Expr}}(undef, ninputs + noutputs) 70 def_ins = Vector{Expr}(undef, ninputs) 71 f_ins_j = Vector{Union{Symbol,Expr}}(undef, ninputs) 72 for i = 1:ninputs 73 if length(dims_in[i]) == 0 74 def_ins[i] = :(local $(f_ins[i])) 75 f_ins_j[i] = f_ins[i] 76 args[i] = f_ins[i] 77 else 78 def_ins[i] = 79 :($(f_ins[i]) = LibCEED.MArray{Tuple{$(dims_in[i]...)},Float64}(undef)) 80 f_ins_j[i] = :($(f_ins[i])[j]) 81 args[i] = :(LibCEED.SArray{Tuple{$(dims_in[i]...)},Float64}($(f_ins[i]))) 82 end 83 end 84 for i = 1:noutputs 85 args[ninputs+i] = f_outs[i] 86 end 87 88 def_outs = [ 89 :($(f_outs[i]) = LibCEED.MArray{Tuple{$(dims_out[i]...)},Float64}(undef)) 90 for i = 1:noutputs 91 ] 92 93 device_ptr_type = Core.LLVMPtr{CeedScalar,LibCEED.AS.Global} 94 95 read_quads_in = [ 96 :( 97 for j = 1:$(input_sz[i]) 98 $(f_ins_j[i]) = unsafe_load( 99 reinterpret($device_ptr_type, fields.inputs[$i]), 100 q + (j - 1)*Q, 101 a, 102 ) 103 end 104 ) for i = 1:ninputs 105 ] 106 107 write_quads_out = [ 108 :( 109 for j = 1:$(output_sz[i]) 110 unsafe_store!( 111 reinterpret($device_ptr_type, fields.outputs[$i]), 112 $(f_outs[i])[j], 113 q + (j - 1)*Q, 114 a, 115 ) 116 end 117 ) for i = 1:noutputs 118 ] 119 120 qf = gensym(qf_name) 121 quote 122 function $qf(ctx_ptr, Q, fields) 123 gd = LibCEED.gridDim() 124 bi = LibCEED.blockIdx() 125 bd = LibCEED.blockDim() 126 ti = LibCEED.threadIdx() 127 128 inc = bd.x*gd.x 129 130 $(def_ins...) 131 $(def_outs...) 132 133 # Alignment for data read/write 134 a = Val($(Base.datatype_alignment(CeedScalar))) 135 136 # Cassette context for replacing intrinsics with CUDA versions 137 ctx = LibCEED.CeedCudaContext() 138 139 for q = (ti.x+(bi.x-1)*bd.x):inc:Q 140 $(read_quads_in...) 141 LibCEED.Cassette.overdub(ctx, $kf, ctx_ptr, $(args...)) 142 $(write_quads_out...) 143 end 144 return 145 end 146 end 147end 148 149function mk_cufunction(ceed, def_module, qf_name, kf, dims_in, dims_out) 150 k_fn = Core.eval(def_module, generate_kernel(qf_name, kf, dims_in, dims_out)) 151 tt = Tuple{Ptr{Nothing},Int32,FieldsCuda} 152 host_k = cufunction(k_fn, tt; maxregs=64) 153 return host_k.fun.handle 154end 155# COV_EXCL_STOP 156