blob: 9a3681b356a8f2f14a2dbf5c6f92bcca7a70b3a6 [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef CXX11_SRC_FIXEDPOINT_FIXEDPOINTTYPES_H_
#define CXX11_SRC_FIXEDPOINT_FIXEDPOINTTYPES_H_
#include <cmath>
#include <iostream>
namespace Eigen {
// The mantissa part of the fixed point representation. See
// go/tensorfixedpoint for details
struct QInt8;
struct QUInt8;
struct QInt16;
struct QUInt16;
struct QInt32;
template <>
struct NumTraits<QInt8> : GenericNumTraits<int8_t> {};
template <>
struct NumTraits<QUInt8> : GenericNumTraits<uint8_t> {};
template <>
struct NumTraits<QInt16> : GenericNumTraits<int16_t> {};
template <>
struct NumTraits<QUInt16> : GenericNumTraits<uint16_t> {};
template <>
struct NumTraits<QInt32> : GenericNumTraits<int32_t> {};
namespace internal {
template <>
struct scalar_product_traits<QInt32, double> {
enum {
// Cost = NumTraits<T>::MulCost,
Defined = 1
};
typedef QInt32 ReturnType;
};
}
// Wrap the 8bit int into a QInt8 struct instead of using a typedef to prevent
// the compiler from silently type cast the mantissa into a bigger or a smaller
// representation.
struct QInt8 {
QInt8() {}
QInt8(const int8_t v) : value(v) {}
QInt8(const QInt32 v);
operator int() const { return static_cast<int>(value); }
int8_t value;
};
struct QUInt8 {
QUInt8() {}
QUInt8(const uint8_t v) : value(v) {}
QUInt8(const QInt32 v);
operator int() const { return static_cast<int>(value); }
uint8_t value;
};
struct QInt16 {
QInt16() {}
QInt16(const int16_t v) : value(v) {}
QInt16(const QInt32 v);
operator int() const { return static_cast<int>(value); }
int16_t value;
};
struct QUInt16 {
QUInt16() {}
QUInt16(const uint16_t v) : value(v) {}
QUInt16(const QInt32 v);
operator int() const { return static_cast<int>(value); }
uint16_t value;
};
struct QInt32 {
QInt32() {}
QInt32(const int8_t v) : value(v) {}
QInt32(const int32_t v) : value(v) {}
QInt32(const uint32_t v) : value(static_cast<int32_t>(v)) {}
QInt32(const QInt8 v) : value(v.value) {}
QInt32(const float v) : value(static_cast<int32_t>(lrint(v))) {}
#ifdef EIGEN_MAKING_DOCS
// Workaround to fix build on PPC.
QInt32(unsigned long v) : value(v) {}
#endif
operator float() const { return static_cast<float>(value); }
int32_t value;
};
EIGEN_STRONG_INLINE QInt8::QInt8(const QInt32 v)
: value(static_cast<int8_t>(
v.value > 127 ? 127 : (v.value < -128 ? -128 : v.value))) {}
EIGEN_STRONG_INLINE QUInt8::QUInt8(const QInt32 v)
: value(static_cast<uint8_t>(v.value > 255 ? 255
: (v.value < 0 ? 0 : v.value))) {
}
EIGEN_STRONG_INLINE QInt16::QInt16(const QInt32 v)
: value(static_cast<int16_t>(
v.value > 32767 ? 32767 : (v.value < -32768 ? -32768 : v.value))) {}
EIGEN_STRONG_INLINE QUInt16::QUInt16(const QInt32 v)
: value(static_cast<uint16_t>(
v.value > 65535 ? 65535 : (v.value < 0 ? 0 : v.value))) {}
// Basic widening 8-bit operations: This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QInt8 b) {
return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QUInt8 b) {
return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QInt8 a, const QInt8 b) {
return QInt32(static_cast<int32_t>(a.value) + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt8 a, const QInt8 b) {
return QInt32(static_cast<int32_t>(a.value) - static_cast<int32_t>(b.value));
}
// Basic widening 16-bit operations: This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QInt16 b) {
return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QUInt16 b) {
return QInt32(static_cast<int32_t>(a.value) * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QInt16 a, const QInt16 b) {
return QInt32(static_cast<int32_t>(a.value) + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt16 a, const QInt16 b) {
return QInt32(static_cast<int32_t>(a.value) - static_cast<int32_t>(b.value));
}
// Mixed QInt32 op QInt8 operations. This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt8 b) {
return QInt32(a.value + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) + b.value);
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt8 b) {
return QInt32(a.value - static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) - b.value);
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt8 b) {
return QInt32(a.value * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) * b.value);
}
// Mixed QInt32 op QInt16 operations. This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt16 b) {
return QInt32(a.value + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) + b.value);
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt16 b) {
return QInt32(a.value - static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) - b.value);
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt16 b) {
return QInt32(a.value * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) * b.value);
}
// Mixed QInt32 op QUInt8 operations. This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QUInt8 b) {
return QInt32(a.value + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QUInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) + b.value);
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QUInt8 b) {
return QInt32(a.value - static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QUInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) - b.value);
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QUInt8 b) {
return QInt32(a.value * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QUInt8 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) * b.value);
}
// Mixed QInt32 op QUInt16 operations. This will be vectorized in future CLs.
EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QUInt16 b) {
return QInt32(a.value + static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator+(const QUInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) + b.value);
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QUInt16 b) {
return QInt32(a.value - static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator-(const QUInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) - b.value);
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QUInt16 b) {
return QInt32(a.value * static_cast<int32_t>(b.value));
}
EIGEN_STRONG_INLINE QInt32 operator*(const QUInt16 a, const QInt32 b) {
return QInt32(static_cast<int32_t>(a.value) * b.value);
}
// Basic arithmetic operations on QInt32, which behaves like a int32_t.
EIGEN_STRONG_INLINE QInt32 operator+(const QInt32 a, const QInt32 b) {
return a.value + b.value;
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a, const QInt32 b) {
return a.value - b.value;
}
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const QInt32 b) {
return a.value * b.value;
}
EIGEN_STRONG_INLINE QInt32 operator/(const QInt32 a, const QInt32 b) {
return a.value / b.value;
}
EIGEN_STRONG_INLINE QInt32& operator+=(QInt32& a, const QInt32 b) {
a.value += b.value;
return a;
}
EIGEN_STRONG_INLINE QInt32& operator-=(QInt32& a, const QInt32 b) {
a.value -= b.value;
return a;
}
EIGEN_STRONG_INLINE QInt32& operator*=(QInt32& a, const QInt32 b) {
a.value *= b.value;
return a;
}
EIGEN_STRONG_INLINE QInt32& operator/=(QInt32& a, const QInt32 b) {
a.value /= b.value;
return a;
}
EIGEN_STRONG_INLINE QInt32 operator-(const QInt32 a) { return -a.value; }
// Scaling QInt32 by double. We do the arithmetic in double because
// float only has 23 bits of mantissa, so casting QInt32 to float might reduce
// accuracy by discarding up to 7 (least significant) bits.
EIGEN_STRONG_INLINE QInt32 operator*(const QInt32 a, const double b) {
return static_cast<int32_t>(lrint(static_cast<double>(a.value) * b));
}
EIGEN_STRONG_INLINE QInt32 operator*(const double a, const QInt32 b) {
return static_cast<int32_t>(lrint(a * static_cast<double>(b.value)));
}
EIGEN_STRONG_INLINE QInt32& operator*=(QInt32& a, const double b) {
a.value = static_cast<int32_t>(lrint(static_cast<double>(a.value) * b));
return a;
}
// Comparisons
EIGEN_STRONG_INLINE bool operator==(const QInt8 a, const QInt8 b) {
return a.value == b.value;
}
EIGEN_STRONG_INLINE bool operator==(const QUInt8 a, const QUInt8 b) {
return a.value == b.value;
}
EIGEN_STRONG_INLINE bool operator==(const QInt16 a, const QInt16 b) {
return a.value == b.value;
}
EIGEN_STRONG_INLINE bool operator==(const QUInt16 a, const QUInt16 b) {
return a.value == b.value;
}
EIGEN_STRONG_INLINE bool operator==(const QInt32 a, const QInt32 b) {
return a.value == b.value;
}
EIGEN_STRONG_INLINE bool operator<(const QInt8 a, const QInt8 b) {
return a.value < b.value;
}
EIGEN_STRONG_INLINE bool operator<(const QUInt8 a, const QUInt8 b) {
return a.value < b.value;
}
EIGEN_STRONG_INLINE bool operator<(const QInt16 a, const QInt16 b) {
return a.value < b.value;
}
EIGEN_STRONG_INLINE bool operator<(const QUInt16 a, const QUInt16 b) {
return a.value < b.value;
}
EIGEN_STRONG_INLINE bool operator<(const QInt32 a, const QInt32 b) {
return a.value < b.value;
}
EIGEN_STRONG_INLINE bool operator>(const QInt8 a, const QInt8 b) {
return a.value > b.value;
}
EIGEN_STRONG_INLINE bool operator>(const QUInt8 a, const QUInt8 b) {
return a.value > b.value;
}
EIGEN_STRONG_INLINE bool operator>(const QInt16 a, const QInt16 b) {
return a.value > b.value;
}
EIGEN_STRONG_INLINE bool operator>(const QUInt16 a, const QUInt16 b) {
return a.value > b.value;
}
EIGEN_STRONG_INLINE bool operator>(const QInt32 a, const QInt32 b) {
return a.value > b.value;
}
EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt8 a) {
os << static_cast<int>(a.value);
return os;
}
EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QUInt8 a) {
os << static_cast<int>(a.value);
return os;
}
EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt16 a) {
os << static_cast<int>(a.value);
return os;
}
EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QUInt16 a) {
os << static_cast<int>(a.value);
return os;
}
EIGEN_STRONG_INLINE std::ostream& operator<<(std::ostream& os, QInt32 a) {
os << a.value;
return os;
}
} // namespace Eigen
#endif // CXX11_SRC_FIXEDPOINT_FIXEDPOINTTYPES_H_