#include <common.h>
#include <command.h>
#include <dm.h>
#include <mapmem.h>
#include <tee/ta_helper.h>
#include <tee/ta_hello_world.h>
#include <linux/kernel.h>

/*
 * Probe and list all compliant TEE devices.
 */
static int
do_tee_devices(cmd_tbl_t * cmdtp, int flag, int argc, char *const argv[])
{
	struct udevice *dev = NULL;
	int rc;

	for (rc = uclass_first_device(UCLASS_TEE, &dev); dev;
	     rc = uclass_next_device(&dev)) {
		/* $name@$instance */
		printf("%s@%llx\n", dev->name, map_to_sysmem(dev));
	}

	return rc;
}

static int do_tee_hello(cmd_tbl_t * cmdtp, int flag, int argc,
			char *const argv[])
{
	int rc;
	struct tee_optee_ta_uuid uuid = TA_HELLO_WORLD_UUID;
	struct tee_param param = {0};
	param.attr = TEE_PARAM_ATTR_TYPE_VALUE_INOUT;
	param.u.value.a = 16;

	printf("Invoking TA to increment value 16.\n");
	rc = ta_call(&uuid, TA_HELLO_WORLD_CMD_INC_VALUE, 1,
		     &param);
	if (rc)
		return rc;

	printf("TA incremented my value to %lld.\n", param.u.value.a);
	if (param.u.value.a != 17)
		return -EINVAL;

	return rc;
}

static int do_tee_echo(cmd_tbl_t * cmdtp, int flag, int argc,
			char *const argv[])
{
	int rc;
	struct ta_context context = {0};
	struct tee_optee_ta_uuid uuid = TA_HELLO_WORLD_UUID;
	struct tee_param params[4] = {0};
	char *msg = "Hello from U-Boot!";
	size_t msg_size;
	struct tee_shm *shm_msg = NULL;
	const size_t buf_size = 128;
	struct tee_shm *shm_buf = NULL;

	if (argc > 1)
		msg = argv[1];
	msg_size = strlen(msg) + 1;

	/* Connect to TA. */
	rc = ta_open(&uuid, &context);
	if (rc)
		return rc;

	/* Pass message in via shared memory via params[0]. */
	rc = tee_shm_alloc(context.tee, msg_size, TEE_SHM_ALLOC, &shm_msg);
	if (rc)
		goto out;

	memcpy(shm_msg->addr, msg, msg_size);

	params[0].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
	params[0].u.memref.shm = shm_msg;
	params[0].u.memref.size = msg_size;

	/* Receive responses via shared memory via params[3]. */
	rc = tee_shm_alloc(context.tee, buf_size, TEE_SHM_ALLOC, &shm_buf);
	if (rc)
		goto out;

	params[3].attr = TEE_PARAM_ATTR_TYPE_MEMREF_OUTPUT;
	params[3].u.memref.shm = shm_buf;
	params[3].u.memref.size = buf_size;

	/* Invoke TA_HELLO_WORLD_CMD_ECHO exposed by the TA. */
	printf("Sending \"%s\" to the hello world TA.\n", msg);
	rc = ta_invoke(&context, TA_HELLO_WORLD_CMD_ECHO,
			ARRAY_SIZE(params), params);
	if (rc)
		goto out;

	if (params[3].u.memref.size == 0 || params[3].u.memref.size > buf_size) {
		rc = -ERANGE;
		goto out;
	}

	char *resp = (char *)shm_buf->addr;
	resp[params[3].u.memref.size - 1] = 0;
	printf("TA said: \"%s\"\n", resp);

out:
	tee_shm_free(shm_msg);
	tee_shm_free(shm_buf);
	ta_close(&context);
	return rc;
}

static int write_keyval(const char *key, const void *buf, size_t bytes)
{
	int rc;
	struct ta_context context = {0};
	struct tee_optee_ta_uuid uuid = TA_HELLO_WORLD_UUID;
	struct tee_param params[2] = {0};

	struct tee_shm *shm_key = NULL;
	size_t key_size = strlen(key) + 1;
	struct tee_shm *shm_buf = NULL;

	/* Connect to TA. */
	rc = ta_open(&uuid, &context);
	if (rc)
		return rc;

	/* Pass in @key via shared memory. */
	rc = tee_shm_alloc(context.tee, key_size, TEE_SHM_ALLOC, &shm_key);
	if (rc)
		goto out;

	memcpy(shm_key->addr, key, key_size);
	params[0].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
	params[0].u.memref.shm = shm_key;
	params[0].u.memref.size = key_size;

	/* Pass in buf via shared memory. */
	rc = tee_shm_alloc(context.tee, bytes, TEE_SHM_ALLOC, &shm_buf);
	if (rc)
		goto out;

	memcpy(shm_buf->addr, buf, bytes);
	params[1].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
	params[1].u.memref.shm = shm_buf;
	params[1].u.memref.size = bytes;

	/* Invoke the TA. */
	rc = ta_invoke(&context, TA_HELLO_WORLD_CMD_WRITE_PERSISTENT_VALUE,
			ARRAY_SIZE(params), params);
	if (rc)
		goto out;

out:
	tee_shm_free(shm_key);
	tee_shm_free(shm_buf);
	ta_close(&context);
	return rc;
}

