xref: /libCEED/python/ceed_qfunctioncontext.py (revision 2459f3f1cd4d7d2e210e1c26d669bd2fde41a0b6)
1# Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors
2# All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3#
4# SPDX-License-Identifier: BSD-2-Clause
5#
6# This file is part of CEED:  http://github.com/ceed
7
8from _ceed_cffi import ffi, lib
9import tempfile
10import numpy as np
11import contextlib
12from .ceed_constants import MEM_HOST, USE_POINTER, COPY_VALUES, scalar_types
13
14# ------------------------------------------------------------------------------
15
16
17class QFunctionContext():
18    """Ceed QFunction Context: stores Ceed QFunction user context data."""
19
20    # Constructor
21    def __init__(self, ceed):
22        # CeedQFunctionContext object
23        self._pointer = ffi.new("CeedQFunctionContext *")
24
25        # Reference to Ceed
26        self._ceed = ceed
27
28        # libCEED call
29        err_code = lib.CeedQFunctionContextCreate(
30            self._ceed._pointer[0], self._pointer)
31        self._ceed._check_error(err_code)
32
33    # Destructor
34    def __del__(self):
35        # libCEED call
36        err_code = lib.CeedQFunctionContextDestroy(self._pointer)
37        self._ceed._check_error(err_code)
38
39    # Representation
40    def __repr__(self):
41        return "<CeedQFunctionContext instance at " + hex(id(self)) + ">"
42
43    # String conversion for print() to stdout
44    def __str__(self):
45        """View a QFunction Context via print()."""
46
47        # libCEED call
48        fmt = ffi.new("char[]", "%f".encode('ascii'))
49        with tempfile.NamedTemporaryFile() as key_file:
50            with open(key_file.name, 'r+') as stream_file:
51                stream = ffi.cast("FILE *", stream_file)
52
53                err_code = lib.CeedQFunctionContextView(
54                    self._pointer[0], stream)
55                self._ceed._check_error(err_code)
56
57                stream_file.seek(0)
58                out_string = stream_file.read()
59
60        return out_string
61
62    # Set QFunction Context's data
63    def set_data(self, data, memtype=MEM_HOST, cmode=COPY_VALUES):
64        """Set the data used by a QFunction Context, freeing any previously allocated
65           data if applicable.
66
67           Args:
68             *data: Numpy or Numba array to be used
69             **memtype: memory type of the array being passed, default CEED_MEM_HOST
70             **cmode: copy mode for the array, default CEED_COPY_VALUES"""
71
72        # Store array reference if needed
73        if cmode == USE_POINTER:
74            self._array_reference = data
75        else:
76            self._array_reference = None
77
78        # Setup the numpy array for the libCEED call
79        data_pointer = ffi.new("CeedScalar *")
80        if memtype == MEM_HOST:
81            data_pointer = ffi.cast(
82                "void *",
83                data.__array_interface__['data'][0])
84        else:
85            array_pointer = ffi.cast(
86                "void *",
87                data.__cuda_array_interface__['data'][0])
88
89        # libCEED call
90        err_code = lib.CeedQFunctionContextSetData(
91            self._pointer[0],
92            memtype,
93            cmode,
94            len(data) * ffi.sizeof("CeedScalar"),
95            data_pointer)
96        self._ceed._check_error(err_code)
97
98    # Get QFunction Context's data
99    def get_data(self, memtype=MEM_HOST):
100        """Get read/write access to a QFunction Context via the specified memory type.
101
102           Args:
103             **memtype: memory type of the array being passed, default CEED_MEM_HOST
104
105           Returns:
106             *data: Numpy or Numba array"""
107
108        # Retrieve the length of the array
109        size_pointer = ffi.new("size_t *")
110        err_code = lib.CeedQFunctionContextGetContextSize(
111            self._pointer[0], size_pointer)
112        self._ceed._check_error(err_code)
113
114        # Setup the pointer's pointer
115        data_pointer = ffi.new("CeedScalar **")
116
117        # libCEED call
118        err_code = lib.CeedQFunctionContextGetData(
119            self._pointer[0], memtype, data_pointer)
120        self._ceed._check_error(err_code)
121
122        # Return array created from buffer
123        if memtype == MEM_HOST:
124            # Create buffer object from returned pointer
125            buff = ffi.buffer(
126                data_pointer[0],
127                size_pointer[0])
128            # return Numpy array
129            return np.frombuffer(buff, dtype=scalar_types[lib.CEED_SCALAR_TYPE])
130        else:
131            # CUDA array interface
132            # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
133            import numba.cuda as nbcuda
134            if lib.CEED_SCALAR_TYPE == lib.CEED_SCALAR_FP32:
135                scalar_type_str = '>f4'
136            else:
137                scalar_type_str = '>f8'
138            desc = {
139                'shape': (size_pointer[0] / ffi.sizeof("CeedScalar")),
140                'typestr': scalar_type_str,
141                'data': (int(ffi.cast("intptr_t", data_pointer[0])), False),
142                'version': 2
143            }
144            # return Numba array
145            return nbcuda.from_cuda_array_interface(desc)
146
147    # Restore the QFunction Context's data
148    def restore_data(self):
149        """Restore an array obtained using get_data()."""
150
151        # Setup the pointer's pointer
152        data_pointer = ffi.new("CeedScalar **")
153
154        # libCEED call
155        err_code = lib.CeedQFunctionDataRestoreData(
156            self._pointer[0], data_pointer)
157        self._ceed._check_error(err_code)
158
159    @contextlib.contextmanager
160    def data(self, *shape, memtype=MEM_HOST):
161        """Context manager for array access.
162
163        Args:
164          shape (tuple): shape of returned numpy.array
165          **memtype: memory type of the data being passed, default CEED_MEM_HOST
166
167
168        Returns:
169          np.array: writable view of QFunction Context
170        """
171        x = self.get_data(memtype=memtype)
172        if shape:
173            x = x.reshape(shape)
174        yield x
175        self.restore_data()
176
177# ------------------------------------------------------------------------------
178