// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2012 Désiré Nuentsa-Wakam <desire.nuentsa_wakam@inria.fr>
// Copyright (C) 2012 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_SPARSELU_SUPERNODAL_MATRIX_H
#define EIGEN_SPARSELU_SUPERNODAL_MATRIX_H

// IWYU pragma: private
#include "./InternalHeaderCheck.h"

namespace Eigen {
namespace internal {

/** \ingroup SparseLU_Module
 * \brief a class to manipulate the L supernodal factor from the SparseLU factorization
 * 
 * This class  contain the data to easily store 
 * and manipulate the supernodes during the factorization and solution phase of Sparse LU. 
 * Only the lower triangular matrix has supernodes.
 * 
 * NOTE : This class corresponds to the SCformat structure in SuperLU
 * 
 */
/* TODO
 * InnerIterator as for sparsematrix 
 * SuperInnerIterator to iterate through all supernodes 
 * Function for triangular solve
 */
template <typename Scalar_, typename StorageIndex_>
class MappedSuperNodalMatrix
{
  public:
    typedef Scalar_ Scalar;
    typedef StorageIndex_ StorageIndex;
    typedef Matrix<StorageIndex,Dynamic,1> IndexVector;
    typedef Matrix<Scalar,Dynamic,1> ScalarVector;
  public:
    MappedSuperNodalMatrix()
    {
      
    }
    MappedSuperNodalMatrix(Index m, Index n,  ScalarVector& nzval, IndexVector& nzval_colptr, IndexVector& rowind,
             IndexVector& rowind_colptr, IndexVector& col_to_sup, IndexVector& sup_to_col )
    {
      setInfos(m, n, nzval, nzval_colptr, rowind, rowind_colptr, col_to_sup, sup_to_col);
    }
    
    ~MappedSuperNodalMatrix()
    {
      
    }
    /**
     * Set appropriate pointers for the lower triangular supernodal matrix
     * These infos are available at the end of the numerical factorization
     * FIXME This class will be modified such that it can be use in the course 
     * of the factorization.
     */
    void setInfos(Index m, Index n, ScalarVector& nzval, IndexVector& nzval_colptr, IndexVector& rowind,
             IndexVector& rowind_colptr, IndexVector& col_to_sup, IndexVector& sup_to_col )
    {
      m_row = m;
      m_col = n; 
      m_nzval = nzval.data(); 
      m_nzval_colptr = nzval_colptr.data(); 
      m_rowind = rowind.data(); 
      m_rowind_colptr = rowind_colptr.data(); 
      m_nsuper = col_to_sup(n); 
      m_col_to_sup = col_to_sup.data(); 
      m_sup_to_col = sup_to_col.data(); 
    }
    
    /**
     * Number of rows
     */
    Index rows() const { return m_row; }
    
    /**
     * Number of columns
     */
    Index cols() const { return m_col; }
    
    /**
     * Return the array of nonzero values packed by column
     * 
     * The size is nnz
     */
    Scalar* valuePtr() {  return m_nzval; }
    
    const Scalar* valuePtr() const 
    {
      return m_nzval; 
    }
    /**
     * Return the pointers to the beginning of each column in \ref valuePtr()
     */
    StorageIndex* colIndexPtr()
    {
      return m_nzval_colptr; 
    }
    
    const StorageIndex* colIndexPtr() const
    {
      return m_nzval_colptr; 
    }
    
    /**
     * Return the array of compressed row indices of all supernodes
     */
    StorageIndex* rowIndex()  { return m_rowind; }
    
    const StorageIndex* rowIndex() const
    {
      return m_rowind; 
    }
    
    /**
     * Return the location in \em rowvaluePtr() which starts each column
     */
    StorageIndex* rowIndexPtr() { return m_rowind_colptr; }
    
    const StorageIndex* rowIndexPtr() const
    {
      return m_rowind_colptr; 
    }
    
    /** 
     * Return the array of column-to-supernode mapping 
     */
    StorageIndex* colToSup()  { return m_col_to_sup; }
    
    const StorageIndex* colToSup() const
    {
      return m_col_to_sup;       
    }
    /**
     * Return the array of supernode-to-column mapping
     */
    StorageIndex* supToCol() { return m_sup_to_col; }
    
    const StorageIndex* supToCol() const
    {
      return m_sup_to_col;
    }
    
    /**
     * Return the number of supernodes
     */
    Index nsuper() const
    {
      return m_nsuper; 
    }
    
