blob: df2ac83710010032d07c7755e0960a705ffc6781 [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_SOLVERBASE_H
#define EIGEN_SOLVERBASE_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
template <typename Derived>
struct solve_assertion {
template <bool Transpose_, typename Rhs>
static void run(const Derived& solver, const Rhs& b) {
solver.template _check_solve_assertion<Transpose_>(b);
}
};
template <typename Derived>
struct solve_assertion<Transpose<Derived>> {
typedef Transpose<Derived> type;
template <bool Transpose_, typename Rhs>
static void run(const type& transpose, const Rhs& b) {
internal::solve_assertion<internal::remove_all_t<Derived>>::template run<true>(transpose.nestedExpression(), b);
}
};
template <typename Scalar, typename Derived>
struct solve_assertion<CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived>>> {
typedef CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived>> type;
template <bool Transpose_, typename Rhs>
static void run(const type& adjoint, const Rhs& b) {
internal::solve_assertion<internal::remove_all_t<Transpose<Derived>>>::template run<true>(
adjoint.nestedExpression(), b);
}
};
} // end namespace internal
/** \class SolverBase
* \brief A base class for matrix decomposition and solvers
*
* \tparam Derived the actual type of the decomposition/solver.
*
* Any matrix decomposition inheriting this base class provide the following API:
*
* \code
* MatrixType A, b, x;
* DecompositionType dec(A);
* x = dec.solve(b); // solve A * x = b
* x = dec.transpose().solve(b); // solve A^T * x = b
* x = dec.adjoint().solve(b); // solve A' * x = b
* \endcode
*
* \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation
* errors.
*
* \sa class PartialPivLU, class FullPivLU, class HouseholderQR, class ColPivHouseholderQR, class FullPivHouseholderQR,
* class CompleteOrthogonalDecomposition, class LLT, class LDLT, class SVDBase
*/
template <typename Derived>
class SolverBase : public EigenBase<Derived> {
public:
typedef EigenBase<Derived> Base;
typedef typename internal::traits<Derived>::Scalar Scalar;
typedef Scalar CoeffReturnType;
template <typename Derived_>
friend struct internal::solve_assertion;
enum {
RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime,
ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime,
SizeAtCompileTime = (internal::size_of_xpr_at_compile_time<Derived>::ret),
MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime,
MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime,
MaxSizeAtCompileTime = internal::size_at_compile_time(internal::traits<Derived>::MaxRowsAtCompileTime,
internal::traits<Derived>::MaxColsAtCompileTime),
IsVectorAtCompileTime =
internal::traits<Derived>::MaxRowsAtCompileTime == 1 || internal::traits<Derived>::MaxColsAtCompileTime == 1,
NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0
: bool(IsVectorAtCompileTime) ? 1
: 2
};
/** Default constructor */
SolverBase() {}
~SolverBase() {}
using Base::derived;
/** \returns an expression of the solution x of \f$ A x = b \f$ using the current decomposition of A.
*/
template <typename Rhs>
inline const Solve<Derived, Rhs> solve(const MatrixBase<Rhs>& b) const {
internal::solve_assertion<internal::remove_all_t<Derived>>::template run<false>(derived(), b);
return Solve<Derived, Rhs>(derived(), b.derived());
}
/** \internal the return type of transpose() */
typedef Transpose<const Derived> ConstTransposeReturnType;
/** \returns an expression of the transposed of the factored matrix.
*
* A typical usage is to solve for the transposed problem A^T x = b:
* \code x = dec.transpose().solve(b); \endcode
*
* \sa adjoint(), solve()
*/
inline const ConstTransposeReturnType transpose() const { return ConstTransposeReturnType(derived()); }
/** \internal the return type of adjoint() */
typedef std::conditional_t<NumTraits<Scalar>::IsComplex,
CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, const ConstTransposeReturnType>,
const ConstTransposeReturnType>
AdjointReturnType;
/** \returns an expression of the adjoint of the factored matrix
*
* A typical usage is to solve for the adjoint problem A' x = b:
* \code x = dec.adjoint().solve(b); \endcode
*
* For real scalar types, this function is equivalent to transpose().
*
* \sa transpose(), solve()
*/
inline const AdjointReturnType adjoint() const { return AdjointReturnType(derived().transpose()); }
protected:
template <bool Transpose_, typename Rhs>
void _check_solve_assertion(const Rhs& b) const {
EIGEN_ONLY_USED_FOR_DEBUG(b);
eigen_assert(derived().m_isInitialized && "Solver is not initialized.");
eigen_assert((Transpose_ ? derived().cols() : derived().rows()) == b.rows() &&
"SolverBase::solve(): invalid number of rows of the right hand side matrix b");
}
};
namespace internal {
template <typename Derived>
struct generic_xpr_base<Derived, MatrixXpr, SolverStorage> {
typedef SolverBase<Derived> type;
};
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_SOLVERBASE_H