xref: /libCEED/python/ceed_vector.py (revision b0d170e7bc2e5c930ee481a47eb73044935a48a4)
1# Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors
2# All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6# This file is part of CEED:  http://github.com/ceed
7
8from _ceed_cffi import ffi, lib
9import tempfile
10import numpy as np
11import contextlib
12from .ceed_constants import MEM_HOST, USE_POINTER, COPY_VALUES, NORM_2, scalar_types
13
14# ------------------------------------------------------------------------------
15
16
17class Vector():
18    """Ceed Vector: storing and manipulating vectors."""
19
20    # Constructor
21    def __init__(self, ceed, size):
22        # CeedVector object
23        self._pointer = ffi.new("CeedVector *")
24
25        # Reference to Ceed
26        self._ceed = ceed
27
28        # libCEED call
29        err_code = lib.CeedVectorCreate(
30            self._ceed._pointer[0], size, self._pointer)
31        self._ceed._check_error(err_code)
32
33    # Destructor
34    def __del__(self):
35        # libCEED call
36        err_code = lib.CeedVectorDestroy(self._pointer)
37        self._ceed._check_error(err_code)
38
39    # Representation
40    def __repr__(self):
41        return "<CeedVector instance at " + hex(id(self)) + ">"
42
43    # String conversion for print() to stdout
44    def __str__(self):
45        """View a Vector via print()."""
46
47        # libCEED call
48        fmt = ffi.new("char[]", "%f".encode('ascii'))
49        with tempfile.NamedTemporaryFile() as key_file:
50            with open(key_file.name, 'r+') as stream_file:
51                stream = ffi.cast("FILE *", stream_file)
52
53                err_code = lib.CeedVectorView(self._pointer[0], fmt, stream)
54                self._ceed._check_error(err_code)
55
56                stream_file.seek(0)
57                out_string = stream_file.read()
58
59        return out_string
60
61    # Set Vector's data array
62    def set_array(self, array, memtype=MEM_HOST, cmode=COPY_VALUES):
63        """Set the array used by a Vector, freeing any previously allocated
64           array if applicable.
65
66           Args:
67             *array: Numpy or Numba array to be used
68             **memtype: memory type of the array being passed, default CEED_MEM_HOST
69             **cmode: copy mode for the array, default CEED_COPY_VALUES"""
70
71        # Store array reference if needed
72        if cmode == USE_POINTER:
73            self._array_reference = array
74        else:
75            self._array_reference = None
76
77        # Setup the numpy array for the libCEED call
78        array_pointer = ffi.new("CeedScalar *")
79        if memtype == MEM_HOST:
80            array_pointer = ffi.cast(
81                "CeedScalar *",
82                array.__array_interface__['data'][0])
83        else:
84            array_pointer = ffi.cast(
85                "CeedScalar *",
86                array.__cuda_array_interface__['data'][0])
87
88        # libCEED call
89        err_code = lib.CeedVectorSetArray(
90            self._pointer[0], memtype, cmode, array_pointer)
91        self._ceed._check_error(err_code)
92
93    # Get Vector's data array
94    def get_array(self, memtype=MEM_HOST):
95        """Get read/write access to a Vector via the specified memory type.
96
97           Args:
98             **memtype: memory type of the array being passed, default CEED_MEM_HOST
99
100           Returns:
101             *array: Numpy or Numba array"""
102
103        # Retrieve the length of the array
104        length_pointer = ffi.new("CeedSize *")
105        err_code = lib.CeedVectorGetLength(self._pointer[0], length_pointer)
106        self._ceed._check_error(err_code)
107
108        # Setup the pointer's pointer
109        array_pointer = ffi.new("CeedScalar **")
110
111        # libCEED call
112        err_code = lib.CeedVectorGetArray(
113            self._pointer[0], memtype, array_pointer)
114        self._ceed._check_error(err_code)
115
116        # Return array created from buffer
117        if memtype == MEM_HOST:
118            # Create buffer object from returned pointer
119            buff = ffi.buffer(
120                array_pointer[0],
121                ffi.sizeof("CeedScalar") *
122                length_pointer[0])
123            # return Numpy array
124            return np.frombuffer(buff, dtype=scalar_types[lib.CEED_SCALAR_TYPE])
125        else:
126            # CUDA array interface
127            # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
128            import numba.cuda as nbcuda
129            desc = {
130                'shape': (length_pointer[0]),
131                'typestr': '>f8',
132                'data': (int(ffi.cast("intptr_t", array_pointer[0])), False),
133                'version': 2
134            }
135            # return Numba array
136            return nbcuda.from_cuda_array_interface(desc)
137
138    # Get Vector's data array in read-only mode
139    def get_array_read(self, memtype=MEM_HOST):
140        """Get read-only access to a Vector via the specified memory type.
141
142           Args:
143             **memtype: memory type of the array being passed, default CEED_MEM_HOST
144
145           Returns:
146             *array: Numpy or Numba array"""
147
148        # Retrieve the length of the array
149        length_pointer = ffi.new("CeedSize *")
150        err_code = lib.CeedVectorGetLength(self._pointer[0], length_pointer)
151        self._ceed._check_error(err_code)
152
153        # Setup the pointer's pointer
154        array_pointer = ffi.new("CeedScalar **")
155
156        # libCEED call
157        err_code = lib.CeedVectorGetArrayRead(
158            self._pointer[0], memtype, array_pointer)
159        self._ceed._check_error(err_code)
160
161        # Return array created from buffer
162        if memtype == MEM_HOST:
163            # Create buffer object from returned pointer
164            buff = ffi.buffer(
165                array_pointer[0],
166                ffi.sizeof("CeedScalar") *
167                length_pointer[0])
168            # return read only Numpy array
169            ret = np.frombuffer(buff, dtype=scalar_types[lib.CEED_SCALAR_TYPE])
170            ret.flags['WRITEABLE'] = False
171            return ret
172        else:
173            # CUDA array interface
174            # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
175            import numba.cuda as nbcuda
176            desc = {
177                'shape': (length_pointer[0]),
178                'typestr': '>f8',
179                'data': (int(ffi.cast("intptr_t", array_pointer[0])), False),
180                'version': 2
181            }
182            # return read only Numba array
183            return nbcuda.from_cuda_array_interface(desc)
184
185    # Get Vector's data array in write-only mode
186    def get_array_write(self, memtype=MEM_HOST):
187        """Get write-only access to a Vector via the specified memory type.
188           All old values should be considered invalid.
189
190           Args:
191             **memtype: memory type of the array being passed, default CEED_MEM_HOST
192
193           Returns:
194             *array: Numpy or Numba array"""
195
196        # Retrieve the length of the array
197        length_pointer = ffi.new("CeedSize *")
198        err_code = lib.CeedVectorGetLength(self._pointer[0], length_pointer)
199        self._ceed._check_error(err_code)
200
201        # Setup the pointer's pointer
202        array_pointer = ffi.new("CeedScalar **")
203
204        # libCEED call
205        err_code = lib.CeedVectorGetArrayWrite(
206            self._pointer[0], memtype, array_pointer)
207        self._ceed._check_error(err_code)
208
209        # Return array created from buffer
210        if memtype == MEM_HOST:
211            # Create buffer object from returned pointer
212            buff = ffi.buffer(
213                array_pointer[0],
214                ffi.sizeof("CeedScalar") *
215                length_pointer[0])
216            # return Numpy array
217            return np.frombuffer(buff, dtype=scalar_types[lib.CEED_SCALAR_TYPE])
218        else:
219            # CUDA array interface
220            # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
221            import numba.cuda as nbcuda
222            desc = {
223                'shape': (length_pointer[0]),
224                'typestr': '>f8',
225                'data': (int(ffi.cast("intptr_t", array_pointer[0])), False),
226                'version': 2
227            }
228            # return Numba array
229            return nbcuda.from_cuda_array_interface(desc)
230
231    # Restore the Vector's data array
232    def restore_array(self):
233        """Restore an array obtained using get_array()."""
234
235        # Setup the pointer's pointer
236        array_pointer = ffi.new("CeedScalar **")
237
238        # libCEED call
239        err_code = lib.CeedVectorRestoreArray(self._pointer[0], array_pointer)
240        self._ceed._check_error(err_code)
241
242    # Restore an array obtained using getArrayRead
243    def restore_array_read(self):
244        """Restore an array obtained using get_array_read()."""
245
246        # Setup the pointer's pointer
247        array_pointer = ffi.new("CeedScalar **")
248
249        # libCEED call
250        err_code = lib.CeedVectorRestoreArrayRead(
251            self._pointer[0], array_pointer)
252        self._ceed._check_error(err_code)
253
254    @contextlib.contextmanager
255    def array(self, *shape, memtype=MEM_HOST):
256        """Context manager for array access.
257
258        Args:
259          shape (tuple): shape of returned numpy.array
260          **memtype: memory type of the array being passed, default CEED_MEM_HOST
261
262
263        Returns:
264          np.array: writable view of vector
265
266        Examples:
267          Constructing the identity inside a libceed.Vector:
268
269          >>> vec = ceed.Vector(16)
270          >>> with vec.array(4, 4) as x:
271          >>>     x[...] = np.eye(4)
272        """
273        x = self.get_array(memtype=memtype)
274        if shape:
275            x = x.reshape(shape)
276        yield x
277        self.restore_array()
278
279    @contextlib.contextmanager
280    def array_read(self, *shape, memtype=MEM_HOST):
281        """Context manager for read-only array access.
282
283        Args:
284          shape (tuple): shape of returned numpy.array
285          **memtype: memory type of the array being passed, default CEED_MEM_HOST
286
287        Returns:
288          np.array: read-only view of vector
289
290        Examples:
291          Viewing contents of a reshaped libceed.Vector view:
292
293          >>> vec = ceed.Vector(6)
294          >>> vec.set_value(1.3)
295          >>> with vec.array_read(2, 3) as x:
296          >>>     print(x)
297        """
298        x = self.get_array_read(memtype=memtype)
299        if shape:
300            x = x.reshape(shape)
301        yield x
302        self.restore_array_read()
303
304    @contextlib.contextmanager
305    def array_write(self, *shape, memtype=MEM_HOST):
306        """Context manager for write-only array access.
307           All old values should be considered invalid.
308
309        Args:
310          shape (tuple): shape of returned numpy.array
311          **memtype: memory type of the array being passed, default CEED_MEM_HOST
312
313        Returns:
314          np.array: write-only view of vector
315
316        Examples:
317          Viewing contents of a reshaped libceed.Vector view:
318
319          >>> vec = ceed.Vector(6)
320          >>> vec.set_value(1.3)
321          >>> with vec.array_read(2, 3) as x:
322          >>>     print(x)
323        """
324        x = self.get_array_write(memtype=memtype)
325        if shape:
326            x = x.reshape(shape)
327        yield x
328        self.restore_array()
329
330    # Get the length of a Vector
331    def get_length(self):
332        """Get the length of a Vector.
333
334           Returns:
335             length: length of the Vector"""
336
337        length_pointer = ffi.new("CeedSize *")
338
339        # libCEED call
340        err_code = lib.CeedVectorGetLength(self._pointer[0], length_pointer)
341        self._ceed._check_error(err_code)
342
343        return length_pointer[0]
344
345    # Get the length of a Vector
346    def __len__(self):
347        """Get the length of a Vector.
348
349           Returns:
350             length: length of the Vector"""
351
352        length_pointer = ffi.new("CeedSize *")
353
354        # libCEED call
355        err_code = lib.CeedVectorGetLength(self._pointer[0], length_pointer)
356        self._ceed._check_error(err_code)
357
358        return length_pointer[0]
359
360    # Set the Vector to a given constant value
361    def set_value(self, value):
362        """Set the Vector to a constant value.
363
364           Args:
365             value: value to be used"""
366
367        # libCEED call
368        err_code = lib.CeedVectorSetValue(self._pointer[0], value)
369        self._ceed._check_error(err_code)
370
371    # Sync the Vector to a specified memtype
372    def sync_array(self, memtype=MEM_HOST):
373        """Sync the Vector to a specified memtype.
374
375           Args:
376             **memtype: memtype to be synced"""
377
378        # libCEED call
379        err_code = lib.CeedVectorSyncArray(self._pointer[0], memtype)
380        self._ceed._check_error(err_code)
381
382    # Compute the norm of a vector
383    def norm(self, normtype=NORM_2):
384        """Get the norm of a Vector.
385
386           Args:
387             **normtype: type of norm to be computed"""
388
389        norm_pointer = ffi.new("CeedScalar *")
390
391        # libCEED call
392        err_code = lib.CeedVectorNorm(self._pointer[0], normtype, norm_pointer)
393        self._ceed._check_error(err_code)
394
395        return norm_pointer[0]
396
397    # Take the reciprocal of a vector
398    def reciprocal(self):
399        """Take the reciprocal of a Vector."""
400
401        # libCEED call
402        err_code = lib.CeedVectorReciprocal(self._pointer[0])
403        self._ceed._check_error(err_code)
404
405        return self
406
407    # Compute self = alpha self
408    def scale(self, alpha):
409        """Compute self = alpha self."""
410
411        # libCEED call
412        err_code = lib.CeedVectorScale(self._pointer[0], alpha)
413        self._ceed._check_error(err_code)
414
415        return self
416
417    # Compute self = alpha x + self
418    def axpy(self, alpha, x):
419        """Compute self = alpha x + self."""
420
421        # libCEED call
422        err_code = lib.CeedVectorAXPY(self._pointer[0], alpha, x._pointer[0])
423        self._ceed._check_error(err_code)
424
425        return self
426
427    # Compute the pointwise multiplication self = x .* y
428    def pointwise_mult(self, x, y):
429        """Compute the pointwise multiplication self = x .* y."""
430
431        # libCEED call
432        err_code = lib.CeedVectorPointwiseMult(
433            self._pointer[0], x._pointer[0], y._pointer[0]
434        )
435        self._ceed._check_error(err_code)
436
437        return self
438
439    def _state(self):
440        """Return the modification state of the Vector.
441
442        State is incremented each time the Vector is mutated, and is odd whenever a
443        mutable reference has not been returned.
444        """
445
446        state_pointer = ffi.new("uint64_t *")
447        err_code = lib.CeedVectorGetState(self._pointer[0], state_pointer)
448        self._ceed._check_error(err_code)
449        return state_pointer[0]
450
451# ------------------------------------------------------------------------------
452
453
454class _VectorWrap(Vector):
455    """Wrap a CeedVector pointer in a Vector object."""
456
457    # Constructor
458    def __init__(self, ceed, pointer):
459        # CeedVector object
460        self._pointer = pointer
461
462        # Reference to Ceed
463        self._ceed = ceed
464
465# ------------------------------------------------------------------------------
466