/*****************************************************************************\
 *  rate_limit.c
 *****************************************************************************
 *  Copyright (C) SchedMD LLC.
 *
 *  This file is part of Slurm, a resource management program.
 *  For details, see <https://slurm.schedmd.com/>.
 *  Please also read the included file: DISCLAIMER.
 *
 *  Slurm is free software; you can redistribute it and/or modify it under
 *  the terms of the GNU General Public License as published by the Free
 *  Software Foundation; either version 2 of the License, or (at your option)
 *  any later version.
 *
 *  In addition, as a special exception, the copyright holders give permission
 *  to link the code of portions of this program with the OpenSSL library under
 *  certain conditions as described in each individual source file, and
 *  distribute linked combinations including the two. You must obey the GNU
 *  General Public License in all respects for all of the code used other than
 *  OpenSSL. If you modify file(s) with this exception, you may extend this
 *  exception to your version of the file(s), but you are not obligated to do
 *  so. If you do not wish to do so, delete this exception statement from your
 *  version.  If you delete this exception statement from all source files in
 *  the program, then also delete it here.
 *
 *  Slurm is distributed in the hope that it will be useful, but WITHOUT ANY
 *  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 *  FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
 *  details.
 *
 *  You should have received a copy of the GNU General Public License along
 *  with Slurm; if not, write to the Free Software Foundation, Inc.,
 *  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301  USA.
\*****************************************************************************/

#include <stdbool.h>

#include "src/common/macros.h"
#include "src/common/parse_value.h"
#include "src/common/slurm_protocol_defs.h"
#include "src/common/xstring.h"

#include "src/interfaces/conn.h"

#include "src/slurmctld/proc_req.h"
#include "src/slurmctld/slurmctld.h"

/*
 * last_update is scaled by refill_period, and is not the direct unix time
 */
typedef struct {
	time_t last_update;
	time_t last_logged;
	uint32_t tokens;
	uid_t uid;
} user_bucket_t;

static uint32_t table_size = 8192;
static user_bucket_t *user_buckets = NULL;

static bool rate_limit_enabled = false;
static pthread_mutex_t rate_limit_mutex = PTHREAD_MUTEX_INITIALIZER;

/* 30 tokens max, bucket refills 2 tokens per 1 second */
static uint32_t bucket_size = 30;
static int log_freq = 0;
static uint32_t refill_rate = 2;
static uint32_t refill_period = 1;

static uint32_t _set_positive_rl_param(char *key, char *value_str)
{
	uint32_t value = 0;
	int len = strlen(key);

	/* take off '=' char */
	key[len - 1] = '\0';

	if (s_p_handle_uint32(&value, key, value_str) || !value) {
		fatal("%s=%s is invalid, must be a positive non-zero integer",
		      key, value_str);
	}
	return value;
}

extern void rate_limit_init(void)
{
	char rl_table_size_str[] = "rl_table_size=";
	char rl_bucket_size_str[] = "rl_bucket_size=";
	char rl_log_freq_str[] = "rl_log_freq=";
	char rl_refill_rate_str[] = "rl_refill_rate=";
	char rl_refill_period_str[] = "rl_refill_period=";
	char *tmp_ptr;

	if (!xstrcasestr(slurm_conf.slurmctld_params, "rl_enable"))
		return;

	if ((tmp_ptr = conf_get_opt_str(slurm_conf.slurmctld_params,
					rl_table_size_str))) {
		table_size = _set_positive_rl_param(rl_table_size_str, tmp_ptr);
		xfree(tmp_ptr);
	}

	if ((tmp_ptr = conf_get_opt_str(slurm_conf.slurmctld_params,
					rl_bucket_size_str))) {
		bucket_size =
			_set_positive_rl_param(rl_bucket_size_str, tmp_ptr);
		xfree(tmp_ptr);
	}

	if ((tmp_ptr = conf_get_opt_str(slurm_conf.slurmctld_params,
					rl_log_freq_str))) {
		log_freq = atoi(tmp_ptr);
		if (log_freq < -1) {
			fatal("%s%s is invalid, must be -1, 0, or a positive integer",
			      rl_log_freq_str, tmp_ptr);
		}
		xfree(tmp_ptr);
	}

	if ((tmp_ptr = conf_get_opt_str(slurm_conf.slurmctld_params,
					rl_refill_rate_str))) {
		refill_rate =
			_set_positive_rl_param(rl_refill_rate_str, tmp_ptr);
		xfree(tmp_ptr);
	}

	if ((tmp_ptr = conf_get_opt_str(slurm_conf.slurmctld_params,
					rl_refill_period_str))) {
		refill_period =
			_set_positive_rl_param(rl_refill_period_str, tmp_ptr);
		xfree(tmp_ptr);
	}

	rate_limit_enabled = true;
	user_buckets = xcalloc(table_size, sizeof(user_bucket_t));

	info("RPC rate limiting enabled");
	info("%s: rl_table_size=%u,rl_bucket_size=%u,rl_refill_rate=%u,rl_refill_period=%u",
	     __func__, table_size, bucket_size, refill_rate, refill_period);
}

