1# Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at 2# the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights 3# reserved. See files LICENSE and NOTICE for details. 4# 5# This file is part of CEED, a collection of benchmarks, miniapps, software 6# libraries and APIs for efficient high-order finite element and spectral 7# element discretizations for exascale applications. For more information and 8# source code availability see http://github.com/ceed. 9# 10# The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11# a collaborative effort of two U.S. Department of Energy organizations (Office 12# of Science and the National Nuclear Security Administration) responsible for 13# the planning and preparation of a capable exascale ecosystem, including 14# software, applications, hardware, advanced system engineering and early 15# testbed platforms, in support of the nation's exascale computing imperative. 16 17from _ceed_cffi import ffi, lib 18import tempfile 19import numpy as np 20import contextlib 21from .ceed_constants import MEM_HOST, USE_POINTER, COPY_VALUES, scalar_types 22 23# ------------------------------------------------------------------------------ 24 25 26class QFunctionContext(): 27 """Ceed QFunction Context: stores Ceed QFunction user context data.""" 28 29 # Constructor 30 def __init__(self, ceed): 31 # CeedQFunctionContext object 32 self._pointer = ffi.new("CeedQFunctionContext *") 33 34 # Reference to Ceed 35 self._ceed = ceed 36 37 # libCEED call 38 err_code = lib.CeedQFunctionContextCreate( 39 self._ceed._pointer[0], self._pointer) 40 self._ceed._check_error(err_code) 41 42 # Destructor 43 def __del__(self): 44 # libCEED call 45 err_code = lib.CeedQFunctionContextDestroy(self._pointer) 46 self._ceed._check_error(err_code) 47 48 # Representation 49 def __repr__(self): 50 return "<CeedQFunctionContext instance at " + hex(id(self)) + ">" 51 52 # String conversion for print() to stdout 53 def __str__(self): 54 """View a QFunction Context via print().""" 55 56 # libCEED call 57 fmt = ffi.new("char[]", "%f".encode('ascii')) 58 with tempfile.NamedTemporaryFile() as key_file: 59 with open(key_file.name, 'r+') as stream_file: 60 stream = ffi.cast("FILE *", stream_file) 61 62 err_code = lib.CeedQFunctionContextView( 63 self._pointer[0], stream) 64 self._ceed._check_error(err_code) 65 66 stream_file.seek(0) 67 out_string = stream_file.read() 68 69 return out_string 70 71 # Set QFunction Context's data 72 def set_data(self, data, memtype=MEM_HOST, cmode=COPY_VALUES): 73 """Set the data used by a QFunction Context, freeing any previously allocated 74 data if applicable. 75 76 Args: 77 *data: Numpy or Numba array to be used 78 **memtype: memory type of the array being passed, default CEED_MEM_HOST 79 **cmode: copy mode for the array, default CEED_COPY_VALUES""" 80 81 # Store array reference if needed 82 if cmode == USE_POINTER: 83 self._array_reference = data 84 else: 85 self._array_reference = None 86 87 # Setup the numpy array for the libCEED call 88 data_pointer = ffi.new("CeedScalar *") 89 if memtype == MEM_HOST: 90 data_pointer = ffi.cast( 91 "void *", 92 data.__array_interface__['data'][0]) 93 else: 94 array_pointer = ffi.cast( 95 "void *", 96 data.__cuda_array_interface__['data'][0]) 97 98 # libCEED call 99 err_code = lib.CeedQFunctionContextSetData( 100 self._pointer[0], 101 memtype, 102 cmode, 103 len(data) * ffi.sizeof("CeedScalar"), 104 data_pointer) 105 self._ceed._check_error(err_code) 106 107 # Get QFunction Context's data 108 def get_data(self, memtype=MEM_HOST): 109 """Get read/write access to a QFunction Context via the specified memory type. 110 111 Args: 112 **memtype: memory type of the array being passed, default CEED_MEM_HOST 113 114 Returns: 115 *data: Numpy or Numba array""" 116 117 # Retrieve the length of the array 118 size_pointer = ffi.new("size_t *") 119 err_code = lib.CeedQFunctionContextGetContextSize( 120 self._pointer[0], size_pointer) 121 self._ceed._check_error(err_code) 122 123 # Setup the pointer's pointer 124 data_pointer = ffi.new("CeedScalar **") 125 126 # libCEED call 127 err_code = lib.CeedQFunctionContextGetData( 128 self._pointer[0], memtype, data_pointer) 129 self._ceed._check_error(err_code) 130 131 # Return array created from buffer 132 if memtype == MEM_HOST: 133 # Create buffer object from returned pointer 134 buff = ffi.buffer( 135 data_pointer[0], 136 size_pointer[0]) 137 # return Numpy array 138 return np.frombuffer(buff, dtype=scalar_types[lib.CEED_SCALAR_TYPE]) 139 else: 140 # CUDA array interface 141 # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html 142 import numba.cuda as nbcuda 143 if lib.CEED_SCALAR_TYPE == lib.CEED_SCALAR_FP32: 144 scalar_type_str = '>f4' 145 else: 146 scalar_type_str = '>f8' 147 desc = { 148 'shape': (size_pointer[0] / ffi.sizeof("CeedScalar")), 149 'typestr': scalar_type_str, 150 'data': (int(ffi.cast("intptr_t", data_pointer[0])), False), 151 'version': 2 152 } 153 # return Numba array 154 return nbcuda.from_cuda_array_interface(desc) 155 156 # Restore the QFunction Context's data 157 def restore_data(self): 158 """Restore an array obtained using get_data().""" 159 160 # Setup the pointer's pointer 161 data_pointer = ffi.new("CeedScalar **") 162 163 # libCEED call 164 err_code = lib.CeedQFunctionDataRestoreData( 165 self._pointer[0], data_pointer) 166 self._ceed._check_error(err_code) 167 168 @contextlib.contextmanager 169 def data(self, *shape, memtype=MEM_HOST): 170 """Context manager for array access. 171 172 Args: 173 shape (tuple): shape of returned numpy.array 174 **memtype: memory type of the data being passed, default CEED_MEM_HOST 175 176 177 Returns: 178 np.array: writable view of QFunction Context 179 """ 180 x = self.get_data(memtype=memtype) 181 if shape: 182 x = x.reshape(shape) 183 yield x 184 self.restore_data() 185 186# ------------------------------------------------------------------------------ 187