// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_CXX11_TENSOR_TENSOR_IO_H
#define EIGEN_CXX11_TENSOR_TENSOR_IO_H

// IWYU pragma: private
#include "./InternalHeaderCheck.h"

namespace Eigen {

struct TensorIOFormat;

namespace internal {
template <typename Tensor, std::size_t rank>
struct TensorPrinter;
}

struct TensorIOFormat {
  TensorIOFormat(const std::vector<std::string>& _separator, const std::vector<std::string>& _prefix,
                 const std::vector<std::string>& _suffix, int _precision = StreamPrecision, int _flags = 0,
                 const std::string& _tenPrefix = "", const std::string& _tenSuffix = "", const char _fill = ' ')
      : tenPrefix(_tenPrefix),
        tenSuffix(_tenSuffix),
        prefix(_prefix),
        suffix(_suffix),
        separator(_separator),
        fill(_fill),
        precision(_precision),
        flags(_flags) {
    init_spacer();
  }

  TensorIOFormat(int _precision = StreamPrecision, int _flags = 0, const std::string& _tenPrefix = "",
                 const std::string& _tenSuffix = "", const char _fill = ' ')
      : tenPrefix(_tenPrefix), tenSuffix(_tenSuffix), fill(_fill), precision(_precision), flags(_flags) {
    // default values of prefix, suffix and separator
    prefix = {"", "["};
    suffix = {"", "]"};
    separator = {", ", "\n"};

    init_spacer();
  }

  void init_spacer() {
    if ((flags & DontAlignCols)) return;
    spacer.resize(prefix.size());
    spacer[0] = "";
    int i = int(tenPrefix.length()) - 1;
    while (i >= 0 && tenPrefix[i] != '\n') {
      spacer[0] += ' ';
      i--;
    }

    for (std::size_t k = 1; k < prefix.size(); k++) {
      int j = int(prefix[k].length()) - 1;
      while (j >= 0 && prefix[k][j] != '\n') {
        spacer[k] += ' ';
        j--;
      }
    }
  }

  static inline const TensorIOFormat Numpy() {
    std::vector<std::string> prefix = {"", "["};
    std::vector<std::string> suffix = {"", "]"};
    std::vector<std::string> separator = {" ", "\n"};
    return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "[", "]");
  }

  static inline const TensorIOFormat Plain() {
    std::vector<std::string> separator = {" ", "\n", "\n", ""};
    std::vector<std::string> prefix = {""};
    std::vector<std::string> suffix = {""};
    return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "", "", ' ');
  }

  static inline const TensorIOFormat Native() {
    std::vector<std::string> separator = {", ", ",\n", "\n"};
    std::vector<std::string> prefix = {"", "{"};
    std::vector<std::string> suffix = {"", "}"};
    return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "{", "}", ' ');
  }

  static inline const TensorIOFormat Legacy() {
    TensorIOFormat LegacyFormat(StreamPrecision, 0, "", "", ' ');
    LegacyFormat.legacy_bit = true;
    return LegacyFormat;
  }

  std::string tenPrefix;
  std::string tenSuffix;
  std::vector<std::string> prefix;
  std::vector<std::string> suffix;
  std::vector<std::string> separator;
  char fill;
  int precision;
  int flags;
  std::vector<std::string> spacer{};
  bool legacy_bit = false;
};

template <typename T, int Layout, int rank>
class TensorWithFormat;
// specialize for Layout=ColMajor, Layout=RowMajor and rank=0.
template <typename T, int rank>
class TensorWithFormat<T, RowMajor, rank> {
 public:
  TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}

  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, RowMajor, rank>& wf) {
    // Evaluate the expression if needed
    typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
    TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
    Evaluator tensor(eval, DefaultDevice());
    tensor.evalSubExprsIfNeeded(NULL);
    internal::TensorPrinter<Evaluator, rank>::run(os, tensor, wf.t_format);
    // Cleanup.
    tensor.cleanup();
    return os;
  }

 protected:
  T t_tensor;
  TensorIOFormat t_format;
};

