xref: /libCEED/python/ceed_qfunctioncontext.py (revision 9ba83ac0e4b1fca39d6fa6737a318a9f0cbc172d)
1*9ba83ac0SJeremy L Thompson# Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors
23d8e8822SJeremy L Thompson# All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3777ff853SJeremy L Thompson#
43d8e8822SJeremy L Thompson# SPDX-License-Identifier: BSD-2-Clause
5777ff853SJeremy L Thompson#
63d8e8822SJeremy L Thompson# This file is part of CEED:  http://github.com/ceed
7777ff853SJeremy L Thompson
8777ff853SJeremy L Thompsonfrom _ceed_cffi import ffi, lib
9777ff853SJeremy L Thompsonimport tempfile
10777ff853SJeremy L Thompsonimport numpy as np
11777ff853SJeremy L Thompsonimport contextlib
1280a9ef05SNatalie Beamsfrom .ceed_constants import MEM_HOST, USE_POINTER, COPY_VALUES, scalar_types
13777ff853SJeremy L Thompson
14777ff853SJeremy L Thompson# ------------------------------------------------------------------------------
15777ff853SJeremy L Thompson
16777ff853SJeremy L Thompson
17777ff853SJeremy L Thompsonclass QFunctionContext():
18777ff853SJeremy L Thompson    """Ceed QFunction Context: stores Ceed QFunction user context data."""
19777ff853SJeremy L Thompson
20777ff853SJeremy L Thompson    # Constructor
21777ff853SJeremy L Thompson    def __init__(self, ceed):
22777ff853SJeremy L Thompson        # CeedQFunctionContext object
23777ff853SJeremy L Thompson        self._pointer = ffi.new("CeedQFunctionContext *")
24777ff853SJeremy L Thompson
25777ff853SJeremy L Thompson        # Reference to Ceed
26777ff853SJeremy L Thompson        self._ceed = ceed
27777ff853SJeremy L Thompson
28777ff853SJeremy L Thompson        # libCEED call
29777ff853SJeremy L Thompson        err_code = lib.CeedQFunctionContextCreate(
30777ff853SJeremy L Thompson            self._ceed._pointer[0], self._pointer)
31777ff853SJeremy L Thompson        self._ceed._check_error(err_code)
32777ff853SJeremy L Thompson
33777ff853SJeremy L Thompson    # Destructor
34777ff853SJeremy L Thompson    def __del__(self):
35777ff853SJeremy L Thompson        # libCEED call
36777ff853SJeremy L Thompson        err_code = lib.CeedQFunctionContextDestroy(self._pointer)
37777ff853SJeremy L Thompson        self._ceed._check_error(err_code)
38777ff853SJeremy L Thompson
39777ff853SJeremy L Thompson    # Representation
40777ff853SJeremy L Thompson    def __repr__(self):
41777ff853SJeremy L Thompson        return "<CeedQFunctionContext instance at " + hex(id(self)) + ">"
42777ff853SJeremy L Thompson
43777ff853SJeremy L Thompson    # String conversion for print() to stdout
44777ff853SJeremy L Thompson    def __str__(self):
45777ff853SJeremy L Thompson        """View a QFunction Context via print()."""
46777ff853SJeremy L Thompson
47777ff853SJeremy L Thompson        # libCEED call
48777ff853SJeremy L Thompson        fmt = ffi.new("char[]", "%f".encode('ascii'))
49777ff853SJeremy L Thompson        with tempfile.NamedTemporaryFile() as key_file:
50777ff853SJeremy L Thompson            with open(key_file.name, 'r+') as stream_file:
51777ff853SJeremy L Thompson                stream = ffi.cast("FILE *", stream_file)
52777ff853SJeremy L Thompson
53777ff853SJeremy L Thompson                err_code = lib.CeedQFunctionContextView(
54777ff853SJeremy L Thompson                    self._pointer[0], stream)
55777ff853SJeremy L Thompson                self._ceed._check_error(err_code)
56777ff853SJeremy L Thompson
57777ff853SJeremy L Thompson                stream_file.seek(0)
58777ff853SJeremy L Thompson                out_string = stream_file.read()
59777ff853SJeremy L Thompson
60777ff853SJeremy L Thompson        return out_string
61777ff853SJeremy L Thompson
62777ff853SJeremy L Thompson    # Set QFunction Context's data
63777ff853SJeremy L Thompson    def set_data(self, data, memtype=MEM_HOST, cmode=COPY_VALUES):
64777ff853SJeremy L Thompson        """Set the data used by a QFunction Context, freeing any previously allocated
65777ff853SJeremy L Thompson           data if applicable.
66777ff853SJeremy L Thompson
67777ff853SJeremy L Thompson           Args:
68777ff853SJeremy L Thompson             *data: Numpy or Numba array to be used
69777ff853SJeremy L Thompson             **memtype: memory type of the array being passed, default CEED_MEM_HOST
70777ff853SJeremy L Thompson             **cmode: copy mode for the array, default CEED_COPY_VALUES"""
71777ff853SJeremy L Thompson
72187168c7SJeremy L Thompson        # Store array reference if needed
73187168c7SJeremy L Thompson        if cmode == USE_POINTER:
74187168c7SJeremy L Thompson            self._array_reference = data
75187168c7SJeremy L Thompson        else:
76187168c7SJeremy L Thompson            self._array_reference = None
77187168c7SJeremy L Thompson
78777ff853SJeremy L Thompson        # Setup the numpy array for the libCEED call
79777ff853SJeremy L Thompson        data_pointer = ffi.new("CeedScalar *")
80777ff853SJeremy L Thompson        if memtype == MEM_HOST:
81777ff853SJeremy L Thompson            data_pointer = ffi.cast(
82777ff853SJeremy L Thompson                "void *",
83777ff853SJeremy L Thompson                data.__array_interface__['data'][0])
84777ff853SJeremy L Thompson        else:
85777ff853SJeremy L Thompson            array_pointer = ffi.cast(
86777ff853SJeremy L Thompson                "void *",
87777ff853SJeremy L Thompson                data.__cuda_array_interface__['data'][0])
88777ff853SJeremy L Thompson
89777ff853SJeremy L Thompson        # libCEED call
90777ff853SJeremy L Thompson        err_code = lib.CeedQFunctionContextSetData(
91777ff853SJeremy L Thompson            self._pointer[0],
92777ff853SJeremy L Thompson            memtype,
93777ff853SJeremy L Thompson            cmode,
94777ff853SJeremy L Thompson            len(data) * ffi.sizeof("CeedScalar"),
95777ff853SJeremy L Thompson            data_pointer)
96777ff853SJeremy L Thompson        self._ceed._check_error(err_code)
97777ff853SJeremy L Thompson
98777ff853SJeremy L Thompson    # Get QFunction Context's data
99777ff853SJeremy L Thompson    def get_data(self, memtype=MEM_HOST):
100777ff853SJeremy L Thompson        """Get read/write access to a QFunction Context via the specified memory type.
101777ff853SJeremy L Thompson
102777ff853SJeremy L Thompson           Args:
103777ff853SJeremy L Thompson             **memtype: memory type of the array being passed, default CEED_MEM_HOST
104777ff853SJeremy L Thompson
105777ff853SJeremy L Thompson           Returns:
106777ff853SJeremy L Thompson             *data: Numpy or Numba array"""
107777ff853SJeremy L Thompson
108777ff853SJeremy L Thompson        # Retrieve the length of the array
10980a9ef05SNatalie Beams        size_pointer = ffi.new("size_t *")
110777ff853SJeremy L Thompson        err_code = lib.CeedQFunctionContextGetContextSize(
111777ff853SJeremy L Thompson            self._pointer[0], size_pointer)
112777ff853SJeremy L Thompson        self._ceed._check_error(err_code)
113777ff853SJeremy L Thompson
114777ff853SJeremy L Thompson        # Setup the pointer's pointer
115777ff853SJeremy L Thompson        data_pointer = ffi.new("CeedScalar **")
116777ff853SJeremy L Thompson
117777ff853SJeremy L Thompson        # libCEED call
118777ff853SJeremy L Thompson        err_code = lib.CeedQFunctionContextGetData(
119777ff853SJeremy L Thompson            self._pointer[0], memtype, data_pointer)
120777ff853SJeremy L Thompson        self._ceed._check_error(err_code)
121777ff853SJeremy L Thompson
122777ff853SJeremy L Thompson        # Return array created from buffer
123777ff853SJeremy L Thompson        if memtype == MEM_HOST:
124777ff853SJeremy L Thompson            # Create buffer object from returned pointer
125777ff853SJeremy L Thompson            buff = ffi.buffer(
126777ff853SJeremy L Thompson                data_pointer[0],
12780a9ef05SNatalie Beams                size_pointer[0])
128777ff853SJeremy L Thompson            # return Numpy array
12980a9ef05SNatalie Beams            return np.frombuffer(buff, dtype=scalar_types[lib.CEED_SCALAR_TYPE])
130777ff853SJeremy L Thompson        else:
131777ff853SJeremy L Thompson            # CUDA array interface
132777ff853SJeremy L Thompson            # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
133777ff853SJeremy L Thompson            import numba.cuda as nbcuda
13480a9ef05SNatalie Beams            if lib.CEED_SCALAR_TYPE == lib.CEED_SCALAR_FP32:
13580a9ef05SNatalie Beams                scalar_type_str = '>f4'
13680a9ef05SNatalie Beams            else:
13780a9ef05SNatalie Beams                scalar_type_str = '>f8'
138777ff853SJeremy L Thompson            desc = {
13980a9ef05SNatalie Beams                'shape': (size_pointer[0] / ffi.sizeof("CeedScalar")),
14080a9ef05SNatalie Beams                'typestr': scalar_type_str,
141777ff853SJeremy L Thompson                'data': (int(ffi.cast("intptr_t", data_pointer[0])), False),
142777ff853SJeremy L Thompson                'version': 2
143777ff853SJeremy L Thompson            }
144777ff853SJeremy L Thompson            # return Numba array
145777ff853SJeremy L Thompson            return nbcuda.from_cuda_array_interface(desc)
146777ff853SJeremy L Thompson
147777ff853SJeremy L Thompson    # Restore the QFunction Context's data
148777ff853SJeremy L Thompson    def restore_data(self):
149777ff853SJeremy L Thompson        """Restore an array obtained using get_data()."""
150777ff853SJeremy L Thompson
151777ff853SJeremy L Thompson        # Setup the pointer's pointer
152777ff853SJeremy L Thompson        data_pointer = ffi.new("CeedScalar **")
153777ff853SJeremy L Thompson
154777ff853SJeremy L Thompson        # libCEED call
1559647a07eSDavid Medina        err_code = lib.CeedQFunctionDataRestoreData(
1569647a07eSDavid Medina            self._pointer[0], data_pointer)
157777ff853SJeremy L Thompson        self._ceed._check_error(err_code)
158777ff853SJeremy L Thompson
159777ff853SJeremy L Thompson    @contextlib.contextmanager
160777ff853SJeremy L Thompson    def data(self, *shape, memtype=MEM_HOST):
161777ff853SJeremy L Thompson        """Context manager for array access.
162777ff853SJeremy L Thompson
163777ff853SJeremy L Thompson        Args:
164777ff853SJeremy L Thompson          shape (tuple): shape of returned numpy.array
165777ff853SJeremy L Thompson          **memtype: memory type of the data being passed, default CEED_MEM_HOST
166777ff853SJeremy L Thompson
167777ff853SJeremy L Thompson
168777ff853SJeremy L Thompson        Returns:
169777ff853SJeremy L Thompson          np.array: writable view of QFunction Context
170777ff853SJeremy L Thompson        """
171777ff853SJeremy L Thompson        x = self.get_data(memtype=memtype)
172777ff853SJeremy L Thompson        if shape:
173777ff853SJeremy L Thompson            x = x.reshape(shape)
174777ff853SJeremy L Thompson        yield x
175777ff853SJeremy L Thompson        self.restore_data()
176777ff853SJeremy L Thompson
177777ff853SJeremy L Thompson# ------------------------------------------------------------------------------
178