blob: d0a1a69d73e6e63661d09f3894d86ab282dfff92 [file] [log] [blame]
// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2021 MariaDB Corporation Ab
package org.mariadb.jdbc.message.client;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.mariadb.jdbc.ServerPreparedStatement;
import org.mariadb.jdbc.client.Context;
import org.mariadb.jdbc.client.socket.Writer;
import org.mariadb.jdbc.client.util.Parameter;
import org.mariadb.jdbc.client.util.Parameters;
import org.mariadb.jdbc.export.MaxAllowedPacketException;
import org.mariadb.jdbc.export.Prepare;
import org.mariadb.jdbc.message.server.PrepareResultPacket;
/**
* batch execution. This relies on COM_STMT_BULK_EXECUTE see
* https://mariadb.com/kb/en/com_stmt_bulk_execute/
*/
public final class BulkExecutePacket implements RedoableWithPrepareClientMessage {
private List<Parameters> batchParameterList;
private final String command;
private final ServerPreparedStatement prep;
private Prepare prepareResult;
/**
* Constructor
*
* @param prepareResult prepare result
* @param batchParameterList batch parameter list
* @param command sql command
* @param prep object creator
*/
public BulkExecutePacket(
Prepare prepareResult,
List<Parameters> batchParameterList,
String command,
ServerPreparedStatement prep) {
this.batchParameterList = batchParameterList;
this.prepareResult = prepareResult;
this.command = command;
this.prep = prep;
}
public void saveParameters() {
List<Parameters> savedList = new ArrayList<>(batchParameterList.size());
for (Parameters parameterList : batchParameterList) {
savedList.add(parameterList.clone());
}
this.batchParameterList = savedList;
}
public int encode(Writer writer, Context context, Prepare newPrepareResult)
throws IOException, SQLException {
int statementId =
(newPrepareResult != null && newPrepareResult.getStatementId() != -1)
? newPrepareResult.getStatementId()
: (this.prepareResult != null ? this.prepareResult.getStatementId() : -1);
Iterator<Parameters> paramIterator = batchParameterList.iterator();
Parameters parameters = paramIterator.next();
int parameterCount = parameters.size();
@SuppressWarnings("rawtypes")
Parameter[] parameterHeaderType = new Parameter[parameterCount];
// set header type
for (int i = 0; i < parameterCount; i++) {
parameterHeaderType[i] = parameters.get(i);
}
byte[] lastCmdData = null;
int bulkPacketNo = 0;
// Implementation After writing a bunch of parameter to buffer is marked. then : - when writing
// next bunch of parameter, if buffer grow more than max_allowed_packet, send buffer up to mark,
// then create a new packet with current bunch of data - if a bunch of parameter data type
// changes
// send buffer up to mark, then create a new packet with new data type.
// Problem remains if a bunch of parameter is bigger than max_allowed_packet
main_loop:
while (true) {
bulkPacketNo++;
writer.initPacket();
writer.writeByte(0xfa); // COM_STMT_BULK_EXECUTE
writer.writeInt(statementId);
writer.writeShort((short) 128); // always SEND_TYPES_TO_SERVER
for (int i = 0; i < parameterCount; i++) {
writer.writeShort((short) parameterHeaderType[i].getBinaryEncodeType());
}
if (lastCmdData != null) {
if (writer.throwMaxAllowedLength(lastCmdData.length)) {
throw new MaxAllowedPacketException(
"query size is >= to max_allowed_packet", writer.getCmdLength() != 0);
}
writer.writeBytes(lastCmdData);
writer.mark();
lastCmdData = null;
if (!paramIterator.hasNext()) {
break;
}
parameters = paramIterator.next();
}
parameter_loop:
while (true) {
for (int i = 0; i < parameterCount; i++) {
Parameter param = parameters.get(i);
if (param.isNull()) {
writer.writeByte(0x01); // value is null
} else {
writer.writeByte(0x00); // value follow
param.encodeBinary(writer);
}
}
if (!writer.isMarked() && writer.hasFlushed()) {
// parameter were too big to fit in a MySQL packet
// need to finish the packet separately
writer.flush();
if (!paramIterator.hasNext()) {
break main_loop;
}
parameters = paramIterator.next();
// reset header type
for (int j = 0; j < parameterCount; j++) {
parameterHeaderType[j] = parameters.get(j);
}
break;
}
if (writer.isMarked() && writer.throwMaxAllowedLength(writer.pos())) {
// for max_allowed_packet < 16Mb
// packet length was ok at last mark, but won't with new data
writer.flushBufferStopAtMark();
writer.mark();
lastCmdData = writer.resetMark();
break;
}
writer.mark();
if (writer.bufIsDataAfterMark()) {
// flush has been done
lastCmdData = writer.resetMark();
break;
}
if (!paramIterator.hasNext()) {
break main_loop;
}
parameters = paramIterator.next();
// ensure type has not changed
for (int i = 0; i < parameterCount; i++) {
if (parameterHeaderType[i].getBinaryEncodeType()
!= parameters.get(i).getBinaryEncodeType()
&& !parameters.get(i).isNull()) {
writer.flush();
// reset header type
for (int j = 0; j < parameterCount; j++) {
parameterHeaderType[j] = parameters.get(j);
}
break parameter_loop;
}
}
}
}
writer.flush();
return bulkPacketNo;
}
public int batchUpdateLength() {
return batchParameterList.size();
}
public String getCommand() {
return command;
}
public ServerPreparedStatement prep() {
return prep;
}
public boolean binaryProtocol() {
return true;
}
public String description() {
return command;
}
public void setPrepareResult(PrepareResultPacket prepareResult) {
this.prepareResult = prepareResult;
}
}