xref: /libCEED/python/ceed_qfunctioncontext.py (revision 7be1e82bee12be3372d0edb3771f3d66d6ac97b8)
1# Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
2# the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
3# reserved. See files LICENSE and NOTICE for details.
4#
5# This file is part of CEED, a collection of benchmarks, miniapps, software
6# libraries and APIs for efficient high-order finite element and spectral
7# element discretizations for exascale applications. For more information and
8# source code availability see http://github.com/ceed.
9#
10# The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11# a collaborative effort of two U.S. Department of Energy organizations (Office
12# of Science and the National Nuclear Security Administration) responsible for
13# the planning and preparation of a capable exascale ecosystem, including
14# software, applications, hardware, advanced system engineering and early
15# testbed platforms, in support of the nation's exascale computing imperative.
16
17from _ceed_cffi import ffi, lib
18import tempfile
19import numpy as np
20import contextlib
21from .ceed_constants import MEM_HOST, USE_POINTER, COPY_VALUES
22
23# ------------------------------------------------------------------------------
24
25
26class QFunctionContext():
27    """Ceed QFunction Context: stores Ceed QFunction user context data."""
28
29    # Constructor
30    def __init__(self, ceed):
31        # CeedQFunctionContext object
32        self._pointer = ffi.new("CeedQFunctionContext *")
33
34        # Reference to Ceed
35        self._ceed = ceed
36
37        # libCEED call
38        err_code = lib.CeedQFunctionContextCreate(
39            self._ceed._pointer[0], self._pointer)
40        self._ceed._check_error(err_code)
41
42    # Destructor
43    def __del__(self):
44        # libCEED call
45        err_code = lib.CeedQFunctionContextDestroy(self._pointer)
46        self._ceed._check_error(err_code)
47
48    # Representation
49    def __repr__(self):
50        return "<CeedQFunctionContext instance at " + hex(id(self)) + ">"
51
52    # String conversion for print() to stdout
53    def __str__(self):
54        """View a QFunction Context via print()."""
55
56        # libCEED call
57        fmt = ffi.new("char[]", "%f".encode('ascii'))
58        with tempfile.NamedTemporaryFile() as key_file:
59            with open(key_file.name, 'r+') as stream_file:
60                stream = ffi.cast("FILE *", stream_file)
61
62                err_code = lib.CeedQFunctionContextView(
63                    self._pointer[0], stream)
64                self._ceed._check_error(err_code)
65
66                stream_file.seek(0)
67                out_string = stream_file.read()
68
69        return out_string
70
71    # Set QFunction Context's data
72    def set_data(self, data, memtype=MEM_HOST, cmode=COPY_VALUES):
73        """Set the data used by a QFunction Context, freeing any previously allocated
74           data if applicable.
75
76           Args:
77             *data: Numpy or Numba array to be used
78             **memtype: memory type of the array being passed, default CEED_MEM_HOST
79             **cmode: copy mode for the array, default CEED_COPY_VALUES"""
80
81        # Store array reference if needed
82        if cmode == USE_POINTER:
83            self._array_reference = data
84        else:
85            self._array_reference = None
86
87        # Setup the numpy array for the libCEED call
88        data_pointer = ffi.new("CeedScalar *")
89        if memtype == MEM_HOST:
90            data_pointer = ffi.cast(
91                "void *",
92                data.__array_interface__['data'][0])
93        else:
94            array_pointer = ffi.cast(
95                "void *",
96                data.__cuda_array_interface__['data'][0])
97
98        # libCEED call
99        err_code = lib.CeedQFunctionContextSetData(
100            self._pointer[0],
101            memtype,
102            cmode,
103            len(data) * ffi.sizeof("CeedScalar"),
104            data_pointer)
105        self._ceed._check_error(err_code)
106
107    # Get QFunction Context's data
108    def get_data(self, memtype=MEM_HOST):
109        """Get read/write access to a QFunction Context via the specified memory type.
110
111           Args:
112             **memtype: memory type of the array being passed, default CEED_MEM_HOST
113
114           Returns:
115             *data: Numpy or Numba array"""
116
117        # Retrieve the length of the array
118        size_pointer = ffi.new("CeedInt *")
119        err_code = lib.CeedQFunctionContextGetContextSize(
120            self._pointer[0], size_pointer)
121        self._ceed._check_error(err_code)
122
123        # Setup the pointer's pointer
124        data_pointer = ffi.new("CeedScalar **")
125
126        # libCEED call
127        err_code = lib.CeedQFunctionContextGetData(
128            self._pointer[0], memtype, data_pointer)
129        self._ceed._check_error(err_code)
130
131        # Return array created from buffer
132        if memtype == MEM_HOST:
133            # Create buffer object from returned pointer
134            buff = ffi.buffer(
135                data_pointer[0],
136                ffi.sizeof("CeedScalar") *
137                length_pointer[0])
138            # return Numpy array
139            return np.frombuffer(buff, dtype="float64")
140        else:
141            # CUDA array interface
142            # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
143            import numba.cuda as nbcuda
144            desc = {
145                'shape': (length_pointer[0]),
146                'typestr': '>f8',
147                'data': (int(ffi.cast("intptr_t", data_pointer[0])), False),
148                'version': 2
149            }
150            # return Numba array
151            return nbcuda.from_cuda_array_interface(desc)
152
153    # Restore the QFunction Context's data
154    def restore_data(self):
155        """Restore an array obtained using get_data()."""
156
157        # Setup the pointer's pointer
158        data_pointer = ffi.new("CeedScalar **")
159
160        # libCEED call
161        err_code = lib.CeedQFunctionDataRestoreData(
162            self._pointer[0], data_pointer)
163        self._ceed._check_error(err_code)
164
165    @contextlib.contextmanager
166    def data(self, *shape, memtype=MEM_HOST):
167        """Context manager for array access.
168
169        Args:
170          shape (tuple): shape of returned numpy.array
171          **memtype: memory type of the data being passed, default CEED_MEM_HOST
172
173
174        Returns:
175          np.array: writable view of QFunction Context
176        """
177        x = self.get_data(memtype=memtype)
178        if shape:
179            x = x.reshape(shape)
180        yield x
181        self.restore_data()
182
183# ------------------------------------------------------------------------------
184