xref: /libCEED/julia/LibCEED.jl/src/CeedVector.jl (revision 80a9ef0545a39c00cdcaab1ca26f8053604f3120)
1import LinearAlgebra: norm, axpy!
2
3abstract type AbstractCeedVector end
4
5struct CeedVectorActive <: AbstractCeedVector end
6Base.getindex(::CeedVectorActive) = C.CEED_VECTOR_ACTIVE[]
7
8struct CeedVectorNone <: AbstractCeedVector end
9Base.getindex(::CeedVectorNone) = C.CEED_VECTOR_NONE[]
10
11mutable struct CeedVector <: AbstractCeedVector
12    ref::RefValue{C.CeedVector}
13    arr::Union{Nothing,AbstractArray}
14    CeedVector(ref::Ref{C.CeedVector}) = new(ref, nothing)
15end
16
17"""
18    CeedVector(c::Ceed, len::Integer)
19
20Creates a `CeedVector` of given length.
21"""
22function CeedVector(c::Ceed, len::Integer)
23    ref = Ref{C.CeedVector}()
24    C.CeedVectorCreate(c[], len, ref)
25    obj = CeedVector(ref)
26    finalizer(obj) do x
27        # ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring), "Finalizing %s.\n", repr(x))
28        destroy(x)
29    end
30    return obj
31end
32destroy(v::CeedVector) = C.CeedVectorDestroy(v.ref) # COV_EXCL_LINE
33Base.getindex(v::CeedVector) = v.ref[]
34
35Base.summary(io::IO, v::CeedVector) = print(io, length(v), "-element CeedVector")
36function Base.show(io::IO, ::MIME"text/plain", v::CeedVector)
37    summary(io, v)
38    println(io, ":")
39    witharray_read(v, MEM_HOST) do arr
40        Base.print_array(io, arr)
41    end
42end
43Base.show(io::IO, v::CeedVector) = witharray_read(a -> show(io, a), v, MEM_HOST)
44
45function Base.length(::Type{T}, v::CeedVector) where {T}
46    len = Ref{C.CeedInt}()
47    C.CeedVectorGetLength(v[], len)
48    return T(len[])
49end
50
51Base.ndims(::CeedVector) = 1
52Base.ndims(::Type{CeedVector}) = 1
53Base.axes(v::CeedVector) = (Base.OneTo(length(v)),)
54Base.size(v::CeedVector) = (length(Int, v),)
55Base.length(v::CeedVector) = length(Int, v)
56
57"""
58    setvalue!(v::CeedVector, val::Real)
59
60Set the [`CeedVector`](@ref) to a constant value.
61"""
62setvalue!(v::CeedVector, val::Real) = C.CeedVectorSetValue(v[], val)
63"""
64    setindex!(v::CeedVector, val::Real)
65    v[] = val
66
67Set the [`CeedVector`](@ref) to a constant value, synonymous to [`setvalue!`](@ref).
68"""
69Base.setindex!(v::CeedVector, val::Real) = setvalue!(v, val)
70
71"""
72    norm(v::CeedVector, ntype::NormType)
73
74Return the norm of the given [`CeedVector`](@ref).
75
76The norm type can either be specified as one of `NORM_1`, `NORM_2`, `NORM_MAX`.
77"""
78function norm(v::CeedVector, ntype::NormType)
79    nrm = Ref{CeedScalar}()
80    C.CeedVectorNorm(v[], ntype, nrm)
81    nrm[]
82end
83
84"""
85    norm(v::CeedVector, p::Real)
86
87Return the norm of the given [`CeedVector`](@ref), see [`norm(::CeedVector,
88::NormType)`](@ref).
89
90`p` can have value 1, 2, or Inf, corresponding to `NORM_1`, `NORM_2`, and `NORM_MAX`,
91respectively.
92"""
93function norm(v::CeedVector, p::Real)
94    if p == 1
95        ntype = NORM_1
96    elseif p == 2
97        ntype = NORM_2
98    elseif isinf(p)
99        ntype = NORM_MAX
100    else
101        error("norm(v::CeedVector, p): p must be 1, 2, or Inf")
102    end
103    norm(v, ntype)
104end
105
106"""
107    reciprocal!(v::CeedVector)
108
109Set `v` to be equal to its elementwise reciprocal.
110"""
111reciprocal!(v::CeedVector) = C.CeedVectorReciprocal(v[])
112
113"""
114    setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr)
115
116Set the array used by a [`CeedVector`](@ref), freeing any previously allocated array if
117applicable. The backend may copy values to a different [`MemType`](@ref). See also
118[`syncarray!`](@ref) and [`takearray!`](@ref).
119
120!!! warning "Avoid OWN_POINTER CopyMode"
121    The [`CopyMode`](@ref) `OWN_POINTER` is not suitable for use with arrays that are
122    allocated by Julia, since those cannot be properly freed from libCEED.
123"""
124function setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr)
125    C.CeedVectorSetArray(v[], mtype, cmode, arr)
126    if cmode == USE_POINTER
127        v.arr = arr
128    end
129end
130
131"""
132    syncarray!(v::CeedVector, mtype::MemType)
133
134Sync the [`CeedVector`](@ref) to a specified [`MemType`](@ref). This function is used to
135force synchronization of arrays set with [`setarray!`](@ref). If the requested memtype is
136already synchronized, this function results in a no-op.
137"""
138syncarray!(v::CeedVector, mtype::MemType) = C.CeedVectorSyncArray(v[], mtype)
139
140"""
141    takearray!(v::CeedVector, mtype::MemType)
142
143Take ownership of the [`CeedVector`](@ref) array and remove the array from the
144[`CeedVector`](@ref). The caller is responsible for managing and freeing the array. The
145array is returns as a `Ptr{CeedScalar}`.
146"""
147function takearray!(v::CeedVector, mtype::MemType)
148    ptr = Ref{Ptr{CeedScalar}}()
149    C.CeedVectorTakeArray(v[], mtype, ptr)
150    v.arr = nothing
151    ptr[]
152end
153
154# Helper function to parse arguments of @witharray and @witharray_read
155function witharray_parse(assignment, args)
156    if !Meta.isexpr(assignment, :(=))
157        error("@witharray must have first argument of the form v_arr=v") # COV_EXCL_LINE
158    end
159    arr = assignment.args[1]
160    v = assignment.args[2]
161    mtype = MEM_HOST
162    sz = :((length($(esc(v))),))
163    body = args[end]
164    for i = 1:length(args)-1
165        a = args[i]
166        if !Meta.isexpr(a, :(=))
167            error("Incorrect call to @witharray or @witharray_read") # COV_EXCL_LINE
168        end
169        if a.args[1] == :mtype
170            mtype = a.args[2]
171        elseif a.args[1] == :size
172            sz = esc(a.args[2])
173        end
174    end
175    arr, v, sz, mtype, body
176end
177
178"""
179    @witharray(v_arr=v, [size=(dims...)], [mtype=MEM_HOST], body)
180
181Executes `body`, having extracted the contents of the [`CeedVector`](@ref) `v` as an array
182with name `v_arr`. If the [`memory type`](@ref MemType) `mtype` is not provided, `MEM_HOST`
183will be used. If the size is not specified, a flat vector will be assumed.
184
185# Examples
186Negate the contents of `CeedVector` `v`:
187```
188@witharray v_arr=v v_arr .*= -1.0
189```
190"""
191macro witharray(assignment, args...)
192    arr, v, sz, mtype, body = witharray_parse(assignment, args)
193    quote
194        arr_ref = Ref{Ptr{C.CeedScalar}}()
195        C.CeedVectorGetArray($(esc(v))[], $(esc(mtype)), arr_ref)
196        try
197            $(esc(arr)) = UnsafeArray(arr_ref[], Int.($sz))
198            $(esc(body))
199        finally
200            C.CeedVectorRestoreArray($(esc(v))[], arr_ref)
201        end
202    end
203end
204
205"""
206    @witharray_read(v_arr=v, [size=(dims...)], [mtype=MEM_HOST], body)
207
208Same as [`@witharray`](@ref), but provides read-only access to the data.
209"""
210macro witharray_read(assignment, args...)
211    arr, v, sz, mtype, body = witharray_parse(assignment, args)
212    quote
213        arr_ref = Ref{Ptr{C.CeedScalar}}()
214        C.CeedVectorGetArrayRead($(esc(v))[], $(esc(mtype)), arr_ref)
215        try
216            $(esc(arr)) = UnsafeArray(arr_ref[], Int.($sz))
217            $(esc(body))
218        finally
219            C.CeedVectorRestoreArrayRead($(esc(v))[], arr_ref)
220        end
221    end
222end
223
224"""
225    setindex!(v::CeedVector, v2::AbstractArray)
226    v[] = v2
227
228Sets the values of [`CeedVector`](@ref) `v` equal to those of `v2` using broadcasting.
229"""
230Base.setindex!(v::CeedVector, v2::AbstractArray) = @witharray(a = v, a .= v2)
231
232"""
233    CeedVector(c::Ceed, v2::AbstractVector; mtype=MEM_HOST, cmode=COPY_VALUES)
234
235Creates a new [`CeedVector`](@ref) using the contents of the given vector `v2`. By default,
236the contents of `v2` will be copied to the new [`CeedVector`](@ref), but this behavior can
237be changed by specifying a different `cmode`.
238"""
239function CeedVector(c::Ceed, v2::AbstractVector; mtype=MEM_HOST, cmode=COPY_VALUES)
240    v = CeedVector(c, length(v2))
241    setarray!(v, mtype, cmode, v2)
242    v
243end
244
245"""
246    Vector(v::CeedVector)
247
248Create a new `Vector` by copying the contents of `v`.
249"""
250function Base.Vector(v::CeedVector)
251    v2 = Vector{CeedScalar}(undef, length(v))
252    @witharray_read(a = v, v2 .= a)
253end
254
255"""
256    witharray(f, v::CeedVector, mtype=MEM_HOST)
257
258Calls `f` with an array containing the data of the `CeedVector` `v`, using [`memory
259type`](@ref MemType) `mtype`.
260
261Because of performance issues involving closures, if `f` is a complex operation, it may be
262more efficient to use the macro version `@witharray` (cf. the section on "Performance of
263captured variable" in the [Julia
264documentation](https://docs.julialang.org/en/v1/manual/performance-tips) and related [GitHub
265issue](https://github.com/JuliaLang/julia/issues/15276).
266
267# Examples
268
269Return the sum of a vector:
270```
271witharray(sum, v)
272```
273"""
274function witharray(f, v::CeedVector, mtype::MemType=MEM_HOST)
275    arr_ref = Ref{Ptr{C.CeedScalar}}()
276    C.CeedVectorGetArray(v[], mtype, arr_ref)
277    arr = UnsafeArray(arr_ref[], (length(v),))
278    res = try
279        f(arr)
280    finally
281        C.CeedVectorRestoreArray(v[], arr_ref)
282    end
283    return res
284end
285
286"""
287    witharray_read(f, v::CeedVector, mtype::MemType=MEM_HOST)
288
289Same as [`witharray`](@ref), but with read-only access to the data.
290
291# Examples
292
293Display the contents of a vector:
294```
295witharray_read(display, v)
296```
297"""
298function witharray_read(f, v::CeedVector, mtype::MemType=MEM_HOST)
299    arr_ref = Ref{Ptr{C.CeedScalar}}()
300    C.CeedVectorGetArrayRead(v[], mtype, arr_ref)
301    arr = UnsafeArray(arr_ref[], (length(v),))
302    res = try
303        f(arr)
304    finally
305        C.CeedVectorRestoreArrayRead(v[], arr_ref)
306    end
307    return res
308end
309
310"""
311    scale!(v::CeedVector, a::Real)
312
313Overwrite `v` with `a*v` for scalar `a`. Returns `v`.
314"""
315function scale!(v::CeedVector, a::Real)
316    C.CeedVectorScale(v[], a)
317    return v
318end
319
320"""
321    axpy!(a::Real, x::CeedVector, y::CeedVector)
322
323Overwrite `y` with `x*a + y`, where `a` is a scalar. Returns `y`.
324
325!!! warning "Different argument order"
326    In order to be consistent with `LinearAlgebra.axpy!`, the arguments are passed in order: `a`,
327    `x`, `y`. This is different than the order of arguments of the C function `CeedVectorAXPY`.
328"""
329function axpy!(a::Real, x::CeedVector, y::CeedVector)
330    C.CeedVectorAXPY(y[], a, x[])
331    return y
332end
333
334"""
335    pointwisemult!(w::CeedVector, x::CeedVector, y::CeedVector)
336
337Overwrite `w` with `x .* y`. Any subset of x, y, and w may be the same vector. Returns `w`.
338"""
339function pointwisemult!(w::CeedVector, x::CeedVector, y::CeedVector)
340    C.CeedVectorPointwiseMult(w[], x[], y[])
341    return w
342end
343