// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2021 MariaDB Corporation Ab

package org.mariadb.jdbc.client.socket.impl;

import com.sun.jna.LastErrorException;
import com.sun.jna.Native;
import com.sun.jna.Platform;
import com.sun.jna.Structure;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;

/** Unix IPC socket */
public class UnixDomainSocket extends Socket {

  private static final int AF_UNIX = 1;
  private static final int SOCK_STREAM = 1;
  private static final int PROTOCOL = 0;

  static {
    if (!Platform.isWindows() && !Platform.isWindowsCE()) {
      Native.register("c");
    }
  }

  private final AtomicBoolean closeLock = new AtomicBoolean();
  private final SockAddr sockaddr;
  private final int fd;
  private InputStream is;
  private OutputStream os;
  private boolean connected;

  /**
   * Constructor
   *
   * @param path unix path
   * @throws IOException if any error occurs
   */
  public UnixDomainSocket(String path) throws IOException {
    if (Platform.isWindows()) {
      throw new IOException("Unix domain sockets are not supported on Windows");
    }
    sockaddr = new SockAddr(path);
    closeLock.set(false);
    try {
      fd = socket(AF_UNIX, SOCK_STREAM, PROTOCOL);
    } catch (LastErrorException lee) {
      throw new IOException("native socket() failed : " + formatError(lee));
    }
  }

  /**
   * creates an endpoint for communication and returns a file descriptor that refers to that
   * endpoint. see https://man7.org/linux/man-pages/man2/socket.2.html
   *
   * @param domain domain
   * @param type type
   * @param protocol protocol
   * @return file descriptor
   * @throws LastErrorException if any error occurs
   */
  public static native int socket(int domain, int type, int protocol) throws LastErrorException;

  /**
   * Connect socket
   *
   * @param sockfd file descriptor
   * @param sockaddr socket address
   * @param addrlen address length
   * @return zero on success. -1 on error
   * @throws LastErrorException if error occurs
   */
  public static native int connect(int sockfd, SockAddr sockaddr, int addrlen)
      throws LastErrorException;

  /**
   * Receive a message from a socket
   *
   * @param fd file descriptor
   * @param buffer buffer
   * @param count length
   * @param flags flag. see https://man7.org/linux/man-pages/man2/recvmsg.2.html
   * @return zero on success. -1 on error
   * @throws LastErrorException if error occurs
   */
  public static native int recv(int fd, byte[] buffer, int count, int flags)
      throws LastErrorException;

  /**
   * Send a message to a socket
   *
   * @param fd file descriptor
   * @param buffer buffer
   * @param count length
   * @param flags flag. see https://man7.org/linux/man-pages/man2/sendmsg.2.html
   * @return zero on success. -1 on error
   * @throws LastErrorException if error occurs
   */
  public static native int send(int fd, byte[] buffer, int count, int flags)
      throws LastErrorException;

  /**
   * Close socket
   *
   * @param fd file descriptor
   * @return zero on success. -1 on error
   * @throws LastErrorException if error occurs
   */
  public static native int close(int fd) throws LastErrorException;

  /**
   * return a description of the error code passed in the argument errnum.
   *
   * @param errno error pointer
   * @return error description
   */
  public static native String strerror(int errno);

  private static String formatError(LastErrorException lee) {
    try {
      return strerror(lee.getErrorCode());
    } catch (Throwable t) {
      return lee.getMessage();
    }
  }

  @Override
  public boolean isConnected() {
    return connected;
  }

  @Override
  public void close() throws IOException {
    if (!closeLock.getAndSet(true)) {
      try {
        close(fd);
      } catch (LastErrorException lee) {
        throw new IOException("native close() failed : " + formatError(lee));
      }
      connected = false;
    }
  }

  public void connect(SocketAddress endpoint, int timeout) throws IOException {
    try {
      int ret = connect(fd, sockaddr, sockaddr.size());
      if (ret != 0) {
        throw new IOException(strerror(Native.getLastError()));
      }
      connected = true;
    } catch (LastErrorException lee) {
      try {
        close();
      } catch (IOException e) {
      }

      throw new IOException("native connect() failed : " + formatError(lee));
    }
    is = new UnixSocketInputStream();
    os = new UnixSocketOutputStream();
  }

  public InputStream getInputStream() {
    return is;
  }

  public OutputStream getOutputStream() {
    return os;
  }

  public void setTcpNoDelay(boolean b) {
    // do nothing
  }

  public void setKeepAlive(boolean b) {
    // do nothing
  }

  public void setSoLinger(boolean b, int i) {
    // do nothing
  }

  public void setSoTimeout(int timeout) {
    // do nothing
  }

  public void shutdownInput() {
    // do nothing
  }

  public void shutdownOutput() {
    // do nothing
  }

  /** Socket address */
  public static class SockAddr extends Structure {
    /** socket family */
    public short sun_family = AF_UNIX;
    /** pathname */
    public byte[] sun_path;

    /**
     * Constructor.
     *
     * @param sunPath path
     */
    public SockAddr(String sunPath) {
      byte[] arr = sunPath.getBytes();
      sun_path = new byte[arr.length + 1];
      System.arraycopy(arr, 0, sun_path, 0, Math.min(sun_path.length - 1, arr.length));
      allocateMemory();
    }

    @Override
    protected java.util.List<String> getFieldOrder() {
      return Arrays.asList("sun_family", "sun_path");
    }
  }

  class UnixSocketInputStream extends InputStream {

    @Override
    public int read(byte[] bytesEntry, int off, int len) throws IOException {
      try {
        return recv(fd, bytesEntry, len, 0);
      } catch (LastErrorException lee) {
        throw new IOException("native read() failed : " + formatError(lee));
      }
    }

    @Override
    public int read() throws IOException {
      byte[] bytes = new byte[1];
      int bytesRead = read(bytes);
      if (bytesRead == 0) {
        return -1;
      }
      return bytes[0] & 0xff;
    }

    @Override
    public int read(byte[] bytes) throws IOException {
      return read(bytes, 0, bytes.length);
    }
  }

  class UnixSocketOutputStream extends OutputStream {

    @Override
    public void write(byte[] bytesEntry, int off, int len) throws IOException {
      int bytes;
      try {
        bytes = send(fd, bytesEntry, len, 0);

        if (bytes != len) {
          throw new IOException("can't write " + len + "bytes");
        }
      } catch (LastErrorException lee) {
        throw new IOException("native write() failed : " + formatError(lee));
      }
    }

    @Override
    public void write(int value) throws IOException {}
  }
}