static int read_keyval(const char *key, void *buf, size_t bytes)
{
	int rc;
	struct ta_context context = {0};
	struct tee_optee_ta_uuid uuid = TA_HELLO_WORLD_UUID;
	struct tee_param params[2] = {0};

	struct tee_shm *shm_key = NULL;
	size_t key_size = strlen(key) + 1;
	struct tee_shm *shm_buf = NULL;

	/* Connect to TA. */
	rc = ta_open(&uuid, &context);
	if (rc)
		return rc;

	/* Pass in @key via shared memory. */
	rc = tee_shm_alloc(context.tee, key_size, TEE_SHM_ALLOC, &shm_key);
	if (rc)
		goto out;

	memcpy(shm_key->addr, key, key_size);
	params[0].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INPUT;
	params[0].u.memref.shm = shm_key;
	params[0].u.memref.size = key_size;

	/* Pass in buf via shared memory. */
	rc = tee_shm_alloc(context.tee, bytes, TEE_SHM_ALLOC, &shm_buf);
	if (rc)
		goto out;

	params[1].attr = TEE_PARAM_ATTR_TYPE_MEMREF_INOUT;
	params[1].u.memref.shm = shm_buf;
	params[1].u.memref.size = bytes;

	/* Invoke the TA. */
	rc = ta_invoke(&context, TA_HELLO_WORLD_CMD_READ_PERSISTENT_VALUE,
			ARRAY_SIZE(params), params);
	if (rc)
		goto out;

	if (params[1].u.memref.size == 0 || params[1].u.memref.size > bytes)
		return -ERANGE;
	memcpy(buf, (void *)(params[1].u.memref.shm->addr),
			params[1].u.memref.size);

out:
	tee_shm_free(shm_key);
	tee_shm_free(shm_buf);
	ta_close(&context);
	return rc;
}

static int do_tee_set(cmd_tbl_t * cmdtp, int flag, int argc,
			char *const argv[])
{
	if (argc < 3 || !argv[1][0] || !argv[2][0])
		return -EINVAL;

	const char *key = argv[1];
	const char *value = argv[2];

	return write_keyval(key, value, strlen(value) + 1);
}

static int do_tee_get(cmd_tbl_t * cmdtp, int flag, int argc,
			char *const argv[])
{
	if (argc < 2 || !argv[1][0])
		return -EINVAL;

	const char *key = argv[1];
	char value[64];
	int rc = read_keyval(key, value, sizeof(value));
	if (rc)
		return rc;

	value[sizeof(value) - 1] = 0;
	printf("\"%s\"\n", value);

	return 0;
}

static cmd_tbl_t tee_commands[] = {
	U_BOOT_CMD_MKENT(devices, /* maxargs = */ 1, /* repeatable = */ 0,
			 do_tee_devices, "", ""),
	U_BOOT_CMD_MKENT(hello, /* maxargs = */ 1, /* repeatable = */ 0,
			 do_tee_hello, "", ""),
	U_BOOT_CMD_MKENT(echo, /* maxargs = */ 2, /* repeatable = */ 0,
			 do_tee_echo, "", ""),
	U_BOOT_CMD_MKENT(set, /* maxargs = */ 3, /* repeatable = */ 0,
			 do_tee_set, "", ""),
	U_BOOT_CMD_MKENT(get, /* maxargs = */ 2, /* repeatable = */ 0,
			 do_tee_get, "", ""),
};

static int do_tee(cmd_tbl_t *cmdtp, int flag, int argc, char *const argv[])
{
	cmd_tbl_t *tee_cmd;
	int ret;

	if (argc < 2)
		return CMD_RET_USAGE;

	argc--;
	argv++;

	tee_cmd = find_cmd_tbl(argv[0], tee_commands, ARRAY_SIZE(tee_commands));
	if (!tee_cmd || argc > tee_cmd->maxargs)
		return CMD_RET_USAGE;

	ret = tee_cmd->cmd(tee_cmd, flag, argc, argv);
	return cmd_process_error(cmdtp, ret);
}

U_BOOT_CMD(tee, /* maxargs = */ 4, /* repeatable = */ 0, do_tee,
	   "Trusted Execution Environment operations.",
	   "devices - Probes TEE devices.\n"
	   "tee hello - Pings the hello world TA.\n"
	   "tee echo [msg] - Sends a message to the hello world TA.\n"
	   "tee set <key> <value> - Sets a string value to a key in the hello world TA's secure storage.\n"
	   "tee get <key> - Gets the value to a key in the hello world TA's secure storage.\n"
	   );
