xref: /libCEED/python/ceed_qfunctioncontext.py (revision f5066b3615781dbcd74af2f846f96d7648d0187d)
1# Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
2# the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
3# reserved. See files LICENSE and NOTICE for details.
4#
5# This file is part of CEED, a collection of benchmarks, miniapps, software
6# libraries and APIs for efficient high-order finite element and spectral
7# element discretizations for exascale applications. For more information and
8# source code availability see http://github.com/ceed.
9#
10# The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11# a collaborative effort of two U.S. Department of Energy organizations (Office
12# of Science and the National Nuclear Security Administration) responsible for
13# the planning and preparation of a capable exascale ecosystem, including
14# software, applications, hardware, advanced system engineering and early
15# testbed platforms, in support of the nation's exascale computing imperative.
16
17from _ceed_cffi import ffi, lib
18import tempfile
19import numpy as np
20import contextlib
21from .ceed_constants import MEM_HOST, USE_POINTER, COPY_VALUES, scalar_types
22
23# ------------------------------------------------------------------------------
24
25
26class QFunctionContext():
27    """Ceed QFunction Context: stores Ceed QFunction user context data."""
28
29    # Constructor
30    def __init__(self, ceed):
31        # CeedQFunctionContext object
32        self._pointer = ffi.new("CeedQFunctionContext *")
33
34        # Reference to Ceed
35        self._ceed = ceed
36
37        # libCEED call
38        err_code = lib.CeedQFunctionContextCreate(
39            self._ceed._pointer[0], self._pointer)
40        self._ceed._check_error(err_code)
41
42    # Destructor
43    def __del__(self):
44        # libCEED call
45        err_code = lib.CeedQFunctionContextDestroy(self._pointer)
46        self._ceed._check_error(err_code)
47
48    # Representation
49    def __repr__(self):
50        return "<CeedQFunctionContext instance at " + hex(id(self)) + ">"
51
52    # String conversion for print() to stdout
53    def __str__(self):
54        """View a QFunction Context via print()."""
55
56        # libCEED call
57        fmt = ffi.new("char[]", "%f".encode('ascii'))
58        with tempfile.NamedTemporaryFile() as key_file:
59            with open(key_file.name, 'r+') as stream_file:
60                stream = ffi.cast("FILE *", stream_file)
61
62                err_code = lib.CeedQFunctionContextView(
63                    self._pointer[0], stream)
64                self._ceed._check_error(err_code)
65
66                stream_file.seek(0)
67                out_string = stream_file.read()
68
69        return out_string
70
71    # Set QFunction Context's data
72    def set_data(self, data, memtype=MEM_HOST, cmode=COPY_VALUES):
73        """Set the data used by a QFunction Context, freeing any previously allocated
74           data if applicable.
75
76           Args:
77             *data: Numpy or Numba array to be used
78             **memtype: memory type of the array being passed, default CEED_MEM_HOST
79             **cmode: copy mode for the array, default CEED_COPY_VALUES"""
80
81        # Store array reference if needed
82        if cmode == USE_POINTER:
83            self._array_reference = data
84        else:
85            self._array_reference = None
86
87        # Setup the numpy array for the libCEED call
88        data_pointer = ffi.new("CeedScalar *")
89        if memtype == MEM_HOST:
90            data_pointer = ffi.cast(
91                "void *",
92                data.__array_interface__['data'][0])
93        else:
94            array_pointer = ffi.cast(
95                "void *",
96                data.__cuda_array_interface__['data'][0])
97
98        # libCEED call
99        err_code = lib.CeedQFunctionContextSetData(
100            self._pointer[0],
101            memtype,
102            cmode,
103            len(data) * ffi.sizeof("CeedScalar"),
104            data_pointer)
105        self._ceed._check_error(err_code)
106
107    # Get QFunction Context's data
108    def get_data(self, memtype=MEM_HOST):
109        """Get read/write access to a QFunction Context via the specified memory type.
110
111           Args:
112             **memtype: memory type of the array being passed, default CEED_MEM_HOST
113
114           Returns:
115             *data: Numpy or Numba array"""
116
117        # Retrieve the length of the array
118        size_pointer = ffi.new("size_t *")
119        err_code = lib.CeedQFunctionContextGetContextSize(
120            self._pointer[0], size_pointer)
121        self._ceed._check_error(err_code)
122
123        # Setup the pointer's pointer
124        data_pointer = ffi.new("CeedScalar **")
125
126        # libCEED call
127        err_code = lib.CeedQFunctionContextGetData(
128            self._pointer[0], memtype, data_pointer)
129        self._ceed._check_error(err_code)
130
131        # Return array created from buffer
132        if memtype == MEM_HOST:
133            # Create buffer object from returned pointer
134            buff = ffi.buffer(
135                data_pointer[0],
136                size_pointer[0])
137            # return Numpy array
138            return np.frombuffer(buff, dtype=scalar_types[lib.CEED_SCALAR_TYPE])
139        else:
140            # CUDA array interface
141            # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
142            import numba.cuda as nbcuda
143            if lib.CEED_SCALAR_TYPE == lib.CEED_SCALAR_FP32:
144                scalar_type_str = '>f4'
145            else:
146                scalar_type_str = '>f8'
147            desc = {
148                'shape': (size_pointer[0] / ffi.sizeof("CeedScalar")),
149                'typestr': scalar_type_str,
150                'data': (int(ffi.cast("intptr_t", data_pointer[0])), False),
151                'version': 2
152            }
153            # return Numba array
154            return nbcuda.from_cuda_array_interface(desc)
155
156    # Restore the QFunction Context's data
157    def restore_data(self):
158        """Restore an array obtained using get_data()."""
159
160        # Setup the pointer's pointer
161        data_pointer = ffi.new("CeedScalar **")
162
163        # libCEED call
164        err_code = lib.CeedQFunctionDataRestoreData(
165            self._pointer[0], data_pointer)
166        self._ceed._check_error(err_code)
167
168    @contextlib.contextmanager
169    def data(self, *shape, memtype=MEM_HOST):
170        """Context manager for array access.
171
172        Args:
173          shape (tuple): shape of returned numpy.array
174          **memtype: memory type of the data being passed, default CEED_MEM_HOST
175
176
177        Returns:
178          np.array: writable view of QFunction Context
179        """
180        x = self.get_data(memtype=memtype)
181        if shape:
182            x = x.reshape(shape)
183        yield x
184        self.restore_data()
185
186# ------------------------------------------------------------------------------
187