xref: /petsc/src/binding/petsc4py/src/petsc4py/PETSc/dlpack.pxi (revision 552edb6364df478b294b3111f33a8f37ca096b20)
1# DLPack interface
2
3cdef extern from "Python.h":
4    ctypedef void (*PyCapsule_Destructor)(object)
5    bint PyCapsule_IsValid(object, const char*)
6    void* PyCapsule_GetPointer(object, const char*) except? NULL
7    int PyCapsule_SetName(object, const char*) except -1
8    object PyCapsule_New(void*, const char*, PyCapsule_Destructor)
9    int PyCapsule_CheckExact(object)
10
11cdef extern from "<stdlib.h>" nogil:
12    ctypedef signed long int64_t
13    ctypedef unsigned long long uint64_t
14    ctypedef unsigned char uint8_t
15    ctypedef unsigned short uint16_t
16    void free(void* ptr)
17    void* malloc(size_t size)
18
19cdef struct DLDataType:
20    uint8_t code
21    uint8_t bits
22    uint16_t lanes
23
24cdef enum PetscDLDeviceType:
25    kDLCPU = <unsigned int>1
26    kDLCUDA = <unsigned int>2
27    kDLCUDAHost = <unsigned int>3
28    # kDLOpenCL = <unsigned int>4
29    # kDLVulkan = <unsigned int>7
30    # kDLMetal = <unsigned int>8
31    # kDLVPI = <unsigned int>9
32    kDLROCM = <unsigned int>10
33    kDLROCMHost = <unsigned int>11
34    # kDLExtDev = <unsigned int>12
35    kDLCUDAManaged = <unsigned int>13
36    # kDLOneAPI = <unsigned int>14
37
38ctypedef struct DLContext:
39    PetscDLDeviceType device_type
40    int device_id
41
42cdef enum DLDataTypeCode:
43    kDLInt = <unsigned int>0
44    kDLUInt = <unsigned int>1
45    kDLFloat = <unsigned int>2
46
47cdef struct DLTensor:
48    void* data
49    DLContext ctx
50    int ndim
51    DLDataType dtype
52    int64_t* shape
53    int64_t* strides
54    uint64_t byte_offset
55
56ctypedef int (*dlpack_manager_del_obj)(void*) noexcept nogil
57
58cdef struct DLManagedTensor:
59    DLTensor dl_tensor
60    void* manager_ctx
61    void (*manager_deleter)(DLManagedTensor*) noexcept nogil
62    dlpack_manager_del_obj del_obj
63
64cdef void pycapsule_deleter(object dltensor) noexcept:
65    cdef DLManagedTensor* dlm_tensor = NULL
66    # we do not call a used capsule's deleter
67    if PyCapsule_IsValid(dltensor, b'dltensor'):
68        dlm_tensor = <DLManagedTensor *>PyCapsule_GetPointer(dltensor, b'dltensor')
69        manager_deleter(dlm_tensor)
70
71cdef void manager_deleter(DLManagedTensor* tensor) noexcept nogil:
72    if tensor.manager_ctx is NULL:
73        return
74    free(tensor.dl_tensor.shape)
75    if tensor.del_obj is not NULL:
76        tensor.del_obj(&tensor.manager_ctx)
77    free(tensor)
78
79# --------------------------------------------------------------------
80