    class InnerIterator; 
    template<typename Dest>
    void solveInPlace( MatrixBase<Dest>&X) const;
    template<bool Conjugate, typename Dest>
    void solveTransposedInPlace( MatrixBase<Dest>&X) const;

    
      
      
    
  protected:
    Index m_row; // Number of rows
    Index m_col; // Number of columns
    Index m_nsuper; // Number of supernodes
    Scalar* m_nzval; //array of nonzero values packed by column
    StorageIndex* m_nzval_colptr; //nzval_colptr[j] Stores the location in nzval[] which starts column j
    StorageIndex* m_rowind; // Array of compressed row indices of rectangular supernodes
    StorageIndex* m_rowind_colptr; //rowind_colptr[j] stores the location in rowind[] which starts column j
    StorageIndex* m_col_to_sup; // col_to_sup[j] is the supernode number to which column j belongs
    StorageIndex* m_sup_to_col; //sup_to_col[s] points to the starting column of the s-th supernode
    
  private :
};

/**
  * \brief InnerIterator class to iterate over nonzero values of the current column in the supernodal matrix L
  * 
  */
template<typename Scalar, typename StorageIndex>
class MappedSuperNodalMatrix<Scalar,StorageIndex>::InnerIterator
{
  public:
     InnerIterator(const MappedSuperNodalMatrix& mat, Index outer)
      : m_matrix(mat),
        m_outer(outer),
        m_supno(mat.colToSup()[outer]),
        m_idval(mat.colIndexPtr()[outer]),
        m_startidval(m_idval),
        m_endidval(mat.colIndexPtr()[outer+1]),
        m_idrow(mat.rowIndexPtr()[mat.supToCol()[mat.colToSup()[outer]]]),
        m_endidrow(mat.rowIndexPtr()[mat.supToCol()[mat.colToSup()[outer]]+1])
    {}
    inline InnerIterator& operator++()
    { 
      m_idval++; 
      m_idrow++;
      return *this;
    }
    inline Scalar value() const { return m_matrix.valuePtr()[m_idval]; }
    
    inline Scalar& valueRef() { return const_cast<Scalar&>(m_matrix.valuePtr()[m_idval]); }
    
    inline Index index() const { return m_matrix.rowIndex()[m_idrow]; }
    inline Index row() const { return index(); }
    inline Index col() const { return m_outer; }
    
    inline Index supIndex() const { return m_supno; }
    
    inline operator bool() const 
    { 
      return ( (m_idval < m_endidval) && (m_idval >= m_startidval)
                && (m_idrow < m_endidrow) );
    }
    