template <typename T, int rank>
class TensorWithFormat<T, ColMajor, rank> {
 public:
  TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}

  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, rank>& wf) {
    // Switch to RowMajor storage and print afterwards
    typedef typename T::Index IndexType;
    std::array<IndexType, rank> shuffle;
    std::array<IndexType, rank> id;
    std::iota(id.begin(), id.end(), IndexType(0));
    std::copy(id.begin(), id.end(), shuffle.rbegin());
    auto tensor_row_major = wf.t_tensor.swap_layout().shuffle(shuffle);

    // Evaluate the expression if needed
    typedef TensorEvaluator<const TensorForcedEvalOp<const decltype(tensor_row_major)>, DefaultDevice> Evaluator;
    TensorForcedEvalOp<const decltype(tensor_row_major)> eval = tensor_row_major.eval();
    Evaluator tensor(eval, DefaultDevice());
    tensor.evalSubExprsIfNeeded(NULL);
    internal::TensorPrinter<Evaluator, rank>::run(os, tensor, wf.t_format);
    // Cleanup.
    tensor.cleanup();
    return os;
  }

 protected:
  T t_tensor;
  TensorIOFormat t_format;
};

template <typename T>
class TensorWithFormat<T, ColMajor, 0> {
 public:
  TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}

  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, 0>& wf) {
    // Evaluate the expression if needed
    typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
    TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
    Evaluator tensor(eval, DefaultDevice());
    tensor.evalSubExprsIfNeeded(NULL);
    internal::TensorPrinter<Evaluator, 0>::run(os, tensor, wf.t_format);
    // Cleanup.
    tensor.cleanup();
    return os;
  }

 protected:
  T t_tensor;
  TensorIOFormat t_format;
};

