xref: /libCEED/python/ceed_qfunctioncontext.py (revision 461525f5d1dd56e4cf57a8aa5aab285da54ebf95)
1# Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
2# the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
3# reserved. See files LICENSE and NOTICE for details.
4#
5# This file is part of CEED, a collection of benchmarks, miniapps, software
6# libraries and APIs for efficient high-order finite element and spectral
7# element discretizations for exascale applications. For more information and
8# source code availability see http://github.com/ceed.
9#
10# The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11# a collaborative effort of two U.S. Department of Energy organizations (Office
12# of Science and the National Nuclear Security Administration) responsible for
13# the planning and preparation of a capable exascale ecosystem, including
14# software, applications, hardware, advanced system engineering and early
15# testbed platforms, in support of the nation's exascale computing imperative.
16
17from _ceed_cffi import ffi, lib
18import tempfile
19import numpy as np
20import contextlib
21from .ceed_constants import MEM_HOST, COPY_VALUES
22
23# ------------------------------------------------------------------------------
24
25
26class QFunctionContext():
27    """Ceed QFunction Context: stores Ceed QFunction user context data."""
28
29    # Attributes
30    _ceed = ffi.NULL
31    _pointer = ffi.NULL
32
33    # Constructor
34    def __init__(self, ceed):
35        # CeedQFunctionContext object
36        self._pointer = ffi.new("CeedQFunctionContext *")
37
38        # Reference to Ceed
39        self._ceed = ceed
40
41        # libCEED call
42        err_code = lib.CeedQFunctionContextCreate(
43            self._ceed._pointer[0], self._pointer)
44        self._ceed._check_error(err_code)
45
46    # Destructor
47    def __del__(self):
48        # libCEED call
49        err_code = lib.CeedQFunctionContextDestroy(self._pointer)
50        self._ceed._check_error(err_code)
51
52    # Representation
53    def __repr__(self):
54        return "<CeedQFunctionContext instance at " + hex(id(self)) + ">"
55
56    # String conversion for print() to stdout
57    def __str__(self):
58        """View a QFunction Context via print()."""
59
60        # libCEED call
61        fmt = ffi.new("char[]", "%f".encode('ascii'))
62        with tempfile.NamedTemporaryFile() as key_file:
63            with open(key_file.name, 'r+') as stream_file:
64                stream = ffi.cast("FILE *", stream_file)
65
66                err_code = lib.CeedQFunctionContextView(
67                    self._pointer[0], stream)
68                self._ceed._check_error(err_code)
69
70                stream_file.seek(0)
71                out_string = stream_file.read()
72
73        return out_string
74
75    # Set QFunction Context's data
76    def set_data(self, data, memtype=MEM_HOST, cmode=COPY_VALUES):
77        """Set the data used by a QFunction Context, freeing any previously allocated
78           data if applicable.
79
80           Args:
81             *data: Numpy or Numba array to be used
82             **memtype: memory type of the array being passed, default CEED_MEM_HOST
83             **cmode: copy mode for the array, default CEED_COPY_VALUES"""
84
85        # Setup the numpy array for the libCEED call
86        data_pointer = ffi.new("CeedScalar *")
87        if memtype == MEM_HOST:
88            data_pointer = ffi.cast(
89                "void *",
90                data.__array_interface__['data'][0])
91        else:
92            array_pointer = ffi.cast(
93                "void *",
94                data.__cuda_array_interface__['data'][0])
95
96        # libCEED call
97        err_code = lib.CeedQFunctionContextSetData(
98            self._pointer[0],
99            memtype,
100            cmode,
101            len(data) * ffi.sizeof("CeedScalar"),
102            data_pointer)
103        self._ceed._check_error(err_code)
104
105    # Get QFunction Context's data
106    def get_data(self, memtype=MEM_HOST):
107        """Get read/write access to a QFunction Context via the specified memory type.
108
109           Args:
110             **memtype: memory type of the array being passed, default CEED_MEM_HOST
111
112           Returns:
113             *data: Numpy or Numba array"""
114
115        # Retrieve the length of the array
116        size_pointer = ffi.new("CeedInt *")
117        err_code = lib.CeedQFunctionContextGetContextSize(
118            self._pointer[0], size_pointer)
119        self._ceed._check_error(err_code)
120
121        # Setup the pointer's pointer
122        data_pointer = ffi.new("CeedScalar **")
123
124        # libCEED call
125        err_code = lib.CeedQFunctionContextGetData(
126            self._pointer[0], memtype, data_pointer)
127        self._ceed._check_error(err_code)
128
129        # Return array created from buffer
130        if memtype == MEM_HOST:
131            # Create buffer object from returned pointer
132            buff = ffi.buffer(
133                data_pointer[0],
134                ffi.sizeof("CeedScalar") *
135                length_pointer[0])
136            # return Numpy array
137            return np.frombuffer(buff, dtype="float64")
138        else:
139            # CUDA array interface
140            # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
141            import numba.cuda as nbcuda
142            desc = {
143                'shape': (length_pointer[0]),
144                'typestr': '>f8',
145                'data': (int(ffi.cast("intptr_t", data_pointer[0])), False),
146                'version': 2
147            }
148            # return Numba array
149            return nbcuda.from_cuda_array_interface(desc)
150
151    # Restore the QFunction Context's data
152    def restore_data(self):
153        """Restore an array obtained using get_data()."""
154
155        # Setup the pointer's pointer
156        data_pointer = ffi.new("CeedScalar **")
157
158        # libCEED call
159        err_code = lib.CeedQFunctionDataRestoreData(
160            self._pointer[0], data_pointer)
161        self._ceed._check_error(err_code)
162
163    @contextlib.contextmanager
164    def data(self, *shape, memtype=MEM_HOST):
165        """Context manager for array access.
166
167        Args:
168          shape (tuple): shape of returned numpy.array
169          **memtype: memory type of the data being passed, default CEED_MEM_HOST
170
171
172        Returns:
173          np.array: writable view of QFunction Context
174        """
175        x = self.get_data(memtype=memtype)
176        if shape:
177            x = x.reshape(shape)
178        yield x
179        self.restore_data()
180
181# ------------------------------------------------------------------------------
182