blob: b1928c4a1630de94396227c222da1c02fb300dd8 [file] [log] [blame]
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_TENSOR_TENSOR_IO_H
#define EIGEN_CXX11_TENSOR_TENSOR_IO_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
namespace Eigen {
struct TensorIOFormat;
namespace internal {
template <typename Tensor, std::size_t rank, typename Format, typename EnableIf = void>
struct TensorPrinter;
}
template <typename Derived_>
struct TensorIOFormatBase {
using Derived = Derived_;
TensorIOFormatBase(const std::vector<std::string>& separator, const std::vector<std::string>& prefix,
const std::vector<std::string>& suffix, int precision = StreamPrecision, int flags = 0,
const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ')
: tenPrefix(tenPrefix),
tenSuffix(tenSuffix),
prefix(prefix),
suffix(suffix),
separator(separator),
fill(fill),
precision(precision),
flags(flags) {
init_spacer();
}
void init_spacer() {
if ((flags & DontAlignCols)) return;
spacer.resize(prefix.size());
spacer[0] = "";
int i = int(tenPrefix.length()) - 1;
while (i >= 0 && tenPrefix[i] != '\n') {
spacer[0] += ' ';
i--;
}
for (std::size_t k = 1; k < prefix.size(); k++) {
int j = int(prefix[k].length()) - 1;
while (j >= 0 && prefix[k][j] != '\n') {
spacer[k] += ' ';
j--;
}
}
}
std::string tenPrefix;
std::string tenSuffix;
std::vector<std::string> prefix;
std::vector<std::string> suffix;
std::vector<std::string> separator;
char fill;
int precision;
int flags;
std::vector<std::string> spacer{};
};
struct TensorIOFormatNumpy : public TensorIOFormatBase<TensorIOFormatNumpy> {
using Base = TensorIOFormatBase<TensorIOFormatNumpy>;
TensorIOFormatNumpy()
: Base(/*separator=*/{" ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision,
/*flags=*/0, /*tenPrefix=*/"[", /*tenSuffix=*/"]") {}
};
struct TensorIOFormatNative : public TensorIOFormatBase<TensorIOFormatNative> {
using Base = TensorIOFormatBase<TensorIOFormatNative>;
TensorIOFormatNative()
: Base(/*separator=*/{", ", ",\n", "\n"}, /*prefix=*/{"", "{"}, /*suffix=*/{"", "}"},
/*precision=*/StreamPrecision, /*flags=*/0, /*tenPrefix=*/"{", /*tenSuffix=*/"}") {}
};
struct TensorIOFormatPlain : public TensorIOFormatBase<TensorIOFormatPlain> {
using Base = TensorIOFormatBase<TensorIOFormatPlain>;
TensorIOFormatPlain()
: Base(/*separator=*/{" ", "\n", "\n", ""}, /*prefix=*/{""}, /*suffix=*/{""}, /*precision=*/StreamPrecision,
/*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {}
};
struct TensorIOFormatLegacy : public TensorIOFormatBase<TensorIOFormatLegacy> {
using Base = TensorIOFormatBase<TensorIOFormatLegacy>;
TensorIOFormatLegacy()
: Base(/*separator=*/{", ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision,
/*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {}
};
struct TensorIOFormat : public TensorIOFormatBase<TensorIOFormat> {
using Base = TensorIOFormatBase<TensorIOFormat>;
TensorIOFormat(const std::vector<std::string>& separator, const std::vector<std::string>& prefix,
const std::vector<std::string>& suffix, int precision = StreamPrecision, int flags = 0,
const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ')
: Base(separator, prefix, suffix, precision, flags, tenPrefix, tenSuffix, fill) {}
static inline const TensorIOFormatNumpy Numpy() { return TensorIOFormatNumpy{}; }
static inline const TensorIOFormatPlain Plain() { return TensorIOFormatPlain{}; }
static inline const TensorIOFormatNative Native() { return TensorIOFormatNative{}; }
static inline const TensorIOFormatLegacy Legacy() { return TensorIOFormatLegacy{}; }
};
template <typename T, int Layout, int rank, typename Format>
class TensorWithFormat;
// specialize for Layout=ColMajor, Layout=RowMajor and rank=0.
template <typename T, int rank, typename Format>
class TensorWithFormat<T, RowMajor, rank, Format> {
public:
TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, RowMajor, rank, Format>& wf) {
// Evaluate the expression if needed
typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
Evaluator tensor(eval, DefaultDevice());
tensor.evalSubExprsIfNeeded(NULL);
internal::TensorPrinter<Evaluator, rank, Format>::run(os, tensor, wf.t_format);
// Cleanup.
tensor.cleanup();
return os;
}
protected:
T t_tensor;
Format t_format;
};
template <typename T, int rank, typename Format>
class TensorWithFormat<T, ColMajor, rank, Format> {
public:
TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, rank, Format>& wf) {
// Switch to RowMajor storage and print afterwards
typedef typename T::Index IndexType;
std::array<IndexType, rank> shuffle;
std::array<IndexType, rank> id;
std::iota(id.begin(), id.end(), IndexType(0));
std::copy(id.begin(), id.end(), shuffle.rbegin());
auto tensor_row_major = wf.t_tensor.swap_layout().shuffle(shuffle);
// Evaluate the expression if needed
typedef TensorEvaluator<const TensorForcedEvalOp<const decltype(tensor_row_major)>, DefaultDevice> Evaluator;
TensorForcedEvalOp<const decltype(tensor_row_major)> eval = tensor_row_major.eval();
Evaluator tensor(eval, DefaultDevice());
tensor.evalSubExprsIfNeeded(NULL);
internal::TensorPrinter<Evaluator, rank, Format>::run(os, tensor, wf.t_format);
// Cleanup.
tensor.cleanup();
return os;
}
protected:
T t_tensor;
Format t_format;
};
template <typename T, typename Format>
class TensorWithFormat<T, ColMajor, 0, Format> {
public:
TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, 0, Format>& wf) {
// Evaluate the expression if needed
typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
Evaluator tensor(eval, DefaultDevice());
tensor.evalSubExprsIfNeeded(NULL);
internal::TensorPrinter<Evaluator, 0, Format>::run(os, tensor, wf.t_format);
// Cleanup.
tensor.cleanup();
return os;
}
protected:
T t_tensor;
Format t_format;
};
namespace internal {
// Default scalar printer.
template <typename Scalar, typename Format, typename EnableIf = void>
struct ScalarPrinter {
static void run(std::ostream& stream, const Scalar& scalar, const Format& fmt) { stream << scalar; }
};
template <typename Scalar>
struct ScalarPrinter<Scalar, TensorIOFormatNumpy, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNumpy& fmt) {
stream << numext::real(scalar) << "+" << numext::imag(scalar) << "j";
}
};
template <typename Scalar>
struct ScalarPrinter<Scalar, TensorIOFormatNative, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNative& fmt) {
stream << "{" << numext::real(scalar) << ", " << numext::imag(scalar) << "}";
}
};
template <typename Tensor, std::size_t rank, typename Format, typename EnableIf>
struct TensorPrinter {
using Scalar = std::remove_const_t<typename Tensor::Scalar>;
static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
typedef typename Tensor::Index IndexType;
eigen_assert(Tensor::Layout == RowMajor);
typedef std::conditional_t<is_same<Scalar, char>::value || is_same<Scalar, unsigned char>::value ||
is_same<Scalar, numext::int8_t>::value || is_same<Scalar, numext::uint8_t>::value,
int,
std::conditional_t<is_same<Scalar, std::complex<char>>::value ||
is_same<Scalar, std::complex<unsigned char>>::value ||
is_same<Scalar, std::complex<numext::int8_t>>::value ||
is_same<Scalar, std::complex<numext::uint8_t>>::value,
std::complex<int>, const Scalar&>>
PrintType;
const IndexType total_size = array_prod(tensor.dimensions());
std::streamsize explicit_precision;
if (fmt.precision == StreamPrecision) {
explicit_precision = 0;
} else if (fmt.precision == FullPrecision) {
if (NumTraits<Scalar>::IsInteger) {
explicit_precision = 0;
} else {
explicit_precision = significant_decimals_impl<Scalar>::run();
}
} else {
explicit_precision = fmt.precision;
}
std::streamsize old_precision = 0;
if (explicit_precision) old_precision = s.precision(explicit_precision);
IndexType width = 0;
bool align_cols = !(fmt.flags & DontAlignCols);
if (align_cols) {
// compute the largest width
for (IndexType i = 0; i < total_size; i++) {
std::stringstream sstr;
sstr.copyfmt(s);
ScalarPrinter<Scalar, Format>::run(sstr, static_cast<PrintType>(tensor.data()[i]), fmt);
width = std::max<IndexType>(width, IndexType(sstr.str().length()));
}
}
s << fmt.tenPrefix;
for (IndexType i = 0; i < total_size; i++) {
std::array<bool, rank> is_at_end{};
std::array<bool, rank> is_at_begin{};
// is the ith element the end of an coeff (always true), of a row, of a matrix, ...?
for (std::size_t k = 0; k < rank; k++) {
if ((i + 1) % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
std::multiplies<IndexType>())) ==
0) {
is_at_end[k] = true;
}
}
// is the ith element the begin of an coeff (always true), of a row, of a matrix, ...?
for (std::size_t k = 0; k < rank; k++) {
if (i % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
std::multiplies<IndexType>())) ==
0) {
is_at_begin[k] = true;
}
}
// do we have a line break?
bool is_at_begin_after_newline = false;
for (std::size_t k = 0; k < rank; k++) {
if (is_at_begin[k]) {
std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
if (fmt.separator[separator_index].find('\n') != std::string::npos) {
is_at_begin_after_newline = true;
}
}
}
bool is_at_end_before_newline = false;
for (std::size_t k = 0; k < rank; k++) {
if (is_at_end[k]) {
std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
if (fmt.separator[separator_index].find('\n') != std::string::npos) {
is_at_end_before_newline = true;
}
}
}
std::stringstream suffix, prefix, separator;
for (std::size_t k = 0; k < rank; k++) {
std::size_t suffix_index = (k < fmt.suffix.size()) ? k : fmt.suffix.size() - 1;
if (is_at_end[k]) {
suffix << fmt.suffix[suffix_index];
}
}
for (std::size_t k = 0; k < rank; k++) {
std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
if (is_at_end[k] &&
(!is_at_end_before_newline || fmt.separator[separator_index].find('\n') != std::string::npos)) {
separator << fmt.separator[separator_index];
}
}
for (std::size_t k = 0; k < rank; k++) {
std::size_t spacer_index = (k < fmt.spacer.size()) ? k : fmt.spacer.size() - 1;
if (i != 0 && is_at_begin_after_newline && (!is_at_begin[k] || k == 0)) {
prefix << fmt.spacer[spacer_index];
}
}
for (int k = rank - 1; k >= 0; k--) {
std::size_t prefix_index = (static_cast<std::size_t>(k) < fmt.prefix.size()) ? k : fmt.prefix.size() - 1;
if (is_at_begin[k]) {
prefix << fmt.prefix[prefix_index];
}
}
s << prefix.str();
// So we don't mess around with formatting, output scalar to a string stream, and adjust the width/fill manually.
std::stringstream sstr;
sstr.copyfmt(s);
ScalarPrinter<Scalar, Format>::run(sstr, static_cast<PrintType>(tensor.data()[i]), fmt);
std::string scalar_str = sstr.str();
IndexType scalar_width = scalar_str.length();
if (width && scalar_width < width) {
std::string filler;
for (IndexType i = scalar_width; i < width; ++i) {
filler.push_back(fmt.fill);
}
s << filler;
}
s << scalar_str;
s << suffix.str();
if (i < total_size - 1) {
s << separator.str();
}
}
s << fmt.tenSuffix;
if (explicit_precision) s.precision(old_precision);
}
};
template <typename Tensor, std::size_t rank>
struct TensorPrinter<Tensor, rank, TensorIOFormatLegacy, std::enable_if_t<rank != 0>> {
using Format = TensorIOFormatLegacy;
using Scalar = std::remove_const_t<typename Tensor::Scalar>;
static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
typedef typename Tensor::Index IndexType;
// backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x
// (dim(1)*dim(2)*...*dim(rank-1)).
const IndexType total_size = internal::array_prod(tensor.dimensions());
if (total_size > 0) {
const IndexType first_dim = Eigen::internal::array_get<0>(tensor.dimensions());
Map<const Array<Scalar, Dynamic, Dynamic, Tensor::Layout>> matrix(tensor.data(), first_dim, total_size / first_dim);
s << matrix;
return;
}
}
};
template <typename Tensor, typename Format>
struct TensorPrinter<Tensor, 0, Format> {
static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
using Scalar = std::remove_const_t<typename Tensor::Scalar>;
std::streamsize explicit_precision;
if (fmt.precision == StreamPrecision) {
explicit_precision = 0;
} else if (fmt.precision == FullPrecision) {
if (NumTraits<Scalar>::IsInteger) {
explicit_precision = 0;
} else {
explicit_precision = significant_decimals_impl<Scalar>::run();
}
} else {
explicit_precision = fmt.precision;
}
std::streamsize old_precision = 0;
if (explicit_precision) old_precision = s.precision(explicit_precision);
s << fmt.tenPrefix;
ScalarPrinter<Scalar, Format>::run(s, tensor.coeff(0), fmt);
s << fmt.tenSuffix;
if (explicit_precision) s.precision(old_precision);
}
};
} // end namespace internal
template <typename T>
std::ostream& operator<<(std::ostream& s, const TensorBase<T, ReadOnlyAccessors>& t) {
s << t.format(TensorIOFormat::Plain());
return s;
}
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_IO_H