xref: /libCEED/julia/LibCEED.jl/src/CeedVector.jl (revision 5aed82e4fa97acf4ba24a7f10a35f5303a6798e0)
1import LinearAlgebra: norm, axpy!
2
3abstract type AbstractCeedVector end
4
5struct CeedVectorActive <: AbstractCeedVector end
6Base.getindex(::CeedVectorActive) = C.CEED_VECTOR_ACTIVE[]
7
8struct CeedVectorNone <: AbstractCeedVector end
9Base.getindex(::CeedVectorNone) = C.CEED_VECTOR_NONE[]
10
11mutable struct CeedVector <: AbstractCeedVector
12    ref::RefValue{C.CeedVector}
13    arr::Union{Nothing,AbstractArray}
14    CeedVector(ref::Ref{C.CeedVector}) = new(ref, nothing)
15end
16
17"""
18    CeedVector(c::Ceed, len::Integer; allocate::Bool=true)
19
20Creates a `CeedVector` of given length. If `allocate` is false, then no memory is allocated.
21Attemping to access the vector before calling `setarray!` will fail. By default, `allocate` is
22true, and libCEED will allocate (host) memory for the vector.
23"""
24function CeedVector(c::Ceed, len::Integer; allocate::Bool=true)
25    ref = Ref{C.CeedVector}()
26    C.CeedVectorCreate(c[], len, ref)
27    obj = CeedVector(ref)
28    finalizer(obj) do x
29        # ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring), "Finalizing %s.\n", repr(x))
30        destroy(x)
31    end
32    if allocate
33        setarray!(obj, MEM_HOST, COPY_VALUES, C_NULL)
34    end
35    return obj
36end
37destroy(v::CeedVector) = C.CeedVectorDestroy(v.ref) # COV_EXCL_LINE
38Base.getindex(v::CeedVector) = v.ref[]
39
40Base.summary(io::IO, v::CeedVector) = print(io, length(v), "-element CeedVector")
41function Base.show(io::IO, ::MIME"text/plain", v::CeedVector)
42    summary(io, v)
43    println(io, ":")
44    witharray_read(v, MEM_HOST) do arr
45        Base.print_array(io, arr)
46    end
47end
48Base.show(io::IO, v::CeedVector) = witharray_read(a -> show(io, a), v, MEM_HOST)
49
50function Base.length(::Type{T}, v::CeedVector) where {T}
51    len = Ref{C.CeedSize}()
52    C.CeedVectorGetLength(v[], len)
53    return T(len[])
54end
55
56Base.ndims(::CeedVector) = 1
57Base.ndims(::Type{CeedVector}) = 1
58Base.axes(v::CeedVector) = (Base.OneTo(length(v)),)
59Base.size(v::CeedVector) = (length(Int, v),)
60Base.length(v::CeedVector) = length(Int, v)
61
62"""
63    setvalue!(v::CeedVector, val::Real)
64
65Set the [`CeedVector`](@ref) to a constant value.
66"""
67setvalue!(v::CeedVector, val::Real) = C.CeedVectorSetValue(v[], val)
68"""
69    setindex!(v::CeedVector, val::Real)
70    v[] = val
71
72Set the [`CeedVector`](@ref) to a constant value, synonymous to [`setvalue!`](@ref).
73"""
74Base.setindex!(v::CeedVector, val::Real) = setvalue!(v, val)
75
76"""
77    norm(v::CeedVector, ntype::NormType)
78
79Return the norm of the given [`CeedVector`](@ref).
80
81The norm type can either be specified as one of `NORM_1`, `NORM_2`, `NORM_MAX`.
82"""
83function norm(v::CeedVector, ntype::NormType)
84    nrm = Ref{CeedScalar}()
85    C.CeedVectorNorm(v[], ntype, nrm)
86    nrm[]
87end
88
89"""
90    norm(v::CeedVector, p::Real)
91
92Return the norm of the given [`CeedVector`](@ref), see [`norm(::CeedVector,
93::NormType)`](@ref).
94
95`p` can have value 1, 2, or Inf, corresponding to `NORM_1`, `NORM_2`, and `NORM_MAX`,
96respectively.
97"""
98function norm(v::CeedVector, p::Real)
99    if p == 1
100        ntype = NORM_1
101    elseif p == 2
102        ntype = NORM_2
103    elseif isinf(p)
104        ntype = NORM_MAX
105    else
106        error("norm(v::CeedVector, p): p must be 1, 2, or Inf")
107    end
108    norm(v, ntype)
109end
110
111"""
112    reciprocal!(v::CeedVector)
113
114Set `v` to be equal to its elementwise reciprocal.
115"""
116reciprocal!(v::CeedVector) = C.CeedVectorReciprocal(v[])
117
118"""
119    setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr)
120
121Set the array used by a [`CeedVector`](@ref), freeing any previously allocated array if
122applicable. The backend may copy values to a different [`MemType`](@ref). See also
123[`syncarray!`](@ref) and [`takearray!`](@ref).
124
125!!! warning "Avoid OWN_POINTER CopyMode"
126    The [`CopyMode`](@ref) `OWN_POINTER` is not suitable for use with arrays that are
127    allocated by Julia, since those cannot be properly freed from libCEED.
128"""
129function setarray!(v::CeedVector, mtype::MemType, cmode::CopyMode, arr)
130    C.CeedVectorSetArray(v[], mtype, cmode, arr)
131    if cmode == USE_POINTER
132        v.arr = arr
133    end
134end
135
136"""
137    syncarray!(v::CeedVector, mtype::MemType)
138
139Sync the [`CeedVector`](@ref) to a specified [`MemType`](@ref). This function is used to
140force synchronization of arrays set with [`setarray!`](@ref). If the requested memtype is
141already synchronized, this function results in a no-op.
142"""
143syncarray!(v::CeedVector, mtype::MemType) = C.CeedVectorSyncArray(v[], mtype)
144
145"""
146    takearray!(v::CeedVector, mtype::MemType)
147
148Take ownership of the [`CeedVector`](@ref) array and remove the array from the
149[`CeedVector`](@ref). The caller is responsible for managing and freeing the array. The
150array is returns as a `Ptr{CeedScalar}`.
151"""
152function takearray!(v::CeedVector, mtype::MemType)
153    ptr = Ref{Ptr{CeedScalar}}()
154    C.CeedVectorTakeArray(v[], mtype, ptr)
155    v.arr = nothing
156    ptr[]
157end
158
159# Helper function to parse arguments of @witharray and @witharray_read
160function witharray_parse(assignment, args)
161    if !Meta.isexpr(assignment, :(=))
162        error("@witharray must have first argument of the form v_arr=v") # COV_EXCL_LINE
163    end
164    arr = assignment.args[1]
165    v = assignment.args[2]
166    mtype = MEM_HOST
167    sz = :((length($(esc(v))),))
168    body = args[end]
169    for i = 1:length(args)-1
170        a = args[i]
171        if !Meta.isexpr(a, :(=))
172            error("Incorrect call to @witharray or @witharray_read") # COV_EXCL_LINE
173        end
174        if a.args[1] == :mtype
175            mtype = a.args[2]
176        elseif a.args[1] == :size
177            sz = esc(a.args[2])
178        end
179    end
180    arr, v, sz, mtype, body
181end
182
183"""
184    @witharray(v_arr=v, [size=(dims...)], [mtype=MEM_HOST], body)
185
186Executes `body`, having extracted the contents of the [`CeedVector`](@ref) `v` as an array
187with name `v_arr`. If the [`memory type`](@ref MemType) `mtype` is not provided, `MEM_HOST`
188will be used. If the size is not specified, a flat vector will be assumed.
189
190# Examples
191Negate the contents of `CeedVector` `v`:
192```
193@witharray v_arr=v v_arr .*= -1.0
194```
195"""
196macro witharray(assignment, args...)
197    arr, v, sz, mtype, body = witharray_parse(assignment, args)
198    quote
199        arr_ref = Ref{Ptr{C.CeedScalar}}()
200        C.CeedVectorGetArray($(esc(v))[], $(esc(mtype)), arr_ref)
201        try
202            $(esc(arr)) = UnsafeArray(arr_ref[], Int.($sz))
203            $(esc(body))
204        finally
205            C.CeedVectorRestoreArray($(esc(v))[], arr_ref)
206        end
207    end
208end
209
210"""
211    @witharray_read(v_arr=v, [size=(dims...)], [mtype=MEM_HOST], body)
212
213Same as [`@witharray`](@ref), but provides read-only access to the data.
214"""
215macro witharray_read(assignment, args...)
216    arr, v, sz, mtype, body = witharray_parse(assignment, args)
217    quote
218        arr_ref = Ref{Ptr{C.CeedScalar}}()
219        C.CeedVectorGetArrayRead($(esc(v))[], $(esc(mtype)), arr_ref)
220        try
221            $(esc(arr)) = UnsafeArray(arr_ref[], Int.($sz))
222            $(esc(body))
223        finally
224            C.CeedVectorRestoreArrayRead($(esc(v))[], arr_ref)
225        end
226    end
227end
228
229"""
230    setindex!(v::CeedVector, v2::AbstractArray)
231    v[] = v2
232
233Sets the values of [`CeedVector`](@ref) `v` equal to those of `v2` using broadcasting.
234"""
235Base.setindex!(v::CeedVector, v2::AbstractArray) = @witharray(a = v, a .= v2)
236
237"""
238    CeedVector(c::Ceed, v2::AbstractVector; mtype=MEM_HOST, cmode=COPY_VALUES)
239
240Creates a new [`CeedVector`](@ref) using the contents of the given vector `v2`. By default,
241the contents of `v2` will be copied to the new [`CeedVector`](@ref), but this behavior can
242be changed by specifying a different `cmode`.
243"""
244function CeedVector(c::Ceed, v2::AbstractVector; mtype=MEM_HOST, cmode=COPY_VALUES)
245    v = CeedVector(c, length(v2); allocate=false)
246    setarray!(v, mtype, cmode, v2)
247    v
248end
249
250"""
251    Vector(v::CeedVector)
252
253Create a new `Vector` by copying the contents of `v`.
254"""
255function Base.Vector(v::CeedVector)
256    v2 = Vector{CeedScalar}(undef, length(v))
257    @witharray_read(a = v, v2 .= a)
258end
259
260"""
261    witharray(f, v::CeedVector, mtype=MEM_HOST)
262
263Calls `f` with an array containing the data of the `CeedVector` `v`, using [`memory
264type`](@ref MemType) `mtype`.
265
266Because of performance issues involving closures, if `f` is a complex operation, it may be
267more efficient to use the macro version `@witharray` (cf. the section on "Performance of
268captured variable" in the [Julia
269documentation](https://docs.julialang.org/en/v1/manual/performance-tips) and related [GitHub
270issue](https://github.com/JuliaLang/julia/issues/15276).
271
272# Examples
273
274Return the sum of a vector:
275```
276witharray(sum, v)
277```
278"""
279function witharray(f, v::CeedVector, mtype::MemType=MEM_HOST)
280    arr_ref = Ref{Ptr{C.CeedScalar}}()
281    C.CeedVectorGetArray(v[], mtype, arr_ref)
282    arr = UnsafeArray(arr_ref[], (length(v),))
283    res = try
284        f(arr)
285    finally
286        C.CeedVectorRestoreArray(v[], arr_ref)
287    end
288    return res
289end
290
291"""
292    witharray_read(f, v::CeedVector, mtype::MemType=MEM_HOST)
293
294Same as [`witharray`](@ref), but with read-only access to the data.
295
296# Examples
297
298Display the contents of a vector:
299```
300witharray_read(display, v)
301```
302"""
303function witharray_read(f, v::CeedVector, mtype::MemType=MEM_HOST)
304    arr_ref = Ref{Ptr{C.CeedScalar}}()
305    C.CeedVectorGetArrayRead(v[], mtype, arr_ref)
306    arr = UnsafeArray(arr_ref[], (length(v),))
307    res = try
308        f(arr)
309    finally
310        C.CeedVectorRestoreArrayRead(v[], arr_ref)
311    end
312    return res
313end
314
315"""
316    scale!(v::CeedVector, a::Real)
317
318Overwrite `v` with `a*v` for scalar `a`. Returns `v`.
319"""
320function scale!(v::CeedVector, a::Real)
321    C.CeedVectorScale(v[], a)
322    return v
323end
324
325"""
326    axpy!(a::Real, x::CeedVector, y::CeedVector)
327
328Overwrite `y` with `x*a + y`, where `a` is a scalar. Returns `y`.
329
330!!! warning "Different argument order"
331    In order to be consistent with `LinearAlgebra.axpy!`, the arguments are passed in order: `a`,
332    `x`, `y`. This is different than the order of arguments of the C function `CeedVectorAXPY`.
333"""
334function axpy!(a::Real, x::CeedVector, y::CeedVector)
335    C.CeedVectorAXPY(y[], a, x[])
336    return y
337end
338
339"""
340    pointwisemult!(w::CeedVector, x::CeedVector, y::CeedVector)
341
342Overwrite `w` with `x .* y`. Any subset of x, y, and w may be the same vector. Returns `w`.
343"""
344function pointwisemult!(w::CeedVector, x::CeedVector, y::CeedVector)
345    C.CeedVectorPointwiseMult(w[], x[], y[])
346    return w
347end
348