1import LinearAlgebra: norm, axpy! 2 3abstract type AbstractCeedVector end 4 5struct CeedVectorActive <: AbstractCeedVector end 6Base.getindex(::CeedVectorActive) = C.CEED_VECTOR_ACTIVE[] 7 8struct CeedVectorNone <: AbstractCeedVector end 9Base.getindex(::CeedVectorNone) = C.CEED_VECTOR_NONE[] 10 11mutable struct CeedVector <: AbstractCeedVector 12 ref::RefValue{C.CeedVector} 13 arr::Union{Nothing,AbstractArray} 14 CeedVector(ref::Ref{C.CeedVector}) = new(ref, nothing) 15end 16 17""" 18 CeedVector(c::Ceed, len::Integer; allocate::Bool=true) 19 20Creates a `CeedVector` of given length. If `allocate` is false, then no memory is allocated. 21Attemping to access the vector before calling `setarray!` will fail. By default, `allocate` is 22true, and libCEED will allocate (host) memory for the vector. 23""" 24function CeedVector(c::Ceed, len::Integer; allocate::Bool=true) 25 ref = Ref{C.CeedVector}() 26 C.CeedVectorCreate(c[], len, ref) 27 obj = CeedVector(ref) 28 finalizer(obj) do x 29 # ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring), "Finalizing %s.\n", repr(x)) 30 destroy(x) 31 end 32 if allocate 33 setarray!(obj, MEM_HOST, COPY_VALUES, C_NULL) 34 end 35 return obj 36end 37destroy(v::CeedVector) = C.CeedVectorDestroy(v.ref) # COV_EXCL_LINE 38Base.getindex(v::CeedVector) = v.ref[] 39 40Base.summary(io::IO, v::CeedVector) = print(io, length(v), "-element CeedVector") 41function Base.show(io::IO, ::MIME"text/plain", v::CeedVector) 42 summary(io, v) 43 println(io, ":") 44 witharray_read(v, MEM_HOST) do arr 45 Base.print_array(io, arr) 46 end 47end 48Base.show(io::IO, v::CeedVector) = witharray_read(a -> show(io, a), v, MEM_HOST) 49 50function Base.length(::Type{T}, v::CeedVector) where {T} 51 len = Ref{C.CeedSize}() 52 C.CeedVectorGetLength(v[], len) 53 return T(len[]) 54end 55 56Base.ndims(::CeedVector) = 1 57Base.ndims(::Type{CeedVector}) = 1 58Base.axes(v::CeedVector) = (Base.OneTo(length(v)),) 59Base.size(v::CeedVector) = (length(Int, v),) 60Base.length(v::CeedVector) = length(Int, v) 61 62""" 63 setvalue!(v::CeedVector, val::Real) 64 65Set the [`CeedVector`](@ref) to a constant value. 66""" 67setvalue!(v::CeedVector, val::Real) = C.CeedVectorSetValue(v[], val) 68""" 69 setindex!(v::CeedVector, val::Real) 70 v[] = val 71 72Set the [`CeedVector`](@ref) to a constant value, synonymous to [`setvalue!`](@ref). 73""" 74Base.setindex!(v::CeedVector, val::Real) = setvalue!(v, val) 75 76""" 77 norm(v::CeedVector, ntype::NormType) 78 79Return the norm of the given [`CeedVector`](@ref). 80 81The norm type can either be specified as one of `NORM_1`, `NORM_2`, `NORM_MAX`. 82""" 83function norm(v::CeedVector, ntype::NormType) 84 nrm = Ref{CeedScalar}() 85 C.CeedVectorNorm(v[], ntype, nrm) 86 nrm[] 87end 88 89""" 90 norm(v::CeedVector, p::Real) 91 92Return the norm of the given [`CeedVector`](@ref), see [`norm(::CeedVector, 93::NormType)`](@ref). 94 95`p` can have value 1, 2, or Inf, corresponding to `NORM_1`, `NORM_2`, and `NORM_MAX`, 96respectively. 97""" 98function norm(v::CeedVector, p::Real) 99 if p == 1 100 ntype = NORM_1 101 elseif p == 2 102 ntype = NORM_2 103 elseif isinf(p) 104 ntype = NORM_MAX 105 else 106 error("norm(v::CeedVector, p): p must be 1, 2, or Inf") 107 end 108 norm(v, ntype) 109end 110 111""" 112 reciprocal!(v::CeedVector) 113 114Set `v` to be equal to its elementwise reciprocal. 115""" 116reciprocal!(v::CeedVector) = C.CeedVectorReciprocal(v[]) 117 118""" 119 setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr) 120 121Set the array used by a [`CeedVector`](@ref), freeing any previously allocated array if 122applicable. The backend may copy values to a different [`MemType`](@ref). See also 123[`syncarray!`](@ref) and [`takearray!`](@ref). 124 125!!! warning "Avoid OWN_POINTER CopyMode" 126 The [`CopyMode`](@ref) `OWN_POINTER` is not suitable for use with arrays that are 127 allocated by Julia, since those cannot be properly freed from libCEED. 128""" 129function setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr) 130 C.CeedVectorSetArray(v[], mtype, cmode, arr) 131 if cmode == USE_POINTER 132 v.arr = arr 133 end 134end 135 136""" 137 syncarray!(v::CeedVector, mtype::MemType) 138 139Sync the [`CeedVector`](@ref) to a specified [`MemType`](@ref). This function is used to 140force synchronization of arrays set with [`setarray!`](@ref). If the requested memtype is 141already synchronized, this function results in a no-op. 142""" 143syncarray!(v::CeedVector, mtype::MemType) = C.CeedVectorSyncArray(v[], mtype) 144 145""" 146 takearray!(v::CeedVector, mtype::MemType) 147 148Take ownership of the [`CeedVector`](@ref) array and remove the array from the 149[`CeedVector`](@ref). The caller is responsible for managing and freeing the array. The 150array is returns as a `Ptr{CeedScalar}`. 151""" 152function takearray!(v::CeedVector, mtype::MemType) 153 ptr = Ref{Ptr{CeedScalar}}() 154 C.CeedVectorTakeArray(v[], mtype, ptr) 155 v.arr = nothing 156 ptr[] 157end 158 159# Helper function to parse arguments of @witharray and @witharray_read 160function witharray_parse(assignment, args) 161 if !Meta.isexpr(assignment, :(=)) 162 error("@witharray must have first argument of the form v_arr=v") # COV_EXCL_LINE 163 end 164 arr = assignment.args[1] 165 v = assignment.args[2] 166 mtype = MEM_HOST 167 sz = :((length($(esc(v))),)) 168 body = args[end] 169 for i = 1:length(args)-1 170 a = args[i] 171 if !Meta.isexpr(a, :(=)) 172 error("Incorrect call to @witharray or @witharray_read") # COV_EXCL_LINE 173 end 174 if a.args[1] == :mtype 175 mtype = a.args[2] 176 elseif a.args[1] == :size 177 sz = esc(a.args[2]) 178 end 179 end 180 arr, v, sz, mtype, body 181end 182 183""" 184 @witharray(v_arr=v, [size=(dims...)], [mtype=MEM_HOST], body) 185 186Executes `body`, having extracted the contents of the [`CeedVector`](@ref) `v` as an array 187with name `v_arr`. If the [`memory type`](@ref MemType) `mtype` is not provided, `MEM_HOST` 188will be used. If the size is not specified, a flat vector will be assumed. 189 190# Examples 191Negate the contents of `CeedVector` `v`: 192``` 193@witharray v_arr=v v_arr .*= -1.0 194``` 195""" 196macro witharray(assignment, args...) 197 arr, v, sz, mtype, body = witharray_parse(assignment, args) 198 quote 199 arr_ref = Ref{Ptr{C.CeedScalar}}() 200 C.CeedVectorGetArray($(esc(v))[], $(esc(mtype)), arr_ref) 201 try 202 $(esc(arr)) = UnsafeArray(arr_ref[], Int.($sz)) 203 $(esc(body)) 204 finally 205 C.CeedVectorRestoreArray($(esc(v))[], arr_ref) 206 end 207 end 208end 209 210""" 211 @witharray_read(v_arr=v, [size=(dims...)], [mtype=MEM_HOST], body) 212 213Same as [`@witharray`](@ref), but provides read-only access to the data. 214""" 215macro witharray_read(assignment, args...) 216 arr, v, sz, mtype, body = witharray_parse(assignment, args) 217 quote 218 arr_ref = Ref{Ptr{C.CeedScalar}}() 219 C.CeedVectorGetArrayRead($(esc(v))[], $(esc(mtype)), arr_ref) 220 try 221 $(esc(arr)) = UnsafeArray(arr_ref[], Int.($sz)) 222 $(esc(body)) 223 finally 224 C.CeedVectorRestoreArrayRead($(esc(v))[], arr_ref) 225 end 226 end 227end 228 229""" 230 setindex!(v::CeedVector, v2::AbstractArray) 231 v[] = v2 232 233Sets the values of [`CeedVector`](@ref) `v` equal to those of `v2` using broadcasting. 234""" 235Base.setindex!(v::CeedVector, v2::AbstractArray) = @witharray(a = v, a .= v2) 236 237""" 238 CeedVector(c::Ceed, v2::AbstractVector; mtype=MEM_HOST, cmode=COPY_VALUES) 239 240Creates a new [`CeedVector`](@ref) using the contents of the given vector `v2`. By default, 241the contents of `v2` will be copied to the new [`CeedVector`](@ref), but this behavior can 242be changed by specifying a different `cmode`. 243""" 244function CeedVector(c::Ceed, v2::AbstractVector; mtype=MEM_HOST, cmode=COPY_VALUES) 245 v = CeedVector(c, length(v2); allocate=false) 246 setarray!(v, mtype, cmode, v2) 247 v 248end 249 250""" 251 Vector(v::CeedVector) 252 253Create a new `Vector` by copying the contents of `v`. 254""" 255function Base.Vector(v::CeedVector) 256 v2 = Vector{CeedScalar}(undef, length(v)) 257 @witharray_read(a = v, v2 .= a) 258end 259 260""" 261 witharray(f, v::CeedVector, mtype=MEM_HOST) 262 263Calls `f` with an array containing the data of the `CeedVector` `v`, using [`memory 264type`](@ref MemType) `mtype`. 265 266Because of performance issues involving closures, if `f` is a complex operation, it may be 267more efficient to use the macro version `@witharray` (cf. the section on "Performance of 268captured variable" in the [Julia 269documentation](https://docs.julialang.org/en/v1/manual/performance-tips) and related [GitHub 270issue](https://github.com/JuliaLang/julia/issues/15276). 271 272# Examples 273 274Return the sum of a vector: 275``` 276witharray(sum, v) 277``` 278""" 279function witharray(f, v::CeedVector, mtype::MemType=MEM_HOST) 280 arr_ref = Ref{Ptr{C.CeedScalar}}() 281 C.CeedVectorGetArray(v[], mtype, arr_ref) 282 arr = UnsafeArray(arr_ref[], (length(v),)) 283 res = try 284 f(arr) 285 finally 286 C.CeedVectorRestoreArray(v[], arr_ref) 287 end 288 return res 289end 290 291""" 292 witharray_read(f, v::CeedVector, mtype::MemType=MEM_HOST) 293 294Same as [`witharray`](@ref), but with read-only access to the data. 295 296# Examples 297 298Display the contents of a vector: 299``` 300witharray_read(display, v) 301``` 302""" 303function witharray_read(f, v::CeedVector, mtype::MemType=MEM_HOST) 304 arr_ref = Ref{Ptr{C.CeedScalar}}() 305 C.CeedVectorGetArrayRead(v[], mtype, arr_ref) 306 arr = UnsafeArray(arr_ref[], (length(v),)) 307 res = try 308 f(arr) 309 finally 310 C.CeedVectorRestoreArrayRead(v[], arr_ref) 311 end 312 return res 313end 314 315""" 316 scale!(v::CeedVector, a::Real) 317 318Overwrite `v` with `a*v` for scalar `a`. Returns `v`. 319""" 320function scale!(v::CeedVector, a::Real) 321 C.CeedVectorScale(v[], a) 322 return v 323end 324 325""" 326 axpy!(a::Real, x::CeedVector, y::CeedVector) 327 328Overwrite `y` with `x*a + y`, where `a` is a scalar. Returns `y`. 329 330!!! warning "Different argument order" 331 In order to be consistent with `LinearAlgebra.axpy!`, the arguments are passed in order: `a`, 332 `x`, `y`. This is different than the order of arguments of the C function `CeedVectorAXPY`. 333""" 334function axpy!(a::Real, x::CeedVector, y::CeedVector) 335 C.CeedVectorAXPY(y[], a, x[]) 336 return y 337end 338 339""" 340 pointwisemult!(w::CeedVector, x::CeedVector, y::CeedVector) 341 342Overwrite `w` with `x .* y`. Any subset of x, y, and w may be the same vector. Returns `w`. 343""" 344function pointwisemult!(w::CeedVector, x::CeedVector, y::CeedVector) 345 C.CeedVectorPointwiseMult(w[], x[], y[]) 346 return w 347end 348