namespace internal {
template <typename Tensor, std::size_t rank>
struct TensorPrinter {
  static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) {
    typedef std::remove_const_t<typename Tensor::Scalar> Scalar;
    typedef typename Tensor::Index IndexType;
    static const int layout = Tensor::Layout;
    // backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x
    // (dim(1)*dim(2)*...*dim(rank-1)).
    if (fmt.legacy_bit) {
      const IndexType total_size = internal::array_prod(_t.dimensions());
      if (total_size > 0) {
        const IndexType first_dim = Eigen::internal::array_get<0>(_t.dimensions());
        Map<const Array<Scalar, Dynamic, Dynamic, layout>> matrix(_t.data(), first_dim, total_size / first_dim);
        s << matrix;
        return;
      }
    }

    eigen_assert(layout == RowMajor);
    typedef std::conditional_t<is_same<Scalar, char>::value || is_same<Scalar, unsigned char>::value ||
                                   is_same<Scalar, numext::int8_t>::value || is_same<Scalar, numext::uint8_t>::value,
                               int,
                               std::conditional_t<is_same<Scalar, std::complex<char>>::value ||
                                                      is_same<Scalar, std::complex<unsigned char>>::value ||
                                                      is_same<Scalar, std::complex<numext::int8_t>>::value ||
                                                      is_same<Scalar, std::complex<numext::uint8_t>>::value,
                                                  std::complex<int>, const Scalar&>>
        PrintType;

    const IndexType total_size = array_prod(_t.dimensions());

    std::streamsize explicit_precision;
    if (fmt.precision == StreamPrecision) {
      explicit_precision = 0;
    } else if (fmt.precision == FullPrecision) {
      if (NumTraits<Scalar>::IsInteger) {
        explicit_precision = 0;
      } else {
        explicit_precision = significant_decimals_impl<Scalar>::run();
      }
    } else {
      explicit_precision = fmt.precision;
    }

    std::streamsize old_precision = 0;
    if (explicit_precision) old_precision = s.precision(explicit_precision);

    IndexType width = 0;

    bool align_cols = !(fmt.flags & DontAlignCols);
    if (align_cols) {
      // compute the largest width
      for (IndexType i = 0; i < total_size; i++) {
        std::stringstream sstr;
        sstr.copyfmt(s);
        sstr << static_cast<PrintType>(_t.data()[i]);
        width = std::max<IndexType>(width, IndexType(sstr.str().length()));
      }
    }
    std::streamsize old_width = s.width();
    char old_fill_character = s.fill();

    s << fmt.tenPrefix;
    for (IndexType i = 0; i < total_size; i++) {
      std::array<bool, rank> is_at_end{};
      std::array<bool, rank> is_at_begin{};

      // is the ith element the end of an coeff (always true), of a row, of a matrix, ...?
      for (std::size_t k = 0; k < rank; k++) {
        if ((i + 1) % (std::accumulate(_t.dimensions().rbegin(), _t.dimensions().rbegin() + k, 1,
                                       std::multiplies<IndexType>())) ==
            0) {
          is_at_end[k] = true;
        }
      }

      // is the ith element the begin of an coeff (always true), of a row, of a matrix, ...?
      for (std::size_t k = 0; k < rank; k++) {
        if (i % (std::accumulate(_t.dimensions().rbegin(), _t.dimensions().rbegin() + k, 1,
                                 std::multiplies<IndexType>())) ==
            0) {
          is_at_begin[k] = true;
        }
      }

      // do we have a line break?
      bool is_at_begin_after_newline = false;
      for (std::size_t k = 0; k < rank; k++) {
        if (is_at_begin[k]) {
          std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
          if (fmt.separator[separator_index].find('\n') != std::string::npos) {
            is_at_begin_after_newline = true;
          }
        }
      }

      bool is_at_end_before_newline = false;
      for (std::size_t k = 0; k < rank; k++) {
        if (is_at_end[k]) {
          std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
          if (fmt.separator[separator_index].find('\n') != std::string::npos) {
            is_at_end_before_newline = true;
          }
        }
      }

      std::stringstream suffix, prefix, separator;
      for (std::size_t k = 0; k < rank; k++) {
        std::size_t suffix_index = (k < fmt.suffix.size()) ? k : fmt.suffix.size() - 1;
        if (is_at_end[k]) {
          suffix << fmt.suffix[suffix_index];
        }
      }
      for (std::size_t k = 0; k < rank; k++) {
        std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
        if (is_at_end[k] &&
            (!is_at_end_before_newline || fmt.separator[separator_index].find('\n') != std::string::npos)) {
          separator << fmt.separator[separator_index];
        }
      }
      for (std::size_t k = 0; k < rank; k++) {
        std::size_t spacer_index = (k < fmt.spacer.size()) ? k : fmt.spacer.size() - 1;
        if (i != 0 && is_at_begin_after_newline && (!is_at_begin[k] || k == 0)) {
          prefix << fmt.spacer[spacer_index];
        }
      }
      for (int k = rank - 1; k >= 0; k--) {
        std::size_t prefix_index = (static_cast<std::size_t>(k) < fmt.prefix.size()) ? k : fmt.prefix.size() - 1;
        if (is_at_begin[k]) {
          prefix << fmt.prefix[prefix_index];
        }
      }

      s << prefix.str();
      if (width) {
        s.fill(fmt.fill);
        s.width(width);
        s << std::right;
      }
      s << _t.data()[i];
      s << suffix.str();
      if (i < total_size - 1) {
        s << separator.str();
      }
    }
    s << fmt.tenSuffix;
    if (explicit_precision) s.precision(old_precision);
    if (width) {
      s.fill(old_fill_character);
      s.width(old_width);
    }
  }
};

template <typename Tensor>
struct TensorPrinter<Tensor, 0> {
  static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) {
    typedef typename Tensor::Scalar Scalar;

    std::streamsize explicit_precision;
    if (fmt.precision == StreamPrecision) {
      explicit_precision = 0;
    } else if (fmt.precision == FullPrecision) {
      if (NumTraits<Scalar>::IsInteger) {
        explicit_precision = 0;
      } else {
        explicit_precision = significant_decimals_impl<Scalar>::run();
      }
    } else {
      explicit_precision = fmt.precision;
    }

    std::streamsize old_precision = 0;
    if (explicit_precision) old_precision = s.precision(explicit_precision);

    s << fmt.tenPrefix << _t.coeff(0) << fmt.tenSuffix;
    if (explicit_precision) s.precision(old_precision);
  }
};

}  // end namespace internal
template <typename T>
std::ostream& operator<<(std::ostream& s, const TensorBase<T, ReadOnlyAccessors>& t) {
  s << t.format(TensorIOFormat::Plain());
  return s;
}
}  // end namespace Eigen

#endif  // EIGEN_CXX11_TENSOR_TENSOR_IO_H
