blob: 8aefbffffd0f2644899cf939cd1b9d9c5425b448 [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
#define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
namespace Eigen {
// Can't use std::pairs on cuda devices
template <typename Index> struct IndexPair {
constexpr EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair() : first(0), second(0) {}
constexpr EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair(Index f, Index s) : first(f), second(s) {}
EIGEN_DEVICE_FUNC void set(IndexPair<Index> val) {
first = val.first;
second = val.second;
}
Index first;
Index second;
};
} // end namespace Eigen
#if defined(EIGEN_HAS_CONSTEXPR) && defined(EIGEN_HAS_VARIADIC_TEMPLATES)
#define EIGEN_HAS_INDEX_LIST
namespace Eigen {
/** \internal
*
* \class TensorIndexList
* \ingroup CXX11_Tensor_Module
*
* \brief Set of classes used to encode a set of Tensor dimensions/indices.
*
* The indices in the list can be known at compile time or at runtime. A mix
* of static and dynamic indices can also be provided if needed. The tensor
* code will attempt to take advantage of the indices that are known at
* compile time to optimize the code it generates.
*
* This functionality requires a c++11 compliant compiler. If your compiler
* is older you need to use arrays of indices instead.
*
* Several examples are provided in the cxx11_tensor_index_list.cpp file.
*
* \sa Tensor
*/
template <DenseIndex n>
struct type2index {
static const DenseIndex value = n;
constexpr operator DenseIndex() const { return n; }
void set(DenseIndex val) {
eigen_assert(val == n);
}
};
// This can be used with IndexPairList to get compile-time constant pairs,
// such as IndexPairList<type2indexpair<1,2>, type2indexpair<3,4>>().
template <DenseIndex f, DenseIndex s>
struct type2indexpair {
static const DenseIndex first = f;
static const DenseIndex second = s;
constexpr EIGEN_DEVICE_FUNC operator IndexPair<DenseIndex>() const {
return IndexPair<DenseIndex>(f, s);
}
EIGEN_DEVICE_FUNC void set(const IndexPair<DenseIndex>& val) {
eigen_assert(val.first == f);
eigen_assert(val.second == s);
}
};
namespace internal {
template <typename T>
void update_value(T& val, DenseIndex new_val) {
val = new_val;
}
template <DenseIndex n>
void update_value(type2index<n>& val, DenseIndex new_val) {
val.set(new_val);
}
template <typename T>
void update_value(T& val, IndexPair<DenseIndex> new_val) {
val = new_val;
}
template <DenseIndex f, DenseIndex s>
void update_value(type2indexpair<f, s>& val, IndexPair<DenseIndex> new_val) {
val.set(new_val);
}
template <typename T>
struct is_compile_time_constant {
static constexpr bool value = false;
};
// Next four are is_compile_time_constant for type2index.
template <DenseIndex idx>
struct is_compile_time_constant<type2index<idx> > {
static constexpr bool value = true;
};
template <DenseIndex idx>
struct is_compile_time_constant<const type2index<idx> > {
static constexpr bool value = true;
};
template <DenseIndex idx>
struct is_compile_time_constant<type2index<idx>& > {
static constexpr bool value = true;
};
template <DenseIndex idx>
struct is_compile_time_constant<const type2index<idx>& > {
static constexpr bool value = true;
};
// Next four are is_compile_time_constant for type2indexpair.
template <DenseIndex f, DenseIndex s>
struct is_compile_time_constant<type2indexpair<f, s> > {
static constexpr bool value = true;
};
template <DenseIndex f, DenseIndex s>
struct is_compile_time_constant<const type2indexpair<f, s> > {
static constexpr bool value = true;
};
template <DenseIndex f, DenseIndex s>
struct is_compile_time_constant<type2indexpair<f, s>& > {
static constexpr bool value = true;
};
template <DenseIndex f, DenseIndex s>
struct is_compile_time_constant<const type2indexpair<f, s>& > {
static constexpr bool value = true;
};
template <DenseIndex Idx, typename ValueT>
struct tuple_coeff {
template <typename... T>
static constexpr ValueT get(const DenseIndex i, const std::tuple<T...>& t) {
return (i == Idx ? std::get<Idx>(t) : tuple_coeff<Idx-1, ValueT>::get(i, t));
}
template <typename... T>
static void set(const DenseIndex i, std::tuple<T...>& t, const ValueT value) {
if (i == Idx) {
update_value(std::get<Idx>(t), value);
} else {
tuple_coeff<Idx-1, ValueT>::set(i, t, value);
}
}
template <typename... T>
static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) {
return ((i == Idx) & is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value) ||
tuple_coeff<Idx-1, ValueT>::value_known_statically(i, t);
}
template <typename... T>
static constexpr bool values_up_to_known_statically(const std::tuple<T...>& t) {
return is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value &&
tuple_coeff<Idx-1, ValueT>::values_up_to_known_statically(t);
}
template <typename... T>
static constexpr bool values_up_to_statically_known_to_increase(const std::tuple<T...>& t) {
return is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value &&
is_compile_time_constant<typename std::tuple_element<Idx-1, std::tuple<T...> >::type>::value &&
std::get<Idx>(t) > std::get<Idx-1>(t) &&
tuple_coeff<Idx-1, ValueT>::values_up_to_statically_known_to_increase(t);
}
};
template <typename ValueT>
struct tuple_coeff<0, ValueT> {
template <typename... T>
static constexpr ValueT get(const DenseIndex i, const std::tuple<T...>& t) {
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
return std::get<0>(t);
}
template <typename... T>
static void set(const DenseIndex i, std::tuple<T...>& t, const ValueT value) {
eigen_assert (i == 0);
update_value(std::get<0>(t), value);
}
template <typename... T>
static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) {
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
return is_compile_time_constant<typename std::tuple_element<0, std::tuple<T...> >::type>::value & (i == 0);
}
template <typename... T>
static constexpr bool values_up_to_known_statically(const std::tuple<T...>& t) {
return is_compile_time_constant<typename std::tuple_element<0, std::tuple<T...> >::type>::value;
}
template <typename... T>
static constexpr bool values_up_to_statically_known_to_increase(const std::tuple<T...>& t) {
return true;
}
};
} // namespace internal
template<typename FirstType, typename... OtherTypes>
struct IndexList : std::tuple<FirstType, OtherTypes...> {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::get(i, *this);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::set(i, *this, value);
}
constexpr IndexList(const std::tuple<FirstType, OtherTypes...>& other) : std::tuple<FirstType, OtherTypes...>(other) { }
constexpr IndexList() : std::tuple<FirstType, OtherTypes...>() { }
constexpr bool value_known_statically(const DenseIndex i) const {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::value_known_statically(i, *this);
}
constexpr bool all_values_known_statically() const {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::values_up_to_known_statically(*this);
}
constexpr bool values_statically_known_to_increase() const {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::values_up_to_statically_known_to_increase(*this);
}
};
template<typename FirstType, typename... OtherTypes>
constexpr IndexList<FirstType, OtherTypes...> make_index_list(FirstType val1, OtherTypes... other_vals) {
return std::make_tuple(val1, other_vals...);
}
template<typename FirstType, typename... OtherTypes>
struct IndexPairList : std::tuple<FirstType, OtherTypes...> {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr IndexPair<DenseIndex> operator[] (const DenseIndex i) const {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1, IndexPair<DenseIndex>>::get(i, *this);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const IndexPair<DenseIndex> value) {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...>>::value-1, IndexPair<DenseIndex> >::set(i, *this, value);
}
constexpr IndexPairList(const std::tuple<FirstType, OtherTypes...>& other) : std::tuple<FirstType, OtherTypes...>(other) { }
constexpr IndexPairList() : std::tuple<FirstType, OtherTypes...>() { }
constexpr bool value_known_statically(const DenseIndex i) const {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::value_known_statically(i, *this);
}
};
namespace internal {
template<typename FirstType, typename... OtherTypes> size_t array_prod(const IndexList<FirstType, OtherTypes...>& sizes) {
size_t result = 1;
for (int i = 0; i < array_size<IndexList<FirstType, OtherTypes...> >::value; ++i) {
result *= sizes[i];
}
return result;
};
template<typename FirstType, typename... OtherTypes> struct array_size<IndexList<FirstType, OtherTypes...> > {
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
};
template<typename FirstType, typename... OtherTypes> struct array_size<const IndexList<FirstType, OtherTypes...> > {
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
};
template<typename FirstType, typename... OtherTypes> struct array_size<IndexPairList<FirstType, OtherTypes...> > {
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
};
template<typename FirstType, typename... OtherTypes> struct array_size<const IndexPairList<FirstType, OtherTypes...> > {
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
};
template<DenseIndex n, typename FirstType, typename... OtherTypes> constexpr DenseIndex array_get(IndexList<FirstType, OtherTypes...>& a) {
return std::get<n>(a);
}
template<DenseIndex n, typename FirstType, typename... OtherTypes> constexpr DenseIndex array_get(const IndexList<FirstType, OtherTypes...>& a) {
return std::get<n>(a);
}
template <typename T>
struct index_known_statically {
constexpr bool operator() (DenseIndex) const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_known_statically<IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i);
}
};
template <typename FirstType, typename... OtherTypes>
struct index_known_statically<const IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i);
}
};
template <typename T>
struct all_indices_known_statically {
constexpr bool operator() () const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct all_indices_known_statically<IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() () const {
return IndexList<FirstType, OtherTypes...>().all_values_known_statically();
}
};
template <typename FirstType, typename... OtherTypes>
struct all_indices_known_statically<const IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() () const {
return IndexList<FirstType, OtherTypes...>().all_values_known_statically();
}
};
template <typename T>
struct indices_statically_known_to_increase {
constexpr bool operator() () const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct indices_statically_known_to_increase<IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() () const {
return IndexList<FirstType, OtherTypes...>().values_statically_known_to_increase();
}
};
template <typename FirstType, typename... OtherTypes>
struct indices_statically_known_to_increase<const IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() () const {
return IndexList<FirstType, OtherTypes...>().values_statically_known_to_increase();
}
};
template <typename Tx>
struct index_statically_eq {
constexpr bool operator() (DenseIndex, DenseIndex) const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_eq<IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &&
IndexList<FirstType, OtherTypes...>()[i] == value;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_eq<const IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &&
IndexList<FirstType, OtherTypes...>()[i] == value;
}
};
template <typename T>
struct index_statically_ne {
constexpr bool operator() (DenseIndex, DenseIndex) const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_ne<IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &&
IndexList<FirstType, OtherTypes...>()[i] != value;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_ne<const IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &&
IndexList<FirstType, OtherTypes...>()[i] != value;
}
};
template <typename T>
struct index_statically_gt {
constexpr bool operator() (DenseIndex, DenseIndex) const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_gt<IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &&
IndexList<FirstType, OtherTypes...>()[i] > value;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_gt<const IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &&
IndexList<FirstType, OtherTypes...>()[i] > value;
}
};
template <typename T>
struct index_statically_lt {
constexpr bool operator() (DenseIndex, DenseIndex) const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_lt<IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &&
IndexList<FirstType, OtherTypes...>()[i] < value;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_lt<const IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &&
IndexList<FirstType, OtherTypes...>()[i] < value;
}
};
template <typename Tx>
struct index_pair_first_statically_eq {
constexpr bool operator() (DenseIndex, DenseIndex) const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_pair_first_statically_eq<IndexPairList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexPairList<FirstType, OtherTypes...>().value_known_statically(i) &&
IndexPairList<FirstType, OtherTypes...>()[i].first == value;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_pair_first_statically_eq<const IndexPairList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexPairList<FirstType, OtherTypes...>().value_known_statically(i) &&
IndexPairList<FirstType, OtherTypes...>()[i].first == value;
}
};
template <typename Tx>
struct index_pair_second_statically_eq {
constexpr bool operator() (DenseIndex, DenseIndex) const {
return false;
}
};
template <typename secondType, typename... OtherTypes>
struct index_pair_second_statically_eq<IndexPairList<secondType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexPairList<secondType, OtherTypes...>().value_known_statically(i) &&
IndexPairList<secondType, OtherTypes...>()[i].second == value;
}
};
template <typename secondType, typename... OtherTypes>
struct index_pair_second_statically_eq<const IndexPairList<secondType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexPairList<secondType, OtherTypes...>().value_known_statically(i) &&
IndexPairList<secondType, OtherTypes...>()[i].second == value;
}
};
} // end namespace internal
} // end namespace Eigen
#else
namespace Eigen {
namespace internal {
// No C++11 support
template <typename T>
struct index_known_statically {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex) const{
return false;
}
};
template <typename T>
struct all_indices_known_statically {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() () const {
return false;
}
};
template <typename T>
struct indices_statically_known_to_increase {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() () const {
return false;
}
};
template <typename T>
struct index_statically_eq {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
return false;
}
};
template <typename T>
struct index_statically_ne {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
return false;
}
};
template <typename T>
struct index_statically_gt {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
return false;
}
};
template <typename T>
struct index_statically_lt {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
return false;
}
};
} // end namespace internal
} // end namespace Eigen
#endif
#endif // EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H