# Copyright (c) 2017, Lawrence Livermore National Security, LLC. Produced at
# the Lawrence Livermore National Laboratory. LLNL-CODE-734707. All Rights
# reserved. See files LICENSE and NOTICE for details.
#
# This file is part of CEED, a collection of benchmarks, miniapps, software
# libraries and APIs for efficient high-order finite element and spectral
# element discretizations for exascale applications. For more information and
# source code availability see http://github.com/ceed.
#
# The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
# a collaborative effort of two U.S. Department of Energy organizations (Office
# of Science and the National Nuclear Security Administration) responsible for
# the planning and preparation of a capable exascale ecosystem, including
# software, applications, hardware, advanced system engineering and early
# testbed platforms, in support of the nation's exascale computing imperative.

from _ceed_cffi import ffi, lib
import tempfile
import numpy as np
from abc import ABC
from .ceed_constants import TRANSPOSE, NOTRANSPOSE

# ------------------------------------------------------------------------------


class Basis(ABC):
    """Ceed Basis: finite element basis objects."""

    # Attributes
    _ceed = ffi.NULL
    _pointer = ffi.NULL

    # Representation
    def __repr__(self):
        return "<CeedBasis instance at " + hex(id(self)) + ">"

    # String conversion for print() to stdout
    def __str__(self):
        """View a Basis via print()."""

        # libCEED call
        with tempfile.NamedTemporaryFile() as key_file:
            with open(key_file.name, 'r+') as stream_file:
                stream = ffi.cast("FILE *", stream_file)

                err_code = lib.CeedBasisView(self._pointer[0], stream)
                self._ceed._check_error(err_code)

                stream_file.seek(0)
                out_string = stream_file.read()

        return out_string

    # Apply Basis
    def apply(self, nelem, emode, u, v, tmode=NOTRANSPOSE):
        """Apply basis evaluation from nodes to quadrature points or vice versa.

           Args:
             nelem: the number of elements to apply the basis evaluation to;
                      the backend will specify the ordering in a
                      BlockedElemRestriction
             emode: basis evaluation mode
             u: input vector
             v: output vector
             **tmode: CEED_NOTRANSPOSE to evaluate from nodes to quadrature
                        points, CEED_TRANSPOSE to apply the transpose, mapping
                        from quadrature points to nodes; default CEED_NOTRANSPOSE"""

        # libCEED call
        err_code = lib.CeedBasisApply(self._pointer[0], nelem, tmode, emode,
                                      u._pointer[0], v._pointer[0])
        self._ceed._check_error(err_code)

    # Transpose a Basis
    @property
    def T(self):
        """Transpose a Basis."""

        return TransposeBasis(self)

    # Transpose a Basis
    @property
    def transpose(self):
        """Transpose a Basis."""

        return TransposeBasis(self)

    # Get number of nodes
    def get_num_nodes(self):
        """Get total number of nodes (in dim dimensions) of a Basis.

           Returns:
             num_nodes: total number of nodes"""

        # Setup argument
        p_pointer = ffi.new("CeedInt *")

        # libCEED call
        err_code = lib.CeedBasisGetNumNodes(self._pointer[0], p_pointer)
        self._ceed._check_error(err_code)

        return p_pointer[0]

    # Get number of quadrature points
    def get_num_quadrature_points(self):
        """Get total number of quadrature points (in dim dimensions) of a Basis.

           Returns:
             num_qpts: total number of quadrature points"""

        # Setup argument
        q_pointer = ffi.new("CeedInt *")

        # libCEED call
        err_code = lib.CeedBasisGetNumQuadraturePoints(
            self._pointer[0], q_pointer)
        self._ceed._check_error(err_code)

        return q_pointer[0]

    # Gauss quadrature
    @staticmethod
    def gauss_quadrature(q):
        """Construct a Gauss-Legendre quadrature.

           Args:
             Q: number of quadrature points (integrates polynomials of
                  degree 2*Q-1 exactly)

           Returns:
             (qref1d, qweight1d): array of length Q to hold the abscissa on [-1, 1]
                                    and array of length Q to hold the weights"""

        # Setup arguments
        qref1d = np.empty(q, dtype="float64")
        qweight1d = np.empty(q, dtype="float64")

        qref1d_pointer = ffi.new("CeedScalar *")
        qref1d_pointer = ffi.cast(
            "CeedScalar *",
            qref1d.__array_interface__['data'][0])

        qweight1d_pointer = ffi.new("CeedScalar *")
        qweight1d_pointer = ffi.cast(
            "CeedScalar *",
            qweight1d.__array_interface__['data'][0])

        # libCEED call
        err_code = lib.CeedGaussQuadrature(q, qref1d_pointer, qweight1d_pointer)
        self._ceed._check_error(err_code)

        return qref1d, qweight1d

    # Lobatto quadrature
    @staticmethod
    def lobatto_quadrature(q):
        """Construct a Gauss-Legendre-Lobatto quadrature.

           Args:
             q: number of quadrature points (integrates polynomials of
                  degree 2*Q-3 exactly)

           Returns:
             (qref1d, qweight1d): array of length Q to hold the abscissa on [-1, 1]
                                    and array of length Q to hold the weights"""

        # Setup arguments
        qref1d = np.empty(q, dtype="float64")
        qref1d_pointer = ffi.new("CeedScalar *")
        qref1d_pointer = ffi.cast(
            "CeedScalar *",
            qref1d.__array_interface__['data'][0])

        qweight1d = np.empty(q, dtype="float64")
        qweight1d_pointer = ffi.new("CeedScalar *")
        qweight1d_pointer = ffi.cast(
            "CeedScalar *",
            qweight1d.__array_interface__['data'][0])

        # libCEED call
        err_code = lib.CeedLobattoQuadrature(
            q, qref1d_pointer, qweight1d_pointer)
        self._ceed._check_error(err_code)

        return qref1d, qweight1d

    # QR factorization
    @staticmethod
    def qr_factorization(ceed, mat, tau, m, n):
        """Return QR Factorization of a matrix.

           Args:
             ceed: Ceed context currently in use
             *mat: Numpy array holding the row-major matrix to be factorized in place
             *tau: Numpy array to hold the vector of lengt m of scaling factors
             m: number of rows
             n: numbef of columns"""

        # Setup arguments
        mat_pointer = ffi.new("CeedScalar *")
        mat_pointer = ffi.cast(
            "CeedScalar *",
            mat.__array_interface__['data'][0])

        tau_pointer = ffi.new("CeedScalar *")
        tau_pointer = ffi.cast(
            "CeedScalar *",
            tau.__array_interface__['data'][0])

        # libCEED call
        lib.CeedQRFactorization(
            ceed._pointer[0], mat_pointer, tau_pointer, m, n)

        return mat, tau

    # Symmetric Schur decomposition
    @staticmethod
    def symmetric_schur_decomposition(ceed, mat, n):
        """Return symmetric Schur decomposition of a symmetric matrix
             via symmetric QR factorization.

           Args:
             ceed: Ceed context currently in use
             *mat: Numpy array holding the row-major matrix to be factorized in place
             n: number of rows/columns

           Returns:
             lbda: Numpy array of length n holding eigenvalues"""

        # Setup arguments
        mat_pointer = ffi.new("CeedScalar *")
        mat_pointer = ffi.cast(
            "CeedScalar *",
            mat.__array_interface__['data'][0])

        lbda = np.empty(n, dtype="float64")
        l_pointer = ffi.new("CeedScalar *")
        l_pointer = ffi.cast(
            "CeedScalar *",
            lbda.__array_interface__['data'][0])

        # libCEED call
        lib.CeedSymmetricSchurDecomposition(
            ceed._pointer[0], mat_pointer, l_pointer, n)

        return lbda

    # Simultaneous Diagonalization
    @staticmethod
    def simultaneous_diagonalization(ceed, matA, matB, n):
        """Return Simultaneous Diagonalization of two matrices.

           Args:
             ceed: Ceed context currently in use
             *matA: Numpy array holding the row-major matrix to be factorized with
                      eigenvalues
             *matB: Numpy array holding the row-major matrix to be factorized to identity
             n: number of rows/columns

           Returns:
             (x, lbda): Numpy array holding the row-major orthogonal matrix and
                          Numpy array holding the vector of length n of generalized
                          eigenvalues"""

        # Setup arguments
        matA_pointer = ffi.new("CeedScalar *")
        matA_pointer = ffi.cast(
            "CeedScalar *",
            matA.__array_interface__['data'][0])

        matB_pointer = ffi.new("CeedScalar *")
        matB_pointer = ffi.cast(
            "CeedScalar *",
            matB.__array_interface__['data'][0])

        lbda = np.empty(n, dtype="float64")
        l_pointer = ffi.new("CeedScalar *")
        l_pointer = ffi.cast(
            "CeedScalar *",
            lbda.__array_interface__['data'][0])

        x = np.empty(n * n, dtype="float64")
        x_pointer = ffi.new("CeedScalar *")
        x_pointer = ffi.cast("CeedScalar *", x.__array_interface__['data'][0])

        # libCEED call
        lib.CeedSimultaneousDiagonalization(ceed._pointer[0], matA_pointer, matB_pointer,
                                            x_pointer, l_pointer, n)

        return x, lbda

    # Destructor
    def __del__(self):
        # libCEED call
        err_code = lib.CeedBasisDestroy(self._pointer)
        self._ceed._check_error(err_code)

