blob: d85bb7d9283e74b62e394bed1751a52d322109fe [file] [log] [blame]
/*****************************************************************************\
* conn.c - connection API definitions
*****************************************************************************
* Copyright (C) SchedMD LLC.
*
* This file is part of Slurm, a resource management program.
* For details, see <https://slurm.schedmd.com/>.
* Please also read the included file: DISCLAIMER.
*
* Slurm is free software; you can redistribute it and/or modify it under
* the terms of the GNU General Public License as published by the Free
* Software Foundation; either version 2 of the License, or (at your option)
* any later version.
*
* In addition, as a special exception, the copyright holders give permission
* to link the code of portions of this program with the OpenSSL library under
* certain conditions as described in each individual source file, and
* distribute linked combinations including the two. You must obey the GNU
* General Public License in all respects for all of the code used other than
* OpenSSL. If you modify file(s) with this exception, you may extend this
* exception to your version of the file(s), but you are not obligated to do
* so. If you do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source files in
* the program, then also delete it here.
*
* Slurm is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along
* with Slurm; if not, write to the Free Software Foundation, Inc.,
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
\*****************************************************************************/
#include <pthread.h>
#include <stdlib.h>
#include <string.h>
#include "src/common/fd.h"
#include "src/common/macros.h"
#include "src/common/plugin.h"
#include "src/common/plugrack.h"
#include "src/common/read_config.h"
#include "src/common/slurm_protocol_api.h"
#include "src/common/util-net.h"
#include "src/common/xassert.h"
#include "src/common/xmalloc.h"
#include "src/common/xstring.h"
#include "src/interfaces/certmgr.h"
#include "src/interfaces/conn.h"
typedef struct {
uint32_t (*plugin_id);
int (*load_ca_cert)(char *cert_file);
char *(*get_own_public_cert)(void);
int (*load_own_cert)(char *cert, uint32_t cert_len, char *key,
uint32_t key_len);
int (*load_self_signed_cert)(void);
bool (*own_cert_loaded)(void);
void *(*create_conn)(const conn_args_t *conn_args);
void (*destroy_conn)(conn_t *conn, bool close_fds);
ssize_t (*send)(conn_t *conn, const void *buf, size_t n);
ssize_t (*sendv)(conn_t *conn, const struct iovec *bufs, int count);
uint32_t (*peek)(conn_t *conn);
ssize_t (*recv)(conn_t *conn, void *buf, size_t n);
timespec_t (*get_delay)(conn_t *conn);
int (*negotiate)(conn_t *conn);
bool (*is_client_authenticated)(conn_t *conn);
int (*get_conn_fd)(conn_t *conn);
int (*set_conn_fds)(conn_t *conn, int input_fd, int output_fd);
int (*set_conn_callbacks)(conn_t *conn, conn_callbacks_t *callbacks);
int (*shutdown_conn)(conn_t *conn);
} conn_ops_t;
typedef struct tls_conn_s conn_s;
/*
* These strings must be kept in the same order as the fields
* declared for conn_ops_t.
*/
static const char *syms[] = {
"plugin_id",
"tls_p_load_ca_cert",
"tls_p_get_own_public_cert",
"tls_p_load_own_cert",
"tls_p_load_self_signed_cert",
"tls_p_own_cert_loaded",
"tls_p_create_conn",
"tls_p_destroy_conn",
"tls_p_send",
"tls_p_sendv",
"tls_p_peek",
"tls_p_recv",
"tls_p_get_delay",
"tls_p_negotiate_conn",
"tls_p_is_client_authenticated",
"tls_p_get_conn_fd",
"tls_p_set_conn_fds",
"tls_p_set_conn_callbacks",
"tls_p_shutdown_conn",
};
static conn_ops_t ops;
static plugin_context_t *g_context = NULL;
static plugin_init_t plugin_inited = PLUGIN_NOT_INITED;
static pthread_rwlock_t context_lock = PTHREAD_RWLOCK_INITIALIZER;
static bool tls_enabled_bool = false;
extern char *conn_mode_to_str(conn_mode_t mode)
{
switch (mode) {
case CONN_NULL:
return "null";
case CONN_SERVER:
return "server";
case CONN_CLIENT:
return "client";
}
return "INVALID";
}
extern bool conn_tls_enabled(void)
{
return tls_enabled_bool;
}
extern int conn_g_init(void)
{
int rc = SLURM_SUCCESS;
char *plugin_type = "tls";
char *tls_type = NULL;
slurm_rwlock_wrlock(&context_lock);
if (plugin_inited != PLUGIN_NOT_INITED)
goto done;
tls_type = xstrdup(slurm_conf.tls_type);
if (!tls_type) {
plugin_inited = PLUGIN_NOOP;
goto done;
}
g_context = plugin_context_create(plugin_type, tls_type, (void **) &ops,
syms, sizeof(syms));
if (!g_context) {
error("cannot create %s context for %s",
plugin_type, tls_type);
rc = SLURM_ERROR;
plugin_inited = PLUGIN_NOT_INITED;
goto done;
}
if (xstrstr(slurm_conf.tls_type, "s2n"))
tls_enabled_bool = true;
plugin_inited = PLUGIN_INITED;
if (tls_enabled_bool) {
/* Load CA cert now, wait until later in configless */
if (slurm_conf.last_update && conn_g_load_ca_cert(NULL)) {
error("Could not load trusted certificates for s2n");
rc = SLURM_ERROR;
goto done;
}
/* Load own cert from file */
if ((running_in_slurmctld() || running_in_slurmdbd() ||
running_in_slurmrestd() || running_in_slurmd() ||
running_in_sackd()) &&
slurm_conf.last_update &&
conn_g_load_own_cert(NULL, 0, NULL, 0)) {
error("Could not load own TLS certificate from file");
rc = SLURM_ERROR;
goto done;
}
/*
* Load self-signed certificate in client commands and
* slurmstepd
*/
if ((!running_in_daemon() && !running_in_slurmstepd()) &&
conn_g_load_self_signed_cert()) {
error("Could not load self-signed TLS certificate");
rc = SLURM_ERROR;
goto done;
}
}
done:
xfree(tls_type);
slurm_rwlock_unlock(&context_lock);
return rc;
}
extern int conn_g_fini(void)
{
int rc = SLURM_SUCCESS;
slurm_rwlock_wrlock(&context_lock);
if (g_context) {
rc = plugin_context_destroy(g_context);
g_context = NULL;
}
plugin_inited = PLUGIN_NOT_INITED;
slurm_rwlock_unlock(&context_lock);
return rc;
}
extern int conn_g_load_ca_cert(char *cert_file)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.load_ca_cert))(cert_file);
}
extern char *conn_g_get_own_public_cert(void)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.get_own_public_cert))();
}
extern int conn_g_load_own_cert(char *cert, uint32_t cert_len, char *key,
uint32_t key_len)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.load_own_cert))(cert, cert_len, key, key_len);
}
extern int conn_g_load_self_signed_cert(void)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.load_self_signed_cert))();
}
extern bool conn_g_own_cert_loaded(void)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.own_cert_loaded))();
}
extern conn_t *conn_g_create(const conn_args_t *conn_args)
{
xassert(plugin_inited == PLUGIN_INITED);
log_flag(NET, "%s: fd:%d->%d mode:%d",
__func__, conn_args->input_fd, conn_args->output_fd,
conn_args->mode);
return (*(ops.create_conn))(conn_args);
}
extern void conn_g_destroy(conn_t *conn, bool close_fds)
{
if (!conn)
return;
xassert(plugin_inited == PLUGIN_INITED);
(*(ops.destroy_conn))(conn, close_fds);
}
extern ssize_t conn_g_send(conn_t *conn, const void *buf, size_t n)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.send))(conn, buf, n);
}
extern ssize_t conn_g_sendv(conn_t *conn, const struct iovec *bufs, int count)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.sendv))(conn, bufs, count);
}
extern uint32_t conn_g_peek(conn_t *conn)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.peek))(conn);
}
extern ssize_t conn_g_recv(conn_t *conn, void *buf, size_t n)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.recv))(conn, buf, n);
}
extern timespec_t conn_g_get_delay(conn_t *conn)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.get_delay))(conn);
}
extern int conn_g_negotiate_tls(conn_t *conn)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.negotiate))(conn);
}
extern bool conn_g_is_client_authenticated(conn_t *conn)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.is_client_authenticated))(conn);
}
extern int conn_g_get_fd(conn_t *conn)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.get_conn_fd))(conn);
}
extern int conn_g_set_fds(conn_t *conn, int input_fd, int output_fd)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.set_conn_fds))(conn, input_fd, output_fd);
}
extern int conn_g_set_callbacks(conn_t *conn, conn_callbacks_t *callbacks)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.set_conn_callbacks))(conn, callbacks);
}
extern int conn_g_shutdown(conn_t *conn)
{
xassert(plugin_inited == PLUGIN_INITED);
return (*(ops.shutdown_conn))(conn);
}
extern int conn_blocking_g_shutdown(conn_t *conn)
{
int rc;
int fd;
xassert(plugin_inited == PLUGIN_INITED);
if (!conn)
return SLURM_SUCCESS;
fd = conn_g_get_fd(conn);
while ((rc = (*(ops.shutdown_conn))(conn))) {
short int events = POLLIN;
if ((rc != SLURM_BLOCKED_ON_READ) &&
(rc != SLURM_BLOCKED_ON_WRITE))
return rc;
if (rc == SLURM_BLOCKED_ON_READ)
events = POLLIN;
else if (rc == SLURM_BLOCKED_ON_WRITE)
events = POLLOUT;
/* Wait until it's possible read/write then try again */
if ((rc = wait_fd(fd, slurm_conf.msg_timeout, events))) {
error("%s: Failed to wait on fd %d",
__func__, fd);
return rc;
}
}
return rc;
}