// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2021 MariaDB Corporation Ab

package org.mariadb.jdbc.message.client;

import java.io.IOException;
import java.io.InputStream;
import java.sql.SQLException;
import java.util.concurrent.locks.ReentrantLock;
import org.mariadb.jdbc.BasePreparedStatement;
import org.mariadb.jdbc.ServerPreparedStatement;
import org.mariadb.jdbc.Statement;
import org.mariadb.jdbc.client.Completion;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.ReadableByteBuf;
import org.mariadb.jdbc.client.socket.Reader;
import org.mariadb.jdbc.client.socket.Writer;
import org.mariadb.jdbc.client.util.Parameter;
import org.mariadb.jdbc.client.util.Parameters;
import org.mariadb.jdbc.export.ExceptionFactory;
import org.mariadb.jdbc.export.Prepare;
import org.mariadb.jdbc.message.ClientMessage;
import org.mariadb.jdbc.message.server.CachedPrepareResultPacket;
import org.mariadb.jdbc.message.server.ErrorPacket;
import org.mariadb.jdbc.message.server.PrepareResultPacket;
import org.mariadb.jdbc.plugin.codec.ByteArrayCodec;

/**
 * Send a client COM_STMT_PREPARE + COM_STMT_EXECUTE packets see
 * https://mariadb.com/kb/en/com_stmt_prepare/
 */
public final class PrepareExecutePacket implements RedoableWithPrepareClientMessage {
  private final String sql;
  private Parameters parameters;
  private final ServerPreparedStatement prep;
  private PrepareResultPacket prepareResult;
  private InputStream localInfileInputStream;
  /**
   * Construct prepare packet
   *
   * @param sql sql
   * @param parameters parameter
   * @param prep prepare
   * @param localInfileInputStream local infile input stream
   */
  public PrepareExecutePacket(
      String sql,
      Parameters parameters,
      ServerPreparedStatement prep,
      InputStream localInfileInputStream) {
    this.sql = sql;
    this.parameters = parameters;
    this.prep = prep;
    this.localInfileInputStream = localInfileInputStream;
    this.prepareResult = null;
  }

  @Override
  public int encode(Writer writer, Context context, Prepare newPrepareResult)
      throws IOException, SQLException {
    int statementId = -1;
    if (newPrepareResult == null) {

      writer.initPacket();
      writer.writeByte(0x16);
      writer.writeString(this.sql);
      writer.flushPipeline();
    } else {
      statementId = newPrepareResult.getStatementId();
    }
    int parameterCount = parameters.size();

    // send long data value in separate packet
    for (int i = 0; i < parameterCount; i++) {
      Parameter p = parameters.get(i);
      if (!p.isNull() && p.canEncodeLongData()) {
        new LongDataPacket(statementId, p, i).encode(writer, context);
      }
    }

    writer.initPacket();
    writer.writeByte(0x17);
    writer.writeInt(statementId);
    writer.writeByte(0x00); // NO CURSOR
    writer.writeInt(1); // Iteration pos

    if (parameterCount > 0) {

      // create null bitmap and reserve place in writer
      int nullCount = (parameterCount + 7) / 8;
      byte[] nullBitsBuffer = new byte[nullCount];
      int initialPos = writer.pos();
      writer.pos(initialPos + nullCount);

      // Send Parameter type flag
      writer.writeByte(0x01);

      // Store types of parameters in first package that is sent to the server.
      for (int i = 0; i < parameterCount; i++) {
        Parameter p = parameters.get(i);
        writer.writeByte(p.getBinaryEncodeType());
        writer.writeByte(0);
        if (p.isNull()) {
          nullBitsBuffer[i / 8] |= (1 << (i % 8));
        }
      }

      // write nullBitsBuffer in reserved place
      writer.writeBytesAtPos(nullBitsBuffer, initialPos);

      // send not null parameter, not long data
      for (int i = 0; i < parameterCount; i++) {
        Parameter p = parameters.get(i);
        if (!p.isNull() && !p.canEncodeLongData()) {
          p.encodeBinary(writer);
        }
      }
    }

    writer.flush();
    return (newPrepareResult == null) ? 2 : 1;
  }