# ------------------------------------------------------------------------------


class BasisTensorH1(Basis):
    """Ceed Tensor H1 Basis: finite element tensor-product basis objects for
         H^1 discretizations."""

    # Constructor
    def __init__(self, ceed, dim, ncomp, P1d, Q1d, interp1d, grad1d,
                 qref1d, qweight1d):

        # Setup arguments
        self._pointer = ffi.new("CeedBasis *")

        self._ceed = ceed

        interp1d_pointer = ffi.new("CeedScalar *")
        interp1d_pointer = ffi.cast(
            "CeedScalar *",
            interp1d.__array_interface__['data'][0])

        grad1d_pointer = ffi.new("CeedScalar *")
        grad1d_pointer = ffi.cast(
            "CeedScalar *",
            grad1d.__array_interface__['data'][0])

        qref1d_pointer = ffi.new("CeedScalar *")
        qref1d_pointer = ffi.cast(
            "CeedScalar *",
            qref1d.__array_interface__['data'][0])

        qweight1d_pointer = ffi.new("CeedScalar *")
        qweight1d_pointer = ffi.cast(
            "CeedScalar *",
            qweight1d.__array_interface__['data'][0])

        # libCEED call
        err_code = lib.CeedBasisCreateTensorH1(self._ceed._pointer[0], dim, ncomp,
                                               P1d, Q1d, interp1d_pointer,
                                               grad1d_pointer, qref1d_pointer,
                                               qweight1d_pointer, self._pointer)
        self._ceed._check_error(err_code)

