blob: 1898894f664d4fa1932d2b1a2d77d61ce80171e7 [file] [log] [blame]
// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2021 MariaDB Corporation Ab
package org.mariadb.jdbc.client.socket.impl;
import com.sun.jna.LastErrorException;
import com.sun.jna.Native;
import com.sun.jna.Platform;
import com.sun.jna.Structure;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
/** Unix IPC socket */
public class UnixDomainSocket extends Socket {
private static final int AF_UNIX = 1;
private static final int SOCK_STREAM = 1;
private static final int PROTOCOL = 0;
static {
if (!Platform.isWindows() && !Platform.isWindowsCE()) {
Native.register("c");
}
}
private final AtomicBoolean closeLock = new AtomicBoolean();
private final SockAddr sockaddr;
private final int fd;
private InputStream is;
private OutputStream os;
private boolean connected;
/**
* Constructor
*
* @param path unix path
* @throws IOException if any error occurs
*/
public UnixDomainSocket(String path) throws IOException {
if (Platform.isWindows()) {
throw new IOException("Unix domain sockets are not supported on Windows");
}
sockaddr = new SockAddr(path);
closeLock.set(false);
try {
fd = socket(AF_UNIX, SOCK_STREAM, PROTOCOL);
} catch (LastErrorException lee) {
throw new IOException("native socket() failed : " + formatError(lee));
}
}
/**
* creates an endpoint for communication and returns a file descriptor that refers to that
* endpoint. see https://man7.org/linux/man-pages/man2/socket.2.html
*
* @param domain domain
* @param type type
* @param protocol protocol
* @return file descriptor
* @throws LastErrorException if any error occurs
*/
public static native int socket(int domain, int type, int protocol) throws LastErrorException;
/**
* Connect socket
*
* @param sockfd file descriptor
* @param sockaddr socket address
* @param addrlen address length
* @return zero on success. -1 on error
* @throws LastErrorException if error occurs
*/
public static native int connect(int sockfd, SockAddr sockaddr, int addrlen)
throws LastErrorException;
/**
* Receive a message from a socket
*
* @param fd file descriptor
* @param buffer buffer
* @param count length
* @param flags flag. see https://man7.org/linux/man-pages/man2/recvmsg.2.html
* @return zero on success. -1 on error
* @throws LastErrorException if error occurs
*/
public static native int recv(int fd, byte[] buffer, int count, int flags)
throws LastErrorException;
/**
* Send a message to a socket
*
* @param fd file descriptor
* @param buffer buffer
* @param count length
* @param flags flag. see https://man7.org/linux/man-pages/man2/sendmsg.2.html
* @return zero on success. -1 on error
* @throws LastErrorException if error occurs
*/
public static native int send(int fd, byte[] buffer, int count, int flags)
throws LastErrorException;
/**
* Close socket
*
* @param fd file descriptor
* @return zero on success. -1 on error
* @throws LastErrorException if error occurs
*/
public static native int close(int fd) throws LastErrorException;
/**
* return a description of the error code passed in the argument errnum.
*
* @param errno error pointer
* @return error description
*/
public static native String strerror(int errno);
private static String formatError(LastErrorException lee) {
try {
return strerror(lee.getErrorCode());
} catch (Throwable t) {
return lee.getMessage();
}
}
@Override
public boolean isConnected() {
return connected;
}
@Override
public void close() throws IOException {
if (!closeLock.getAndSet(true)) {
try {
close(fd);
} catch (LastErrorException lee) {
throw new IOException("native close() failed : " + formatError(lee));
}
connected = false;
}
}
public void connect(SocketAddress endpoint, int timeout) throws IOException {
try {
int ret = connect(fd, sockaddr, sockaddr.size());
if (ret != 0) {
throw new IOException(strerror(Native.getLastError()));
}
connected = true;
} catch (LastErrorException lee) {
try {
close();
} catch (IOException e) {
}
throw new IOException("native connect() failed : " + formatError(lee));
}
is = new UnixSocketInputStream();
os = new UnixSocketOutputStream();
}
public InputStream getInputStream() {
return is;
}
public OutputStream getOutputStream() {
return os;
}
public void setTcpNoDelay(boolean b) {
// do nothing
}
public void setKeepAlive(boolean b) {
// do nothing
}
public void setSoLinger(boolean b, int i) {
// do nothing
}
public void setSoTimeout(int timeout) {
// do nothing
}
public void shutdownInput() {
// do nothing
}
public void shutdownOutput() {
// do nothing
}
/** Socket address */
public static class SockAddr extends Structure {
/** socket family */
public short sun_family = AF_UNIX;
/** pathname */
public byte[] sun_path;
/**
* Constructor.
*
* @param sunPath path
*/
public SockAddr(String sunPath) {
byte[] arr = sunPath.getBytes();
sun_path = new byte[arr.length + 1];
System.arraycopy(arr, 0, sun_path, 0, Math.min(sun_path.length - 1, arr.length));
allocateMemory();
}
@Override
protected java.util.List<String> getFieldOrder() {
return Arrays.asList("sun_family", "sun_path");
}
}
class UnixSocketInputStream extends InputStream {
@Override
public int read(byte[] bytesEntry, int off, int len) throws IOException {
try {
return recv(fd, bytesEntry, len, 0);
} catch (LastErrorException lee) {
throw new IOException("native read() failed : " + formatError(lee));
}
}
@Override
public int read() throws IOException {
byte[] bytes = new byte[1];
int bytesRead = read(bytes);
if (bytesRead == 0) {
return -1;
}
return bytes[0] & 0xff;
}
@Override
public int read(byte[] bytes) throws IOException {
return read(bytes, 0, bytes.length);
}
}
class UnixSocketOutputStream extends OutputStream {
@Override
public void write(byte[] bytesEntry, int off, int len) throws IOException {
int bytes;
try {
bytes = send(fd, bytesEntry, len, 0);
if (bytes != len) {
throw new IOException("can't write " + len + "bytes");
}
} catch (LastErrorException lee) {
throw new IOException("native write() failed : " + formatError(lee));
}
}
@Override
public void write(int value) throws IOException {}
}
}