blob: b0003e833b4cbfaf1923b8d729c7afd897b30ab0 [file] [log] [blame]
// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2021 MariaDB Corporation Ab
package org.mariadb.jdbc.integration;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.IOException;
import java.io.InputStream;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.Duration;
import java.time.Instant;
import java.util.Locale;
import java.util.Optional;
import java.util.Properties;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.extension.*;
import org.junit.jupiter.api.function.Executable;
import org.mariadb.jdbc.*;
import org.mariadb.jdbc.export.HaMode;
import org.mariadb.jdbc.export.SslMode;
import org.mariadb.jdbc.integration.tools.TcpProxy;
public class Common {
public static Connection sharedConn;
public static Connection sharedConnBinary;
public static String hostname;
public static int port;
public static String user;
public static String password;
public static String database;
public static String defaultOther;
public static TcpProxy proxy;
public static String mDefUrl;
private static Instant initialTest;
static {
try (InputStream inputStream =
Common.class.getClassLoader().getResourceAsStream("conf.properties")) {
Properties prop = new Properties();
prop.load(inputStream);
String val = System.getenv("TEST_REQUIRE_TLS");
if ("1".equals(val)) {
String cert = System.getenv("TEST_DB_SERVER_CERT");
defaultOther = "&sslMode=verify-full&serverSslCert=" + cert;
} else {
defaultOther = get("DB_OTHER", prop);
}
hostname = get("DB_HOST", prop);
user = get("DB_USER", prop);
port = Integer.parseInt(get("DB_PORT", prop));
password = get("DB_PASSWORD", prop);
database = get("DB_DATABASE", prop);
mDefUrl =
password == null || password.isEmpty()
? String.format(
"jdbc:mariadb://%s:%s/%s?user=%s%s", hostname, port, database, user, defaultOther)
: String.format(
"jdbc:mariadb://%s:%s/%s?user=%s&password=%s%s",
hostname, port, database, user, password, defaultOther);
} catch (IOException io) {
io.printStackTrace();
}
}
private static String get(String name, Properties prop) {
String val = System.getenv("TEST_" + name);
if (val == null) val = System.getProperty("TEST_" + name);
if (val == null) val = prop.getProperty(name);
return val;
}
@BeforeAll
public static void beforeAll() throws Exception {
sharedConn = (Connection) DriverManager.getConnection(mDefUrl);
String binUrl = mDefUrl + (mDefUrl.indexOf("?") > 0 ? "&" : "?") + "useServerPrepStmts=true";
sharedConnBinary = (Connection) DriverManager.getConnection(binUrl);
}
@AfterAll
public static void afterEAll() throws SQLException {
if (sharedConn != null) {
sharedConn.close();
sharedConnBinary.close();
}
if (proxy != null) {
proxy.forceClose();
}
}
public static boolean isMariaDBServer() {
return sharedConn.getContext().getVersion().isMariaDBServer();
}
public static boolean hasCapability(long capability) {
return sharedConn.getContext().hasClientCapability(capability);
}
public static boolean runLongTest() {
String runLongTest = System.getenv("RUN_LONG_TEST");
if (runLongTest != null) {
return Boolean.parseBoolean(runLongTest);
}
return false;
}
public static boolean isXpand() {
String srv = System.getenv("srv");
if (srv != null) {
return "xpand".equals(srv);
}
return sharedConn.getContext().getVersion().isMariaDBServer()
&& sharedConn.getContext().getVersion().getQualifier().toLowerCase().contains("xpand");
}
public static boolean minVersion(int major, int minor, int patch) {
return sharedConn.getContext().getVersion().versionGreaterOrEqual(major, minor, patch);
}
public static boolean exactVersion(int major, int minor, int patch) {
return sharedConn.getContext().getVersion().getMajorVersion() == major
&& sharedConn.getContext().getVersion().getMinorVersion() == minor
&& sharedConn.getContext().getVersion().getPatchVersion() == patch;
}
public static Connection createCon() throws SQLException {
return (Connection) DriverManager.getConnection(mDefUrl);
}
public static void createSequenceTables() throws SQLException {
Statement stmt = sharedConn.createStatement();
boolean seq10_ok = false;
boolean seq10_000_ok = false;
try {
ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM sequence_1_to_10");
if (rs.next() && rs.getInt(1) == 10) {
seq10_ok = true;
}
} catch (SQLException e) {
// eat
}
if (!seq10_ok) {
stmt.execute("DROP TABLE IF EXISTS sequence_1_to_10");
stmt.execute("CREATE TABLE sequence_1_to_10 (t1 int)");
stmt.execute("insert into sequence_1_to_10 VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10)");
}
try {
ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM sequence_1_to_10000");
if (rs.next() && rs.getInt(1) == 10_000) {
seq10_000_ok = true;
}
} catch (SQLException e) {
// eat
}
if (!seq10_000_ok) {
stmt.execute("DROP TABLE IF EXISTS sequence_1_to_10000");
stmt.execute("CREATE TABLE sequence_1_to_10000 (t1 int)");
try (PreparedStatement prep =
sharedConnBinary.prepareStatement("INSERT INTO sequence_1_to_10000 VALUES (?)")) {
for (int i = 1; i <= 10_000; i++) {
prep.setInt(1, i);
prep.addBatch();
}
prep.executeBatch();
}
}
}
public Connection createProxyCon(HaMode mode, String opts) throws SQLException {
Configuration conf = Configuration.parse(mDefUrl);
HostAddress hostAddress = conf.addresses().get(0);
try {
proxy = new TcpProxy(hostAddress.host, hostAddress.port);
} catch (IOException i) {
throw new SQLException("proxy error", i);
}
String url = mDefUrl.replaceAll("//([^/]*)/", "//localhost:" + proxy.getLocalPort() + "/");
if (mode != HaMode.NONE) {
url =
url.replaceAll(
"jdbc:mariadb:", "jdbc:mariadb:" + mode.name().toLowerCase(Locale.ROOT) + ":");
}
if (conf.sslMode() == SslMode.VERIFY_FULL) {
url = url.replaceAll("sslMode=verify-full", "sslMode=verify-ca");
}
return (Connection) DriverManager.getConnection(url + opts);
}
public static boolean haveSsl() throws SQLException {
Statement stmt = sharedConn.createStatement();
ResultSet rs = stmt.executeQuery("show variables like '%ssl%'");
// while (rs.next()) {
// System.out.println(rs.getString(1) + ":" + rs.getString(2));
// }
try {
rs = stmt.executeQuery("select @@have_ssl");
assertTrue(rs.next());
return "YES".equals(rs.getString(1));
} catch (SQLException e) {
return false;
}
}
public boolean isWindows() {
return System.getProperty("os.name").toLowerCase().contains("win");
}
public void cancelForVersion(int major, int minor) {
String dbVersion = sharedConn.getMetaData().getDatabaseProductVersion();
Assumptions.assumeFalse(dbVersion.startsWith(major + "." + minor));
}
public static Connection createCon(String option) throws SQLException {
return (Connection) DriverManager.getConnection(mDefUrl + "&" + option);
}
public static Connection createCon(String option, Integer sslPort) throws SQLException {
Configuration conf = Configuration.parse(mDefUrl + "&" + option);
if (sslPort != null) {
for (HostAddress hostAddress : conf.addresses()) {
hostAddress.port = sslPort;
}
}
return Driver.connect(conf);
}
@AfterEach
public void afterEach1() throws SQLException {
sharedConn.isValid(2000);
sharedConnBinary.isValid(2000);
}
public static int getMaxAllowedPacket() throws SQLException {
java.sql.Statement st = sharedConn.createStatement();
ResultSet rs = st.executeQuery("select @@max_allowed_packet");
assertTrue(rs.next());
return rs.getInt(1);
}
public static void assertThrowsContains(
Class<? extends Exception> expectedType, Executable executable, String expected) {
Exception e = Assertions.assertThrows(expectedType, executable);
Assertions.assertTrue(e.getMessage().contains(expected), "real message:" + e.getMessage());
}
@RegisterExtension public Extension watcher = new Follow();
private static class Follow implements TestWatcher, BeforeEachCallback {
@Override
public void testDisabled(ExtensionContext context, Optional<String> reason) {
System.out.printf(
" - Disabled %s: with reason :- %s%n",
context.getDisplayName(), reason.orElse("No reason"));
}
@Override
public void testAborted(ExtensionContext context, Throwable cause) {
System.out.printf(" - Aborted %s: %n", context.getDisplayName());
}
@Override
public void testFailed(ExtensionContext context, Throwable cause) {
System.out.printf(" ✗ Failed %s: %n", context.getDisplayName());
}
@Override
public void testSuccessful(ExtensionContext context) {
System.out.printf(
" ✓ %s: %sms%n",
context.getDisplayName(), Duration.between(initialTest, Instant.now()).toMillis());
}
@Override
public void beforeEach(ExtensionContext extensionContext) {
initialTest = Instant.now();
}
}
}