# ------------------------------------------------------------------------------


class BasisTensorH1Lagrange(Basis):
    """Ceed Tensor H1 Lagrange Basis: finite element tensor-product Lagrange basis
         objects for H^1 discretizations."""

    # Constructor
    def __init__(self, ceed, dim, ncomp, P, Q, qmode):

        # Setup arguments
        self._pointer = ffi.new("CeedBasis *")

        self._ceed = ceed

        # libCEED call
        err_code = lib.CeedBasisCreateTensorH1Lagrange(self._ceed._pointer[0], dim,
                                                       ncomp, P, Q, qmode, self._pointer)
        self._ceed._check_error(err_code)

# ------------------------------------------------------------------------------


class BasisH1(Basis):
    """Ceed H1 Basis: finite element non tensor-product basis for H^1 discretizations."""

    # Constructor
    def __init__(self, ceed, topo, ncomp, nnodes,
                 nqpts, interp, grad, qref, qweight):

        # Setup arguments
        self._pointer = ffi.new("CeedBasis *")

        self._ceed = ceed

        interp_pointer = ffi.new("CeedScalar *")
        interp_pointer = ffi.cast(
            "CeedScalar *",
            interp.__array_interface__['data'][0])

        grad_pointer = ffi.new("CeedScalar *")
        grad_pointer = ffi.cast(
            "CeedScalar *",
            grad.__array_interface__['data'][0])

        qref_pointer = ffi.new("CeedScalar *")
        qref_pointer = ffi.cast(
            "CeedScalar *",
            qref.__array_interface__['data'][0])

        qweight_pointer = ffi.new("CeedScalar *")
        qweight_pointer = ffi.cast(
            "CeedScalar *",
            qweight.__array_interface__['data'][0])

        # libCEED call
        err_code = lib.CeedBasisCreateH1(self._ceed._pointer[0], topo, ncomp,
                                         nnodes, nqpts, interp_pointer,
                                         grad_pointer, qref_pointer,
                                         qweight_pointer, self._pointer)

# ------------------------------------------------------------------------------


class TransposeBasis():
    """Transpose Ceed Basis: transpose of finite element tensor-product basis objects."""

    # Attributes
    _basis = None

    # Constructor
    def __init__(self, basis):

        # Reference basis
        self._basis = basis

    # Representation
    def __repr__(self):
        return "<Transpose CeedBasis instance at " + hex(id(self)) + ">"

    # Apply Transpose Basis
    def apply(self, nelem, emode, u, v):
        """Apply basis evaluation from quadrature points to nodes.

           Args:
             nelem: the number of elements to apply the basis evaluation to;
                      the backend will specify the ordering in a
                      Blocked ElemRestriction
             **emode: basis evaluation mode
             u: input vector
             v: output vector"""

        # libCEED call
        self._basis.apply(nelem, emode, u, v, tmode=TRANSPOSE)

# ------------------------------------------------------------------------------
