blob: befb485e8ce128c0ac48f8fb9507ed98fe3decd9 [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_SYMBOLIC_INDEX_H
#define EIGEN_SYMBOLIC_INDEX_H
// IWYU pragma: private
#include "../InternalHeaderCheck.h"
namespace Eigen {
/** \namespace Eigen::symbolic
* \ingroup Core_Module
*
* This namespace defines a set of classes and functions to build and evaluate symbolic expressions of scalar type
* Index. Here is a simple example:
*
* \code
* // First step, defines symbols:
* struct x_tag {}; static const symbolic::SymbolExpr<x_tag> x;
* struct y_tag {}; static const symbolic::SymbolExpr<y_tag> y;
* struct z_tag {}; static const symbolic::SymbolExpr<z_tag> z;
*
* // Defines an expression:
* auto expr = (x+3)/y+z;
*
* // And evaluate it: (c++14)
* std::cout << expr.eval(x=6,y=3,z=-13) << "\n";
*
* \endcode
*
* It is currently only used internally to define and manipulate the
* Eigen::placeholders::last and Eigen::placeholders::lastp1 symbols in
* Eigen::seq and Eigen::seqN.
*
*/
namespace symbolic {
template <typename Tag>
class Symbol;
template <typename Tag, typename Type>
class SymbolValue;
template <typename Arg0>
class NegateExpr;
template <typename Arg1, typename Arg2>
class AddExpr;
template <typename Arg1, typename Arg2>
class ProductExpr;
template <typename Arg1, typename Arg2>
class QuotientExpr;
template <typename IndexType = Index>
class ValueExpr;
/** \class BaseExpr
* \ingroup Core_Module
* Common base class of any symbolic expressions
*/
template <typename Derived_>
class BaseExpr {
public:
using Derived = Derived_;
constexpr const Derived& derived() const { return *static_cast<const Derived*>(this); }
/** Evaluate the expression given the \a values of the symbols.
*
* \param values defines the values of the symbols, as constructed by SymbolExpr::operator= operator.
*
*/
template <typename... Tags, typename... Types>
constexpr Index eval(const SymbolValue<Tags, Types>&... values) const {
return derived().eval_impl(values...);
}
/** Evaluate the expression at compile time given the \a values of the symbols.
*
* If a value is not known at compile-time, returns Eigen::Undefined.
*
*/
template <typename... Tags, typename... Types>
static constexpr Index eval_at_compile_time(const SymbolValue<Tags, Types>&...) {
return Derived::eval_at_compile_time_impl(SymbolValue<Tags, Types>{}...);
}
constexpr NegateExpr<Derived> operator-() const { return NegateExpr<Derived>(derived()); }
constexpr AddExpr<Derived, ValueExpr<>> operator+(Index b) const {
return AddExpr<Derived, ValueExpr<>>(derived(), b);
}
constexpr AddExpr<Derived, ValueExpr<>> operator-(Index a) const {
return AddExpr<Derived, ValueExpr<>>(derived(), -a);
}
constexpr ProductExpr<Derived, ValueExpr<>> operator*(Index a) const {
return ProductExpr<Derived, ValueExpr<> >(derived(), a);
}
constexpr QuotientExpr<Derived, ValueExpr<>> operator/(Index a) const {
return QuotientExpr<Derived, ValueExpr<> >(derived(), a);
}
friend constexpr AddExpr<Derived, ValueExpr<>> operator+(Index a, const BaseExpr& b) {
return AddExpr<Derived, ValueExpr<> >(b.derived(), a);
}
friend constexpr AddExpr<NegateExpr<Derived>, ValueExpr<>> operator-(Index a, const BaseExpr& b) {
return AddExpr<NegateExpr<Derived>, ValueExpr<> >(-b.derived(), a);
}
friend constexpr ProductExpr<ValueExpr<>, Derived> operator*(Index a, const BaseExpr& b) {
return ProductExpr<ValueExpr<>, Derived>(a, b.derived());
}
friend constexpr QuotientExpr<ValueExpr<>, Derived> operator/(Index a, const BaseExpr& b) {
return QuotientExpr<ValueExpr<>, Derived>(a, b.derived());
}
template <int N>
constexpr AddExpr<Derived, ValueExpr<internal::FixedInt<N>>> operator+(internal::FixedInt<N>) const {
return AddExpr<Derived, ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >());
}
template <int N>
constexpr AddExpr<Derived, ValueExpr<internal::FixedInt<-N>>> operator-(internal::FixedInt<N>) const {
return AddExpr<Derived, ValueExpr<internal::FixedInt<-N> > >(derived(), ValueExpr<internal::FixedInt<-N> >());
}
template <int N>
constexpr ProductExpr<Derived, ValueExpr<internal::FixedInt<N>>> operator*(internal::FixedInt<N>) const {
return ProductExpr<Derived, ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >());
}
template <int N>
constexpr QuotientExpr<Derived, ValueExpr<internal::FixedInt<N>>> operator/(internal::FixedInt<N>) const {
return QuotientExpr<Derived, ValueExpr<internal::FixedInt<N> > >(derived(), ValueExpr<internal::FixedInt<N> >());
}
template <int N>
friend constexpr AddExpr<Derived, ValueExpr<internal::FixedInt<N>>> operator+(internal::FixedInt<N>,
const BaseExpr& b) {
return AddExpr<Derived, ValueExpr<internal::FixedInt<N> > >(b.derived(), ValueExpr<internal::FixedInt<N> >());
}
template <int N>
friend constexpr AddExpr<NegateExpr<Derived>, ValueExpr<internal::FixedInt<N>>> operator-(internal::FixedInt<N>,
const BaseExpr& b) {
return AddExpr<NegateExpr<Derived>, ValueExpr<internal::FixedInt<N> > >(-b.derived(),
ValueExpr<internal::FixedInt<N> >());
}
template <int N>
friend constexpr ProductExpr<ValueExpr<internal::FixedInt<N>>, Derived> operator*(internal::FixedInt<N>,
const BaseExpr& b) {
return ProductExpr<ValueExpr<internal::FixedInt<N> >, Derived>(ValueExpr<internal::FixedInt<N> >(), b.derived());
}
template <int N>
friend constexpr QuotientExpr<ValueExpr<internal::FixedInt<N>>, Derived> operator/(internal::FixedInt<N>,
const BaseExpr& b) {
return QuotientExpr<ValueExpr<internal::FixedInt<N> >, Derived>(ValueExpr<internal::FixedInt<N> >(), b.derived());
}
template <typename OtherDerived>
constexpr AddExpr<Derived, OtherDerived> operator+(const BaseExpr<OtherDerived>& b) const {
return AddExpr<Derived, OtherDerived>(derived(), b.derived());
}
template <typename OtherDerived>
constexpr AddExpr<Derived, NegateExpr<OtherDerived>> operator-(const BaseExpr<OtherDerived>& b) const {
return AddExpr<Derived, NegateExpr<OtherDerived> >(derived(), -b.derived());
}
template <typename OtherDerived>
constexpr ProductExpr<Derived, OtherDerived> operator*(const BaseExpr<OtherDerived>& b) const {
return ProductExpr<Derived, OtherDerived>(derived(), b.derived());
}
template <typename OtherDerived>
constexpr QuotientExpr<Derived, OtherDerived> operator/(const BaseExpr<OtherDerived>& b) const {
return QuotientExpr<Derived, OtherDerived>(derived(), b.derived());
}
};
template <typename T>
struct is_symbolic {
// BaseExpr has no conversion ctor, so we only have to check whether T can be statically cast to its base class
// BaseExpr<T>.
enum { value = internal::is_convertible<T, BaseExpr<T> >::value };
};
// A simple wrapper around an integral value to provide the eval method.
// We could also use a free-function symbolic_eval...
template <typename IndexType>
class ValueExpr : BaseExpr<ValueExpr<IndexType>> {
public:
constexpr ValueExpr() = default;
constexpr ValueExpr(IndexType val) : value_(val) {}
template <typename... Tags, typename... Types>
constexpr IndexType eval_impl(const SymbolValue<Tags, Types>&...) const {
return value_;
}
template <typename... Tags, typename... Types>
static constexpr IndexType eval_at_compile_time_impl(const SymbolValue<Tags, Types>&...) {
return IndexType(Undefined);
}
protected:
IndexType value_;
};
// Specialization for compile-time value,
// It is similar to ValueExpr(N) but this version helps the compiler to generate better code.
template <int N>
class ValueExpr<internal::FixedInt<N>> : public BaseExpr<ValueExpr<internal::FixedInt<N>>> {
public:
constexpr ValueExpr() = default;
constexpr ValueExpr(internal::FixedInt<N>) {}
template <typename... Tags, typename... Types>
constexpr Index eval_impl(const SymbolValue<Tags, Types>&...) const {
return Index(N);
}
template <typename... Tags, typename... Types>
static constexpr Index eval_at_compile_time_impl(const SymbolValue<Tags, Types>&...) {
return Index(N);
}
};
/** Represents the actual value of a symbol identified by its tag
*
* It is the return type of SymbolValue::operator=, and most of the time this is only way it is used.
*/
template <typename Tag, typename Type>
class SymbolValue : public BaseExpr<SymbolValue<Tag, Type>> {};
template <typename Tag>
class SymbolValue<Tag, Index> : public BaseExpr<SymbolValue<Tag, Index>> {
public:
constexpr SymbolValue() = default;
/** Default constructor from the value \a val */
constexpr SymbolValue(Index val) : value_(val) {}
/** \returns the stored value of the symbol */
constexpr Index value() const { return value_; }
/** \returns the stored value of the symbol at compile time, or Undefined if not known. */
static constexpr Index value_at_compile_time() { return Index(Undefined); }
template <typename... Tags, typename... Types>
constexpr Index eval_impl(const SymbolValue<Tags, Types>&...) const {
return value();
}
template <typename... Tags, typename... Types>
static constexpr Index eval_at_compile_time_impl(const SymbolValue<Tags, Types>&...) {
return value_at_compile_time();
}
protected:
Index value_;
};
template <typename Tag, int N>
class SymbolValue<Tag, internal::FixedInt<N>> : public BaseExpr<SymbolValue<Tag, internal::FixedInt<N>>> {
public:
constexpr SymbolValue() = default;
/** Default constructor from the value \a val */
constexpr SymbolValue(internal::FixedInt<N>){};
/** \returns the stored value of the symbol */
constexpr Index value() const { return static_cast<Index>(N); }
/** \returns the stored value of the symbol at compile time, or Undefined if not known. */
static constexpr Index value_at_compile_time() { return static_cast<Index>(N); }
template <typename... Tags, typename... Types>
constexpr Index eval_impl(const SymbolValue<Tags, Types>&...) const {
return value();
}
template <typename... Tags, typename... Types>
static constexpr Index eval_at_compile_time_impl(const SymbolValue<Tags, Types>&...) {
return value_at_compile_time();
}
};
// Find and return a symbol value based on the tag.
template <typename Tag, typename... Types>
struct EvalSymbolValueHelper;
// Empty base case, symbol not found.
template <typename Tag>
struct EvalSymbolValueHelper<Tag> {
static constexpr Index eval_impl() {
eigen_assert(false && "Symbol not found.");
return Index(Undefined);
}
static constexpr Index eval_at_compile_time_impl() { return Index(Undefined); }
};
// We found a symbol value matching the provided Tag!
template <typename Tag, typename Type, typename... OtherTypes>
struct EvalSymbolValueHelper<Tag, SymbolValue<Tag, Type>, OtherTypes...> {
static constexpr Index eval_impl(const SymbolValue<Tag, Type>& symbol, const OtherTypes&...) {
return symbol.value();
}
static constexpr Index eval_at_compile_time_impl(const SymbolValue<Tag, Type>& symbol, const OtherTypes&...) {
return symbol.value_at_compile_time();
}
};
// No symbol value in first value, recursive search starting with next.
template <typename Tag, typename T1, typename... OtherTypes>
struct EvalSymbolValueHelper<Tag, T1, OtherTypes...> {
static constexpr Index eval_impl(const T1&, const OtherTypes&... values) {
return EvalSymbolValueHelper<Tag, OtherTypes...>::eval_impl(values...);
}
static constexpr Index eval_at_compile_time_impl(const T1&, const OtherTypes&...) {
return EvalSymbolValueHelper<Tag, OtherTypes...>::eval_at_compile_time_impl(OtherTypes{}...);
}
};
/** Expression of a symbol uniquely identified by the template parameter type \c tag */
template <typename tag>
class SymbolExpr : public BaseExpr<SymbolExpr<tag> > {
public:
/** Alias to the template parameter \c tag */
typedef tag Tag;
constexpr SymbolExpr() = default;
/** Associate the value \a val to the given symbol \c *this, uniquely identified by its \c Tag.
*
* The returned object should be passed to ExprBase::eval() to evaluate a given expression with this specified
* runtime-time value.
*/
constexpr SymbolValue<Tag, Index> operator=(Index val) const { return SymbolValue<Tag, Index>(val); }
template <int N>
constexpr SymbolValue<Tag, internal::FixedInt<N>> operator=(internal::FixedInt<N>) const {
return SymbolValue<Tag, internal::FixedInt<N>>{internal::FixedInt<N>{}};
}
template <typename... Tags, typename... Types>
constexpr Index eval_impl(const SymbolValue<Tags, Types>&... values) const {
return EvalSymbolValueHelper<Tag, SymbolValue<Tags, Types>...>::eval_impl(values...);
}
template <typename... Tags, typename... Types>
static constexpr Index eval_at_compile_time_impl(const SymbolValue<Tags, Types>&...) {
return EvalSymbolValueHelper<Tag, SymbolValue<Tags, Types>...>::eval_at_compile_time_impl(
SymbolValue<Tags, Types>{}...);
}
};
template <typename Arg0>
class NegateExpr : public BaseExpr<NegateExpr<Arg0> > {
public:
constexpr NegateExpr() = default;
constexpr NegateExpr(const Arg0& arg0) : m_arg0(arg0) {}
template <typename... Tags, typename... Types>
constexpr Index eval_impl(const SymbolValue<Tags, Types>&... values) const {
return -m_arg0.eval_impl(values...);
}
template <typename... Tags, typename... Types>
static constexpr Index eval_at_compile_time_impl(const SymbolValue<Tags, Types>&...) {
constexpr Index v = Arg0::eval_at_compile_time_impl(SymbolValue<Tags, Types>{}...);
return (v == Undefined) ? Undefined : -v;
}
protected:
Arg0 m_arg0;
};
template <typename Arg0, typename Arg1>
class AddExpr : public BaseExpr<AddExpr<Arg0, Arg1> > {
public:
constexpr AddExpr() = default;
constexpr AddExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template <typename... Tags, typename... Types>
constexpr Index eval_impl(const SymbolValue<Tags, Types>&... values) const {
return m_arg0.eval_impl(values...) + m_arg1.eval_impl(values...);
}
template <typename... Tags, typename... Types>
static constexpr Index eval_at_compile_time_impl(const SymbolValue<Tags, Types>&...) {
constexpr Index v0 = Arg0::eval_at_compile_time_impl(SymbolValue<Tags, Types>{}...);
constexpr Index v1 = Arg1::eval_at_compile_time_impl(SymbolValue<Tags, Types>{}...);
return (v0 == Undefined || v1 == Undefined) ? Undefined : v0 + v1;
}
protected:
Arg0 m_arg0;
Arg1 m_arg1;
};
template <typename Arg0, typename Arg1>
class ProductExpr : public BaseExpr<ProductExpr<Arg0, Arg1> > {
public:
constexpr ProductExpr() = default;
constexpr ProductExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template <typename... Tags, typename... Types>
constexpr Index eval_impl(const SymbolValue<Tags, Types>&... values) const {
return m_arg0.eval_impl(values...) * m_arg1.eval_impl(values...);
}
template <typename... Tags, typename... Types>
static constexpr Index eval_at_compile_time_impl(const SymbolValue<Tags, Types>&...) {
constexpr Index v0 = Arg0::eval_at_compile_time_impl(SymbolValue<Tags, Types>{}...);
constexpr Index v1 = Arg1::eval_at_compile_time_impl(SymbolValue<Tags, Types>{}...);
return (v0 == Undefined || v1 == Undefined) ? Undefined : v0 * v1;
}
protected:
Arg0 m_arg0;
Arg1 m_arg1;
};
template <typename Arg0, typename Arg1>
class QuotientExpr : public BaseExpr<QuotientExpr<Arg0, Arg1> > {
public:
constexpr QuotientExpr() = default;
constexpr QuotientExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template <typename... Tags, typename... Types>
constexpr Index eval_impl(const SymbolValue<Tags, Types>&... values) const {
return m_arg0.eval_impl(values...) / m_arg1.eval_impl(values...);
}
template <typename... Tags, typename... Types>
static constexpr Index eval_at_compile_time_impl(const SymbolValue<Tags, Types>&...) {
constexpr Index v0 = Arg0::eval_at_compile_time_impl(SymbolValue<Tags, Types>{}...);
constexpr Index v1 = Arg1::eval_at_compile_time_impl(SymbolValue<Tags, Types>{}...);
return (v0 == Undefined || v1 == Undefined) ? Undefined : v0 / v1;
}
protected:
Arg0 m_arg0;
Arg1 m_arg1;
};
} // end namespace symbolic
} // end namespace Eigen
#endif // EIGEN_SYMBOLIC_INDEX_H