  protected:
    const MappedSuperNodalMatrix& m_matrix; // Supernodal lower triangular matrix 
    const Index m_outer;                    // Current column 
    const Index m_supno;                    // Current SuperNode number
    Index m_idval;                          // Index to browse the values in the current column
    const Index m_startidval;               // Start of the column value
    const Index m_endidval;                 // End of the column value
    Index m_idrow;                          // Index to browse the row indices 
    Index m_endidrow;                       // End index of row indices of the current column
};

/**
 * \brief Solve with the supernode triangular matrix
 * 
 */
template<typename Scalar, typename Index_>
template<typename Dest>
void MappedSuperNodalMatrix<Scalar,Index_>::solveInPlace( MatrixBase<Dest>&X) const
{
    /* Explicit type conversion as the Index type of MatrixBase<Dest> may be wider than Index */
//    eigen_assert(X.rows() <= NumTraits<Index>::highest());
//    eigen_assert(X.cols() <= NumTraits<Index>::highest());
    Index n    = int(X.rows());
    Index nrhs = Index(X.cols());
    const Scalar * Lval = valuePtr();                 // Nonzero values 
    Matrix<Scalar,Dynamic,Dest::ColsAtCompileTime, ColMajor> work(n, nrhs);     // working vector
    work.setZero();
    for (Index k = 0; k <= nsuper(); k ++)
    {
      Index fsupc = supToCol()[k];                    // First column of the current supernode 
      Index istart = rowIndexPtr()[fsupc];            // Pointer index to the subscript of the current column
      Index nsupr = rowIndexPtr()[fsupc+1] - istart;  // Number of rows in the current supernode
      Index nsupc = supToCol()[k+1] - fsupc;          // Number of columns in the current supernode
      Index nrow = nsupr - nsupc;                     // Number of rows in the non-diagonal part of the supernode
      Index irow;                                     //Current index row
      
      if (nsupc == 1 )
      {
        for (Index j = 0; j < nrhs; j++)
        {
          InnerIterator it(*this, fsupc);
          ++it; // Skip the diagonal element
          for (; it; ++it)
          {
            irow = it.row();
            X(irow, j) -= X(fsupc, j) * it.value();
          }
        }
      }
      else
      {
        // The supernode has more than one column 
        Index luptr = colIndexPtr()[fsupc]; 
        Index lda = colIndexPtr()[fsupc+1] - luptr;
        
        // Triangular solve 
        Map<const Matrix<Scalar,Dynamic,Dynamic, ColMajor>, 0, OuterStride<> > A( &(Lval[luptr]), nsupc, nsupc, OuterStride<>(lda) );
        typename Dest::RowsBlockXpr U = X.derived().middleRows(fsupc, nsupc);
        U = A.template triangularView<UnitLower>().solve(U);        
        // Matrix-vector product 
        new (&A) Map<const Matrix<Scalar,Dynamic,Dynamic, ColMajor>, 0, OuterStride<> > ( &(Lval[luptr+nsupc]), nrow, nsupc, OuterStride<>(lda) );
        work.topRows(nrow).noalias() = A * U;
        
        //Begin Scatter 
        for (Index j = 0; j < nrhs; j++)
        {
          Index iptr = istart + nsupc; 
          for (Index i = 0; i < nrow; i++)
          {
            irow = rowIndex()[iptr]; 
            X(irow, j) -= work(i, j); // Scatter operation
            work(i, j) = Scalar(0); 
            iptr++;
          }
        }
      }
    } 
}

template<typename Scalar, typename Index_>
template<bool Conjugate, typename Dest>
void MappedSuperNodalMatrix<Scalar,Index_>::solveTransposedInPlace( MatrixBase<Dest>&X) const
{
    using numext::conj;
  Index n    = int(X.rows());
  Index nrhs = Index(X.cols());
  const Scalar * Lval = valuePtr();                 // Nonzero values
  Matrix<Scalar,Dynamic,Dest::ColsAtCompileTime, ColMajor> work(n, nrhs);     // working vector
  work.setZero();
  for (Index k = nsuper(); k >= 0; k--)
  {
    Index fsupc = supToCol()[k];                    // First column of the current supernode
    Index istart = rowIndexPtr()[fsupc];            // Pointer index to the subscript of the current column
    Index nsupr = rowIndexPtr()[fsupc+1] - istart;  // Number of rows in the current supernode
    Index nsupc = supToCol()[k+1] - fsupc;          // Number of columns in the current supernode
    Index nrow = nsupr - nsupc;                     // Number of rows in the non-diagonal part of the supernode
    Index irow;                                     //Current index row

    if (nsupc == 1 )
    {
      for (Index j = 0; j < nrhs; j++)
      {
        InnerIterator it(*this, fsupc);
        ++it; // Skip the diagonal element
        for (; it; ++it)
        {
          irow = it.row();
          X(fsupc,j) -= X(irow, j) * (Conjugate?conj(it.value()):it.value());
        }
      }
    }
    else
    {
      // The supernode has more than one column
      Index luptr = colIndexPtr()[fsupc];
      Index lda = colIndexPtr()[fsupc+1] - luptr;

      //Begin Gather
      for (Index j = 0; j < nrhs; j++)
      {
        Index iptr = istart + nsupc;
        for (Index i = 0; i < nrow; i++)
        {
          irow = rowIndex()[iptr];
          work.topRows(nrow)(i,j)= X(irow,j); // Gather operation
          iptr++;
        }
      }

      // Matrix-vector product with transposed submatrix
      Map<const Matrix<Scalar,Dynamic,Dynamic, ColMajor>, 0, OuterStride<> > A( &(Lval[luptr+nsupc]), nrow, nsupc, OuterStride<>(lda) );
      typename Dest::RowsBlockXpr U = X.derived().middleRows(fsupc, nsupc);
      if(Conjugate)
        U = U - A.adjoint() * work.topRows(nrow);
      else
        U = U - A.transpose() * work.topRows(nrow);

      // Triangular solve (of transposed diagonal block)
      new (&A) Map<const Matrix<Scalar,Dynamic,Dynamic, ColMajor>, 0, OuterStride<> > ( &(Lval[luptr]), nsupc, nsupc, OuterStride<>(lda) );
      if(Conjugate)
        U = A.adjoint().template triangularView<UnitUpper>().solve(U);
      else
        U = A.transpose().template triangularView<UnitUpper>().solve(U);

    }

  }
}


} // end namespace internal

} // end namespace Eigen

#endif // EIGEN_SPARSELU_MATRIX_H
