1import LinearAlgebra: norm, axpy! 2 3abstract type AbstractCeedVector end 4 5struct CeedVectorActive <: AbstractCeedVector end 6Base.getindex(::CeedVectorActive) = C.CEED_VECTOR_ACTIVE[] 7 8struct CeedVectorNone <: AbstractCeedVector end 9Base.getindex(::CeedVectorNone) = C.CEED_VECTOR_NONE[] 10 11mutable struct CeedVector <: AbstractCeedVector 12 ref::RefValue{C.CeedVector} 13 arr::Union{Nothing,AbstractArray} 14 CeedVector(ref::Ref{C.CeedVector}) = new(ref, nothing) 15end 16 17""" 18 CeedVector(c::Ceed, len::Integer) 19 20Creates a `CeedVector` of given length. 21""" 22function CeedVector(c::Ceed, len::Integer) 23 ref = Ref{C.CeedVector}() 24 C.CeedVectorCreate(c[], len, ref) 25 obj = CeedVector(ref) 26 finalizer(obj) do x 27 # ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring), "Finalizing %s.\n", repr(x)) 28 destroy(x) 29 end 30 return obj 31end 32destroy(v::CeedVector) = C.CeedVectorDestroy(v.ref) # COV_EXCL_LINE 33Base.getindex(v::CeedVector) = v.ref[] 34 35Base.summary(io::IO, v::CeedVector) = print(io, length(v), "-element CeedVector") 36function Base.show(io::IO, ::MIME"text/plain", v::CeedVector) 37 summary(io, v) 38 println(io, ":") 39 witharray_read(v, MEM_HOST) do arr 40 Base.print_array(io, arr) 41 end 42end 43Base.show(io::IO, v::CeedVector) = witharray_read(a -> show(io, a), v, MEM_HOST) 44 45function Base.length(::Type{T}, v::CeedVector) where {T} 46 len = Ref{C.CeedInt}() 47 C.CeedVectorGetLength(v[], len) 48 return T(len[]) 49end 50 51Base.ndims(::CeedVector) = 1 52Base.ndims(::Type{CeedVector}) = 1 53Base.axes(v::CeedVector) = (Base.OneTo(length(v)),) 54Base.size(v::CeedVector) = (length(Int, v),) 55Base.length(v::CeedVector) = length(Int, v) 56 57""" 58 setvalue!(v::CeedVector, val::CeedScalar) 59 60Set the [`CeedVector`](@ref) to a constant value. 61""" 62setvalue!(v::CeedVector, val::CeedScalar) = C.CeedVectorSetValue(v[], val) 63""" 64 setindex!(v::CeedVector, val::CeedScalar) 65 v[] = val 66 67Set the [`CeedVector`](@ref) to a constant value, synonymous to [`setvalue!`](@ref). 68""" 69Base.setindex!(v::CeedVector, val::CeedScalar) = setvalue!(v, val) 70 71""" 72 norm(v::CeedVector, ntype::NormType) 73 74Return the norm of the given [`CeedVector`](@ref). 75 76The norm type can either be specified as one of `NORM_1`, `NORM_2`, `NORM_MAX`. 77""" 78function norm(v::CeedVector, ntype::NormType) 79 nrm = Ref{CeedScalar}() 80 C.CeedVectorNorm(v[], ntype, nrm) 81 nrm[] 82end 83 84""" 85 norm(v::CeedVector, p::Real) 86 87Return the norm of the given [`CeedVector`](@ref), see [`norm(::CeedVector, 88::NormType)`](@ref). 89 90`p` can have value 1, 2, or Inf, corresponding to `NORM_1`, `NORM_2`, and `NORM_MAX`, 91respectively. 92""" 93function norm(v::CeedVector, p::Real) 94 if p == 1 95 ntype = NORM_1 96 elseif p == 2 97 ntype = NORM_2 98 elseif isinf(p) 99 ntype = NORM_MAX 100 else 101 error("norm(v::CeedVector, p): p must be 1, 2, or Inf") 102 end 103 norm(v, ntype) 104end 105 106""" 107 reciprocal!(v::CeedVector) 108 109Set `v` to be equal to its elementwise reciprocal. 110""" 111reciprocal!(v::CeedVector) = C.CeedVectorReciprocal(v[]) 112 113""" 114 setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr) 115 116Set the array used by a [`CeedVector`](@ref), freeing any previously allocated array if 117applicable. The backend may copy values to a different [`MemType`](@ref). See also 118[`syncarray!`](@ref) and [`takearray!`](@ref). 119 120!!! warning "Avoid OWN_POINTER CopyMode" 121 The [`CopyMode`](@ref) `OWN_POINTER` is not suitable for use with arrays that are 122 allocated by Julia, since those cannot be properly freed from libCEED. 123""" 124function setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr) 125 C.CeedVectorSetArray(v[], mtype, cmode, arr) 126 if cmode == USE_POINTER 127 v.arr = arr 128 end 129end 130 131""" 132 syncarray!(v::CeedVector, mtype::MemType) 133 134Sync the [`CeedVector`](@ref) to a specified [`MemType`](@ref). This function is used to 135force synchronization of arrays set with [`setarray!`](@ref). If the requested memtype is 136already synchronized, this function results in a no-op. 137""" 138syncarray!(v::CeedVector, mtype::MemType) = C.CeedVectorSyncArray(v[], mtype) 139 140""" 141 takearray!(v::CeedVector, mtype::MemType) 142 143Take ownership of the [`CeedVector`](@ref) array and remove the array from the 144[`CeedVector`](@ref). The caller is responsible for managing and freeing the array. The 145array is returns as a `Ptr{CeedScalar}`. 146""" 147function takearray!(v::CeedVector, mtype::MemType) 148 ptr = Ref{Ptr{CeedScalar}}() 149 C.CeedVectorTakeArray(v[], mtype, ptr) 150 v.arr = nothing 151 ptr[] 152end 153 154# Helper function to parse arguments of @witharray and @witharray_read 155function witharray_parse(assignment, args) 156 if !Meta.isexpr(assignment, :(=)) 157 error("@witharray must have first argument of the form v_arr=v") # COV_EXCL_LINE 158 end 159 arr = assignment.args[1] 160 v = assignment.args[2] 161 mtype = MEM_HOST 162 sz = :((length($(esc(v))),)) 163 body = args[end] 164 for i = 1:length(args)-1 165 a = args[i] 166 if !Meta.isexpr(a, :(=)) 167 error("Incorrect call to @witharray or @witharray_read") # COV_EXCL_LINE 168 end 169 if a.args[1] == :mtype 170 mtype = a.args[2] 171 elseif a.args[1] == :size 172 sz = esc(a.args[2]) 173 end 174 end 175 arr, v, sz, mtype, body 176end 177 178""" 179 @witharray(v_arr=v, [size=(dims...)], [mtype=MEM_HOST], body) 180 181Executes `body`, having extracted the contents of the [`CeedVector`](@ref) `v` as an array 182with name `v_arr`. If the [`memory type`](@ref MemType) `mtype` is not provided, `MEM_HOST` 183will be used. If the size is not specified, a flat vector will be assumed. 184 185# Examples 186Negate the contents of `CeedVector` `v`: 187``` 188@witharray v_arr=v v_arr .*= -1.0 189``` 190""" 191macro witharray(assignment, args...) 192 arr, v, sz, mtype, body = witharray_parse(assignment, args) 193 quote 194 arr_ref = Ref{Ptr{C.CeedScalar}}() 195 C.CeedVectorGetArray($(esc(v))[], $(esc(mtype)), arr_ref) 196 try 197 $(esc(arr)) = UnsafeArray(arr_ref[], Int.($sz)) 198 $(esc(body)) 199 finally 200 C.CeedVectorRestoreArray($(esc(v))[], arr_ref) 201 end 202 end 203end 204 205""" 206 @witharray_read(v_arr=v, [size=(dims...)], [mtype=MEM_HOST], body) 207 208Same as [`@witharray`](@ref), but provides read-only access to the data. 209""" 210macro witharray_read(assignment, args...) 211 arr, v, sz, mtype, body = witharray_parse(assignment, args) 212 quote 213 arr_ref = Ref{Ptr{C.CeedScalar}}() 214 C.CeedVectorGetArrayRead($(esc(v))[], $(esc(mtype)), arr_ref) 215 try 216 $(esc(arr)) = UnsafeArray(arr_ref[], Int.($sz)) 217 $(esc(body)) 218 finally 219 C.CeedVectorRestoreArrayRead($(esc(v))[], arr_ref) 220 end 221 end 222end 223 224""" 225 setindex!(v::CeedVector, v2::AbstractArray) 226 v[] = v2 227 228Sets the values of [`CeedVector`](@ref) `v` equal to those of `v2` using broadcasting. 229""" 230Base.setindex!(v::CeedVector, v2::AbstractArray) = @witharray(a = v, a .= v2) 231 232""" 233 CeedVector(c::Ceed, v2::AbstractVector; mtype=MEM_HOST, cmode=COPY_VALUES) 234 235Creates a new [`CeedVector`](@ref) using the contents of the given vector `v2`. By default, 236the contents of `v2` will be copied to the new [`CeedVector`](@ref), but this behavior can 237be changed by specifying a different `cmode`. 238""" 239function CeedVector(c::Ceed, v2::AbstractVector; mtype=MEM_HOST, cmode=COPY_VALUES) 240 v = CeedVector(c, length(v2)) 241 setarray!(v, mtype, cmode, v2) 242 v 243end 244 245""" 246 Vector(v::CeedVector) 247 248Create a new `Vector` by copying the contents of `v`. 249""" 250function Base.Vector(v::CeedVector) 251 v2 = Vector{CeedScalar}(undef, length(v)) 252 @witharray_read(a = v, v2 .= a) 253end 254 255""" 256 witharray(f, v::CeedVector, mtype=MEM_HOST) 257 258Calls `f` with an array containing the data of the `CeedVector` `v`, using [`memory 259type`](@ref MemType) `mtype`. 260 261Because of performance issues involving closures, if `f` is a complex operation, it may be 262more efficient to use the macro version `@witharray` (cf. the section on "Performance of 263captured variable" in the [Julia 264documentation](https://docs.julialang.org/en/v1/manual/performance-tips) and related [GitHub 265issue](https://github.com/JuliaLang/julia/issues/15276). 266 267# Examples 268 269Return the sum of a vector: 270``` 271witharray(sum, v) 272``` 273""" 274function witharray(f, v::CeedVector, mtype::MemType=MEM_HOST) 275 arr_ref = Ref{Ptr{C.CeedScalar}}() 276 C.CeedVectorGetArray(v[], mtype, arr_ref) 277 arr = UnsafeArray(arr_ref[], (length(v),)) 278 res = try 279 f(arr) 280 finally 281 C.CeedVectorRestoreArray(v[], arr_ref) 282 end 283 return res 284end 285 286""" 287 witharray_read(f, v::CeedVector, mtype::MemType=MEM_HOST) 288 289Same as [`witharray`](@ref), but with read-only access to the data. 290 291# Examples 292 293Display the contents of a vector: 294``` 295witharray_read(display, v) 296``` 297""" 298function witharray_read(f, v::CeedVector, mtype::MemType=MEM_HOST) 299 arr_ref = Ref{Ptr{C.CeedScalar}}() 300 C.CeedVectorGetArrayRead(v[], mtype, arr_ref) 301 arr = UnsafeArray(arr_ref[], (length(v),)) 302 res = try 303 f(arr) 304 finally 305 C.CeedVectorRestoreArrayRead(v[], arr_ref) 306 end 307 return res 308end 309 310""" 311 scale!(v::CeedVector, a::Real) 312 313Overwrite `v` with `a*v` for scalar `a`. Returns `v`. 314""" 315function scale!(v::CeedVector, a::Real) 316 C.CeedVectorScale(v[], a) 317 return v 318end 319 320""" 321 axpy!(a::Real, x::CeedVector, y::CeedVector) 322 323Overwrite `y` with `x*a + y`, where `a` is a scalar. Returns `y`. 324 325!!! warning "Different argument order" 326 In order to be consistent with `LinearAlgebra.axpy!`, the arguments are passed in order: `a`, 327 `x`, `y`. This is different than the order of arguments of the C function `CeedVectorAXPY`. 328""" 329function axpy!(a::Real, x::CeedVector, y::CeedVector) 330 C.CeedVectorAXPY(y[], a, x[]) 331 return y 332end 333 334""" 335 pointwisemult!(w::CeedVector, x::CeedVector, y::CeedVector) 336 337Overwrite `w` with `x .* y`. Any subset of x, y, and w may be the same vector. Returns `w`. 338""" 339function pointwisemult!(w::CeedVector, x::CeedVector, y::CeedVector) 340 C.CeedVectorPointwiseMult(w[], x[], y[]) 341 return w 342end 343