1# Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at 2# the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights 3# reserved. See files LICENSE and NOTICE for details. 4# 5# This file is part of CEED, a collection of benchmarks, miniapps, software 6# libraries and APIs for efficient high-order finite element and spectral 7# element discretizations for exascale applications. For more information and 8# source code availability see http://github.com/ceed. 9# 10# The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11# a collaborative effort of two U.S. Department of Energy organizations (Office 12# of Science and the National Nuclear Security Administration) responsible for 13# the planning and preparation of a capable exascale ecosystem, including 14# software, applications, hardware, advanced system engineering and early 15# testbed platforms, in support of the nation's exascale computing imperative. 16 17from _ceed_cffi import ffi, lib 18import tempfile 19import numpy as np 20import contextlib 21from .ceed_constants import MEM_HOST, COPY_VALUES 22 23# ------------------------------------------------------------------------------ 24 25 26class QFunctionContext(): 27 """Ceed QFunction Context: stores Ceed QFunction user context data.""" 28 29 # Attributes 30 _ceed = ffi.NULL 31 _pointer = ffi.NULL 32 33 # Constructor 34 def __init__(self, ceed): 35 # CeedQFunctionContext object 36 self._pointer = ffi.new("CeedQFunctionContext *") 37 38 # Reference to Ceed 39 self._ceed = ceed 40 41 # libCEED call 42 err_code = lib.CeedQFunctionContextCreate( 43 self._ceed._pointer[0], self._pointer) 44 self._ceed._check_error(err_code) 45 46 # Destructor 47 def __del__(self): 48 # libCEED call 49 err_code = lib.CeedQFunctionContextDestroy(self._pointer) 50 self._ceed._check_error(err_code) 51 52 # Representation 53 def __repr__(self): 54 return "<CeedQFunctionContext instance at " + hex(id(self)) + ">" 55 56 # String conversion for print() to stdout 57 def __str__(self): 58 """View a QFunction Context via print().""" 59 60 # libCEED call 61 fmt = ffi.new("char[]", "%f".encode('ascii')) 62 with tempfile.NamedTemporaryFile() as key_file: 63 with open(key_file.name, 'r+') as stream_file: 64 stream = ffi.cast("FILE *", stream_file) 65 66 err_code = lib.CeedQFunctionContextView( 67 self._pointer[0], stream) 68 self._ceed._check_error(err_code) 69 70 stream_file.seek(0) 71 out_string = stream_file.read() 72 73 return out_string 74 75 # Set QFunction Context's data 76 def set_data(self, data, memtype=MEM_HOST, cmode=COPY_VALUES): 77 """Set the data used by a QFunction Context, freeing any previously allocated 78 data if applicable. 79 80 Args: 81 *data: Numpy or Numba array to be used 82 **memtype: memory type of the array being passed, default CEED_MEM_HOST 83 **cmode: copy mode for the array, default CEED_COPY_VALUES""" 84 85 # Setup the numpy array for the libCEED call 86 data_pointer = ffi.new("CeedScalar *") 87 if memtype == MEM_HOST: 88 data_pointer = ffi.cast( 89 "void *", 90 data.__array_interface__['data'][0]) 91 else: 92 array_pointer = ffi.cast( 93 "void *", 94 data.__cuda_array_interface__['data'][0]) 95 96 # libCEED call 97 err_code = lib.CeedQFunctionContextSetData( 98 self._pointer[0], 99 memtype, 100 cmode, 101 len(data) * ffi.sizeof("CeedScalar"), 102 data_pointer) 103 self._ceed._check_error(err_code) 104 105 # Get QFunction Context's data 106 def get_data(self, memtype=MEM_HOST): 107 """Get read/write access to a QFunction Context via the specified memory type. 108 109 Args: 110 **memtype: memory type of the array being passed, default CEED_MEM_HOST 111 112 Returns: 113 *data: Numpy or Numba array""" 114 115 # Retrieve the length of the array 116 size_pointer = ffi.new("CeedInt *") 117 err_code = lib.CeedQFunctionContextGetContextSize( 118 self._pointer[0], size_pointer) 119 self._ceed._check_error(err_code) 120 121 # Setup the pointer's pointer 122 data_pointer = ffi.new("CeedScalar **") 123 124 # libCEED call 125 err_code = lib.CeedQFunctionContextGetData( 126 self._pointer[0], memtype, data_pointer) 127 self._ceed._check_error(err_code) 128 129 # Return array created from buffer 130 if memtype == MEM_HOST: 131 # Create buffer object from returned pointer 132 buff = ffi.buffer( 133 data_pointer[0], 134 ffi.sizeof("CeedScalar") * 135 length_pointer[0]) 136 # return Numpy array 137 return np.frombuffer(buff, dtype="float64") 138 else: 139 # CUDA array interface 140 # https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html 141 import numba.cuda as nbcuda 142 desc = { 143 'shape': (length_pointer[0]), 144 'typestr': '>f8', 145 'data': (int(ffi.cast("intptr_t", data_pointer[0])), False), 146 'version': 2 147 } 148 # return Numba array 149 return nbcuda.from_cuda_array_interface(desc) 150 151 # Restore the QFunction Context's data 152 def restore_data(self): 153 """Restore an array obtained using get_data().""" 154 155 # Setup the pointer's pointer 156 data_pointer = ffi.new("CeedScalar **") 157 158 # libCEED call 159 err_code = lib.CeedQFunctionDataRestoreData( 160 self._pointer[0], data_pointer) 161 self._ceed._check_error(err_code) 162 163 @contextlib.contextmanager 164 def data(self, *shape, memtype=MEM_HOST): 165 """Context manager for array access. 166 167 Args: 168 shape (tuple): shape of returned numpy.array 169 **memtype: memory type of the data being passed, default CEED_MEM_HOST 170 171 172 Returns: 173 np.array: writable view of QFunction Context 174 """ 175 x = self.get_data(memtype=memtype) 176 if shape: 177 x = x.reshape(shape) 178 yield x 179 self.restore_data() 180 181# ------------------------------------------------------------------------------ 182