  @Override
  public Completion readPacket(
      Statement stmt,
      int fetchSize,
      long maxRows,
      int resultSetConcurrency,
      int resultSetType,
      boolean closeOnCompletion,
      Reader reader,
      Writer writer,
      Context context,
      ExceptionFactory exceptionFactory,
      ReentrantLock lock,
      boolean traceEnable,
      ClientMessage message)
      throws IOException, SQLException {
    if (this.prepareResult == null) {
      ReadableByteBuf buf = reader.readReusablePacket(traceEnable);
      // *********************************************************************************************************
      // * ERROR response
      // *********************************************************************************************************
      if (buf.getUnsignedByte()
          == 0xff) { // force current status to in transaction to ensure rollback/commit, since
        // command may
        // have issue a transaction
        ErrorPacket errorPacket = new ErrorPacket(buf, context);
        throw exceptionFactory
            .withSql(this.description())
            .create(
                errorPacket.getMessage(), errorPacket.getSqlState(), errorPacket.getErrorCode());
      }
      if (context.getConf().useServerPrepStmts()
          && context.getConf().cachePrepStmts()
          && sql.length() < 8192) {
        PrepareResultPacket prepare = new CachedPrepareResultPacket(buf, reader, context);
        PrepareResultPacket previousCached =
            (PrepareResultPacket)
                context
                    .getPrepareCache()
                    .put(
                        sql,
                        prepare,
                        stmt instanceof ServerPreparedStatement
                            ? (ServerPreparedStatement) stmt
                            : null);
        if (stmt != null) {
          ((BasePreparedStatement) stmt)
              .setPrepareResult(previousCached != null ? previousCached : prepare);
        }
        this.prepareResult = previousCached != null ? previousCached : prepare;
        return this.prepareResult;
      }
      PrepareResultPacket prepareResult = new PrepareResultPacket(buf, reader, context);
      if (stmt != null) {
        ((BasePreparedStatement) stmt).setPrepareResult(prepareResult);
      }
      this.prepareResult = prepareResult;
      return prepareResult;
    } else {
      return RedoableWithPrepareClientMessage.super.readPacket(
          stmt,
          fetchSize,
          maxRows,
          resultSetConcurrency,
          resultSetType,
          closeOnCompletion,
          reader,
          writer,
          context,
          exceptionFactory,
          lock,
          traceEnable,
          message);
    }
  }

  public void saveParameters() {
    this.parameters = this.parameters.clone();
  }

  @Override
  public void ensureReplayable(Context context) throws IOException, SQLException {
    int parameterCount = parameters.size();
    for (int i = 0; i < parameterCount; i++) {
      Parameter p = parameters.get(i);
      if (!p.isNull() && p.canEncodeLongData()) {
        this.parameters.set(
            i, new org.mariadb.jdbc.codec.Parameter<>(ByteArrayCodec.INSTANCE, p.encodeData()));
      }
    }
  }

  public boolean canSkipMeta() {
    return true;
  }

  @Override
  public String description() {
    return "PREPARE + EXECUTE " + sql;
  }

  public int batchUpdateLength() {
    return 1;
  }

  public String getCommand() {
    return sql;
  }

  public InputStream getLocalInfileInputStream() {
    return localInfileInputStream;
  }

  public ServerPreparedStatement prep() {
    return prep;
  }

  public boolean binaryProtocol() {
    return true;
  }

  public boolean validateLocalFileName(String fileName, Context context) {
    return ClientMessage.validateLocalFileName(sql, parameters, fileName, context);
  }

  public void setPrepareResult(PrepareResultPacket prepareResult) {
    this.prepareResult = prepareResult;
  }
}