extern void rate_limit_shutdown(void)
{
	slurm_mutex_lock(&rate_limit_mutex);
	rate_limit_enabled = false;
	xfree(user_buckets);
	slurm_mutex_unlock(&rate_limit_mutex);
}

/*
 * Return true if the limit's been exceeded.
 * False otherwise.
 */
extern bool rate_limit_exceeded(slurm_msg_t *msg)
{
	slurmctld_rpc_t *this_rpc = NULL;
	bool exceeded = false;
	int start_position = 0, position = 0;
	time_t now;

	if (!rate_limit_enabled)
		return false;

	if ((this_rpc = find_rpc(msg->msg_type)) && this_rpc->rl_exempt)
		return false;

	/*
	 * Exempt SlurmUser / root. Subjecting internal cluster traffic to
	 * the rate limit would break things really quickly. :)
	 * (We're assuming SlurmdUser is root here.)
	 */
	if (validate_slurm_user(msg->auth_uid))
		return false;

	slurm_mutex_lock(&rate_limit_mutex);
	/* This can happen if we already executed rate_limit_shutdown */
	if (!user_buckets) {
		slurm_mutex_unlock(&rate_limit_mutex);
		return true;
	}
	now = time(NULL);

	/*
	 * Scan for position. Note that uid 0 indicates an unused slot,
	 * since root is never subjected to the rate limit.
	 * Naively hash the uid into the table. If that's not a match, keep
	 * scanning for the next vacant spot. Wrap around to the front if
	 * necessary once we hit the end.
	 */
	start_position = position = msg->auth_uid % table_size;
	while ((user_buckets[position].uid) &&
	       (user_buckets[position].uid != msg->auth_uid)) {
		position++;
		if (position == table_size)
			position = 0;
		if (position == start_position) {
			position = table_size;
			break;
		}
	}

	if (position == table_size) {
		/*
		 * Avoid the temptation to resize the table... you'd need to
		 * rehash all the contents which would be annoying and slow.
		 */
		error("RPC Rate Limiting: ran out of user table space. User will not be limited.");
	} else if (!user_buckets[position].uid) {
		user_buckets[position].uid = msg->auth_uid;
		user_buckets[position].last_update = now / refill_period;
		user_buckets[position].tokens = bucket_size - 1;
		debug3("%s: new entry for uid %u", __func__, msg->auth_uid);
	} else {
		time_t now_periods = now / refill_period;
		time_t delta = now_periods - user_buckets[position].last_update;
		user_buckets[position].last_update = now_periods;

		/* add tokens */
		if (delta) {
			user_buckets[position].tokens += (delta * refill_rate);
			user_buckets[position].tokens =
				MIN(user_buckets[position].tokens, bucket_size);
		}

		if (user_buckets[position].tokens)
			user_buckets[position].tokens--;
		else
			exceeded = true;

		debug3("%s: found uid %u at position %d remaining tokens %d%s",
		       __func__, msg->auth_uid, position,
		       user_buckets[position].tokens,
		       (exceeded ? " rate limit exceeded" : ""));
	}
	slurm_mutex_unlock(&rate_limit_mutex);

	if (exceeded && (log_freq != -1) &&
	    ((user_buckets[position].last_logged + log_freq) <= now)) {
		slurm_addr_t *cli_addr = &msg->address;

		if (cli_addr->ss_family == AF_UNSPEC) {
			int fd = conn_g_get_fd(msg->conn);
			(void) slurm_get_peer_addr(fd, cli_addr);
		}

		info("RPC rate limit exceeded by uid %u with %s from %pA, telling to back off",
		     msg->auth_uid, rpc_num2string(msg->msg_type), cli_addr);
		user_buckets[position].last_logged = now;
	}

	return exceeded;
}
