1# Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors 2# All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3# 4# SPDX-License-Identifier: BSD-2-Clause 5# 6# This file is part of CEED: http://github.com/ceed 7 8from _ceed_cffi import ffi, lib 9import tempfile 10import numpy as np 11import contextlib 12from .ceed_constants import MEM_HOST, USE_POINTER, COPY_VALUES, scalar_types 13 14# ------------------------------------------------------------------------------ 15 16 17class QFunctionContext(): 18 """Ceed QFunction Context: stores Ceed QFunction user context data.""" 19 20 # Constructor 21 def __init__(self, ceed): 22 # CeedQFunctionContext object 23 self._pointer = ffi.new("CeedQFunctionContext *") 24 25 # Reference to Ceed 26 self._ceed = ceed 27 28 # libCEED call 29 err_code = lib.CeedQFunctionContextCreate( 30 self._ceed._pointer[0], self._pointer) 31 self._ceed._check_error(err_code) 32 33 # Destructor 34 def __del__(self): 35 # libCEED call 36 err_code = lib.CeedQFunctionContextDestroy(self._pointer) 37 self._ceed._check_error(err_code) 38 39 # Representation 40 def __repr__(self): 41 return "<CeedQFunctionContext instance at " + hex(id(self)) + ">" 42 43 # String conversion for print() to stdout 44 def __str__(self): 45 """View a QFunction Context via print().""" 46 47 # libCEED call 48 fmt = ffi.new("char[]", "%f".encode('ascii')) 49 with tempfile.NamedTemporaryFile() as key_file: 50 with open(key_file.name, 'r+') as stream_file: 51 stream = ffi.cast("FILE *", stream_file) 52 53 err_code = lib.CeedQFunctionContextView( 54 self._pointer[0], stream) 55 self._ceed._check_error(err_code) 56 57 stream_file.seek(0) 58 out_string = stream_file.read() 59 60 return out_string 61 62 # Set QFunction Context's data 63 def set_data(self, data, memtype=MEM_HOST, cmode=COPY_VALUES): 64 """Set the data used by a QFunction Context, freeing any previously allocated 65 data if applicable. 66 67 Args: 68 *data: Numpy or Numba array to be used 69 **memtype: memory type of the array being passed, default CEED_MEM_HOST 70 **cmode: copy mode for the array, default CEED_COPY_VALUES""" 71 72 # Store array reference if needed 73 if cmode == USE_POINTER: 74 self._array_reference = data 75 else: 76 self._array_reference = None 77 78 # Setup the numpy array for the libCEED call 79 data_pointer = ffi.new("CeedScalar *") 80 if memtype == MEM_HOST: 81 data_pointer = ffi.cast( 82 "void *", 83 data.__array_interface__['data'][0]) 84 else: 85 array_pointer = ffi.cast( 86 "void *", 87 data.__cuda_array_interface__['data'][0]) 88 89 # libCEED call 90 err_code = lib.CeedQFunctionContextSetData( 91 self._pointer[0], 92 memtype, 93 cmode, 94 len(data) * ffi.sizeof("CeedScalar"), 95 data_pointer) 96 self._ceed._check_error(err_code) 97 98 # Get QFunction Context's data 99 def get_data(self, memtype=MEM_HOST): 100 """Get read/write access to a QFunction Context via the specified memory type. 101 102 Args: 103 **memtype: memory type of the array being passed, default CEED_MEM_HOST 104 105 Returns: 106 *data: Numpy or Numba array""" 107 108 # Retrieve the length of the array 109 size_pointer = ffi.new("size_t *") 110 err_code = lib.CeedQFunctionContextGetContextSize( 111 self._pointer[0], size_pointer) 112 self._ceed._check_error(err_code) 113 114 # Setup the pointer's pointer 115 data_pointer = ffi.new("CeedScalar **") 116 117 # libCEED call 118 err_code = lib.CeedQFunctionContextGetData( 119 self._pointer[0], memtype, data_pointer) 120 self._ceed._check_error(err_code) 121 122 # Return array created from buffer 123 if memtype == MEM_HOST: 124 # Create buffer object from returned pointer 125 buff = ffi.buffer( 126 data_pointer[0], 127 size_pointer[0]) 128 # return Numpy array 129 return np.frombuffer(buff, dtype=scalar_types[lib.CEED_SCALAR_TYPE]) 130 else: 131 # CUDA array interface 132 # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html 133 import numba.cuda as nbcuda 134 if lib.CEED_SCALAR_TYPE == lib.CEED_SCALAR_FP32: 135 scalar_type_str = '>f4' 136 else: 137 scalar_type_str = '>f8' 138 desc = { 139 'shape': (size_pointer[0] / ffi.sizeof("CeedScalar")), 140 'typestr': scalar_type_str, 141 'data': (int(ffi.cast("intptr_t", data_pointer[0])), False), 142 'version': 2 143 } 144 # return Numba array 145 return nbcuda.from_cuda_array_interface(desc) 146 147 # Restore the QFunction Context's data 148 def restore_data(self): 149 """Restore an array obtained using get_data().""" 150 151 # Setup the pointer's pointer 152 data_pointer = ffi.new("CeedScalar **") 153 154 # libCEED call 155 err_code = lib.CeedQFunctionDataRestoreData( 156 self._pointer[0], data_pointer) 157 self._ceed._check_error(err_code) 158 159 @contextlib.contextmanager 160 def data(self, *shape, memtype=MEM_HOST): 161 """Context manager for array access. 162 163 Args: 164 shape (tuple): shape of returned numpy.array 165 **memtype: memory type of the data being passed, default CEED_MEM_HOST 166 167 168 Returns: 169 np.array: writable view of QFunction Context 170 """ 171 x = self.get_data(memtype=memtype) 172 if shape: 173 x = x.reshape(shape) 174 yield x 175 self.restore_data() 176 177# ------------------------------------------------------------------------------ 178