/*
 * Copyright (c) 2019 The Fuchsia Authors
 *
 * SPDX-License-Identifier:	BSD-3-Clause
 */

#include <common.h>

#include <asm/arch/efuse.h>
#include <asm/arch/secure_apb.h>
#include <asm/io.h>
#include <libavb/libavb.h>
#include <libavb_atx/libavb_atx.h>
#include <tee/ta_vx_helper.h>
#include <zbi/zbi.h>
#include <zircon_uboot/partition.h>
#include <zircon_uboot/vboot.h>
#include <zircon_uboot/zircon.h>

#define AVB_ATX_NUM_KEY_VERSIONS 2

/* By convention, when a rollback index is not used, the value remains zero. */
#define ROLLBACK_INDEX_NOT_USED (0)

extern AvbIOResult
avb_read_permanent_attributes(AvbAtxOps *atx_ops,
			      AvbAtxPermanentAttributes *attributes);

extern AvbIOResult
avb_read_permanent_attributes_hash(AvbAtxOps *atx_ops,
				   uint8_t hash[AVB_SHA256_DIGEST_SIZE]);

typedef struct {
	struct {
		size_t location;
		uint64_t value;
	} key_versions[AVB_ATX_NUM_KEY_VERSIONS];
	size_t next_key_version_index;

	uint8_t *preloaded_img_addr;
	uint64_t preloaded_img_size;
} vboot_context_t;

/* If a negative offset is given, computes the unsigned offset. */
static inline int64_t calc_offset(uint64_t size, int64_t offset)
{
	if (offset < 0) {
		return size + offset;
	}
	return offset;
}

static AvbIOResult read_from_partition(AvbOps *ops, const char *partition,
				       int64_t offset, size_t num_bytes,
				       void *buffer, size_t *out_num_read)
{
	int64_t abs_offset;
	zircon_partition *part = zircon_get_partition(partition);
	if (!part) {
		return AVB_IO_RESULT_ERROR_NO_SUCH_PARTITION;
	}
	abs_offset = calc_offset(part->size, offset);
	if ((abs_offset > part->size) || (abs_offset < 0)) {
		zircon_free_partition(part);
		return AVB_IO_RESULT_ERROR_RANGE_OUTSIDE_PARTITION;
	}
	if ((abs_offset + num_bytes) > part->size) {
		num_bytes = part->size - abs_offset;
	}

	if (part->read(part, abs_offset, buffer, num_bytes)) {
		zircon_free_partition(part);
		return AVB_IO_RESULT_ERROR_IO;
	}
	*out_num_read = num_bytes;
	zircon_free_partition(part);
	return AVB_IO_RESULT_OK;
}

static AvbIOResult get_preloaded_partition(AvbOps *ops, const char *partition,
					   size_t num_bytes,
					   uint8_t **out_pointer,
					   size_t *out_num_bytes_preloaded)
{
	vboot_context_t *context = (vboot_context_t *)ops->user_data;
	*out_pointer = NULL;
	*out_num_bytes_preloaded = 0;

	if (!strncmp(partition, ZIRCON_PARTITION_PREFIX,
		     strlen(ZIRCON_PARTITION_PREFIX))) {
		*out_pointer = context->preloaded_img_addr;

		if (num_bytes <= context->preloaded_img_size) {
			*out_num_bytes_preloaded = num_bytes;
		} else {
			*out_num_bytes_preloaded = context->preloaded_img_size;
		}
	}
	return AVB_IO_RESULT_OK;
}

static AvbIOResult write_to_partition(AvbOps *ops, const char *partition,
				      int64_t offset, size_t num_bytes,
				      const void *buffer)
{
	// Our usage of libavb should never be writing to a partition - this is only
	// used by the (deprecated) libavb_ab extension.
	printf("Error: libavb write_to_partition() unimplemented\n");
	return AVB_IO_RESULT_ERROR_IO;
}

/* avb_slot_verify uses this call to check that a partition exists.
 * Checks for existence but ignores GUID because it's unused. */
static AvbIOResult get_unique_guid_for_partition(AvbOps *ops,
						 const char *partition,
						 char *guid_buf,
						 size_t guid_buf_size)
{
	zircon_partition *part = zircon_get_partition(partition);
	if (!part) {
		return AVB_IO_RESULT_ERROR_NO_SUCH_PARTITION;
	}
	zircon_free_partition(part);

	guid_buf[0] = '\0';
	return AVB_IO_RESULT_OK;
}

static AvbIOResult get_size_of_partition(AvbOps *ops, const char *partition,
					 uint64_t *out_size_num_bytes)
{
	zircon_partition *part = zircon_get_partition(partition);
	if (!part) {
		return AVB_IO_RESULT_ERROR_NO_SUCH_PARTITION;
	}
	*out_size_num_bytes = part->size;
	zircon_free_partition(part);
	return AVB_IO_RESULT_OK;
}

static AvbIOResult get_random(AvbAtxOps *atx_ops, size_t num_bytes,
			      uint8_t *output)
{
	if (ta_vx_cprng_draw(output, num_bytes)) {
		return AVB_IO_RESULT_ERROR_IO;
	}
	return AVB_IO_RESULT_OK;
}

static AvbIOResult read_is_device_unlocked(AvbOps *ops, bool *out_is_unlocked)
{
	if (ta_vx_is_unlocked(out_is_unlocked)) {
		return AVB_IO_RESULT_ERROR_IO;
	}
	return AVB_IO_RESULT_OK;
}

static AvbIOResult read_persistent_value(AvbOps *ops, const char *name,
					 size_t buffer_size,
					 uint8_t *out_buffer,
					 size_t *out_num_bytes_read)
{
	if (ta_vx_read_persistent_value(name, out_buffer, buffer_size,
					out_num_bytes_read)) {
		// Per contract with avb_ops->read_persistent_value, we should
		// return AVB_IO_RESULT_ERROR_NO_SUCH_VALUE when the specified
		// value does not exist, is not supported, or is not populated.
		//
		// In order to relieve the complexity with error propagation
		// from the TA, we consider any error as if the error was
		// "no such value".
		return AVB_IO_RESULT_ERROR_NO_SUCH_VALUE;
	}

	return AVB_IO_RESULT_OK;
}

static AvbIOResult write_persistent_value(AvbOps *ops, const char *name,
					  size_t value_size,
					  const uint8_t *value)
{
	// Per contract with avb_ops->write_persistent_value, if |value_size|
	// is zero, future calls to |read_persisent_value| shall return
	// AVB_IO_RESULT_ERROR_NO_SUCH_VALUE. That means we should delete the
	// value if |value_size| is zero.
	if (value_size == 0 && ta_vx_delete_persistent_value(name)) {
		return AVB_IO_RESULT_ERROR_IO;
	}

	if (value_size > 0 &&
	    ta_vx_write_persistent_value(name, value, value_size)) {
		return AVB_IO_RESULT_ERROR_IO;
	}

	return AVB_IO_RESULT_OK;
}

static AvbIOResult avb_read_rollback_index(AvbOps *ops,
					   size_t rollback_index_location,
					   uint64_t *out_rollback_index)
{
	if (ta_vx_read_rollback_index(rollback_index_location,
				      out_rollback_index)) {
		return AVB_IO_RESULT_ERROR_IO;
	}
	return AVB_IO_RESULT_OK;
}

static AvbIOResult avb_write_rollback_index(AvbOps *ops,
					    size_t rollback_index_location,
					    uint64_t rollback_index)
{
	if (ta_vx_write_rollback_index(rollback_index_location,
				       rollback_index)) {
		return AVB_IO_RESULT_ERROR_IO;
	}
	return AVB_IO_RESULT_OK;
}

static void set_key_version(AvbAtxOps *atx_ops, size_t rollback_index_location,
			    uint64_t key_version)
{
	vboot_context_t *context = (vboot_context_t *)atx_ops->ops->user_data;

	size_t index = context->next_key_version_index++;
	if (index < AVB_ATX_NUM_KEY_VERSIONS) {
		context->key_versions[index].location = rollback_index_location;
		context->key_versions[index].value = key_version;
	} else {
		printf("ERROR: set_key_version index out of bounds: %lu\n",
		       index);
		avb_abort();
	}
}

static AvbOps ops;

static AvbAtxOps atx_ops = {
	.ops = &ops,
	.read_permanent_attributes = avb_read_permanent_attributes,
	.read_permanent_attributes_hash = avb_read_permanent_attributes_hash,
	.set_key_version = set_key_version,
	.get_random = get_random,
};

static AvbOps ops = {
	.atx_ops = &atx_ops,
	.read_from_partition = read_from_partition,
	.get_preloaded_partition = get_preloaded_partition,
	.write_to_partition = write_to_partition,
	.validate_vbmeta_public_key = avb_atx_validate_vbmeta_public_key,
	.read_rollback_index = avb_read_rollback_index,
	.write_rollback_index = avb_write_rollback_index,
	.read_is_device_unlocked = read_is_device_unlocked,
	.get_unique_guid_for_partition = get_unique_guid_for_partition,
	.get_size_of_partition = get_size_of_partition,
	.read_persistent_value = read_persistent_value,
	.write_persistent_value = write_persistent_value,
};

struct property_lookup_user_data {
	zbi_header_t *zbi;
	size_t capacity;
};

static bool property_lookup_desc_foreach(const AvbDescriptor *header,
					 void *user_data);

static int process_verify_data(AvbOps *ops, AvbSlotVerifyResult result,
			       AvbSlotVerifyData *verify_data,
			       vboot_context_t *context, zbi_header_t *zbi,
			       size_t capacity, bool has_successfully_booted)
{
	// Copy zbi items within vbmeta regardless of lock state.
	if (zbi && result == AVB_SLOT_VERIFY_RESULT_OK) {
		struct property_lookup_user_data lookup_data = {
			.zbi = zbi, .capacity = capacity
		};

		for (int i = 0; i < verify_data->num_vbmeta_images; ++i) {
			AvbVBMetaData *vb = &verify_data->vbmeta_images[i];
			/* load properties into KV store */
			if (!avb_descriptor_foreach(
				    vb->vbmeta_data, vb->vbmeta_size,
				    property_lookup_desc_foreach,
				    &lookup_data)) {
				fprintf(stderr,
					"Fail to parse VBMETA properties\n");
				return -1;
			}
		}
	}

	bool unlocked;

	if (ops->read_is_device_unlocked(ops, &unlocked)) {
		fprintf(stderr, "Failed to read lock state.\n");
		return -1;
	}

	if (unlocked) {
		printf("Device unlocked: not checking verification result.\n");
		return 0;
	}

	if (result != AVB_SLOT_VERIFY_RESULT_OK) {
		fprintf(stderr, "Failed to verify, err_code: %s\n",
			avb_slot_verify_result_to_string(result));
		return -1;
	}

	// Increase rollback index values to match the verified slot only if
	// it has already successfully booted.
	if (has_successfully_booted) {
		int i;
		for (i = 0; i < ARRAY_SIZE(verify_data->rollback_indexes);
		     i++) {
			uint64_t rollback_index_value =
				verify_data->rollback_indexes[i];

			if (rollback_index_value == ROLLBACK_INDEX_NOT_USED) {
				continue;
			}

			result = ops->write_rollback_index(
				ops, i, rollback_index_value);
			if (result != AVB_SLOT_VERIFY_RESULT_OK) {
				fprintf(stderr,
					"Failed to write rollback index: %d\n",
					i);
				return -1;
			}
		}

		/* Also increase rollback index values for Fuchsia key version locations.
		 */
		if (context == NULL) {
			fprintf(stderr, "key version context not found\n");
			return -1;
		}

		for (i = 0; i < AVB_ATX_NUM_KEY_VERSIONS; i++) {
			result = ops->write_rollback_index(
				ops, context->key_versions[i].location,
				context->key_versions[i].value);

			if (result != AVB_SLOT_VERIFY_RESULT_OK) {
				fprintf(stderr,
					"Failed to write rollback index: %zu\n",
					context->key_versions[i].location);
				return -1;
			}
		}
	}

	return 0;
}

int zircon_vboot_slot_verify(unsigned char *loadaddr, uint64_t img_size,
			     const char *ab_suffix,
			     bool has_successfully_booted, zbi_header_t *zbi,
			     size_t capacity)
{
	vboot_context_t context = { 0 };
	context.preloaded_img_addr = loadaddr;
	context.preloaded_img_size = img_size;
	ops.user_data = (void *)&context;

	const char *const requested_partitions[] = {
		ZIRCON_PARTITION_PREFIX,
#ifdef CONFIG_FACTORY_BOOT_KVS
		CONFIG_FACTORY_BOOT_PARTITION_NAME,
#else
		"factory",
#endif
		NULL
	};

	int ret = ta_vx_lock_if_ephemerally_unlocked();
	if (ret) {
		return ret;
	}

	AvbSlotVerifyData *verify_data = NULL;

	AvbSlotVerifyResult result =
		avb_slot_verify(&ops, requested_partitions, ab_suffix,
				AVB_SLOT_VERIFY_FLAGS_NONE,
				AVB_HASHTREE_ERROR_MODE_EIO, &verify_data);

	ret = process_verify_data(&ops, result, verify_data, &context, zbi,
				      capacity, has_successfully_booted);
	if (verify_data) {
		avb_slot_verify_data_free(verify_data);
	}
	return ret;
}

typedef struct {
	uint8_t *preloaded_img_addr;
	size_t preloaded_img_size;

	uint8_t *preloaded_vbmeta_addr;
	size_t preloaded_vbmeta_size;
} vboot_ramboot_context_t;

#define VBOOT_RAMBOOT_ZBI_PARTITION "ramboot_zbi"

static AvbIOResult
get_preloaded_partition_ramboot(AvbOps *ops, const char *partition,
				size_t num_bytes, uint8_t **out_pointer,
				size_t *out_num_bytes_preloaded)
{
	vboot_ramboot_context_t *context =
		(vboot_ramboot_context_t *)ops->user_data;

	// Only support preloaded `ramboot_zbi` partition
	if (strcmp(partition, VBOOT_RAMBOOT_ZBI_PARTITION)) {
		*out_pointer = NULL;
		return AVB_IO_RESULT_OK;
	}
	*out_num_bytes_preloaded = min(num_bytes, context->preloaded_img_size);
	*out_pointer = context->preloaded_img_addr;

	return AVB_IO_RESULT_OK;
}

static AvbIOResult read_from_partition_ramboot(AvbOps *ops,
					       const char *partition,
					       int64_t offset, size_t num_bytes,
					       void *buffer,
					       size_t *out_num_read)
{
	// Only read "factory" partition, which is validated by vbmeta
	if (!strcmp(partition, "factory")) {
		return read_from_partition(ops, "factory", offset, num_bytes,
					   buffer, out_num_read);
	}

	if (strcmp(partition, "vbmeta") || offset != 0) {
		return AVB_SLOT_VERIFY_RESULT_ERROR_IO;
	}

	vboot_ramboot_context_t *context =
		(vboot_ramboot_context_t *)ops->user_data;

	*out_num_read = min(num_bytes, context->preloaded_vbmeta_size);

	memcpy(buffer, context->preloaded_vbmeta_addr, *out_num_read);

	return AVB_IO_RESULT_OK;
}

/* This is not called, but needs to be non-null. */
static AvbIOResult get_size_of_partition_ramboot(AvbOps *ops,
						 const char *partition,
						 uint64_t *out_size_num_bytes)
{
	return AVB_SLOT_VERIFY_RESULT_ERROR_IO;
}

/* Assume all partitions exist for ramboot. */
static AvbIOResult get_unique_guid_for_partition_ramboot(AvbOps *ops,
							 const char *partition,
							 char *guid_buf,
							 size_t guid_buf_size)
{
	guid_buf[0] = '\0';
	return AVB_IO_RESULT_OK;
}

/* RAM-booting does not increment key versions. */
static void set_key_version_ramboot(AvbAtxOps *atx_ops,
				    size_t rollback_index_location,
				    uint64_t key_version)
{
}

// ramboot_atx_ops and ramboot_ops have all disk write functionality removed.
static AvbAtxOps ramboot_atx_ops = {
	.ops = &ops,
	.read_permanent_attributes = avb_read_permanent_attributes,
	.read_permanent_attributes_hash = avb_read_permanent_attributes_hash,
	.set_key_version = set_key_version_ramboot,
	.get_random = get_random,
};

static AvbOps ramboot_ops = {
	.atx_ops = &ramboot_atx_ops,
	.read_from_partition = read_from_partition_ramboot,
	.get_preloaded_partition = get_preloaded_partition_ramboot,
	.validate_vbmeta_public_key = avb_atx_validate_vbmeta_public_key,
	.read_rollback_index = avb_read_rollback_index,
	.read_is_device_unlocked = read_is_device_unlocked,
	.get_unique_guid_for_partition = get_unique_guid_for_partition_ramboot,
	.get_size_of_partition = get_size_of_partition_ramboot,
	.read_persistent_value = read_persistent_value,
	// This is required to initialize factory digest if the rpmb is empty.
	.write_persistent_value = write_persistent_value,
};

int zircon_vboot_preloaded_img_verify(zbi_header_t *zbi, size_t zbi_size,
				      size_t capacity, unsigned char *vbmeta,
				      size_t vbmeta_size)
{
	bool unlocked;

	if (ops.read_is_device_unlocked(&ops, &unlocked)) {
		fprintf(stderr, "Failed to read lock state.\n");
		return -1;
	}

	if (unlocked && vbmeta_size == 0) {
		return 0;
	}

	const char *const requested_partitions[] = {
		VBOOT_RAMBOOT_ZBI_PARTITION, "factory", NULL
	};

	vboot_ramboot_context_t context = { 0 };
	context.preloaded_img_addr = (uint8_t *)zbi;
	context.preloaded_img_size = zbi_size;
	context.preloaded_vbmeta_addr = vbmeta;
	context.preloaded_vbmeta_size = vbmeta_size;

	ramboot_ops.user_data = (void *)&context;

	int ret = ta_vx_lock_if_ephemerally_unlocked();
	if (ret) {
		return ret;
	}

	AvbSlotVerifyData *verify_data = NULL;

	AvbSlotVerifyResult result =
		avb_slot_verify(&ramboot_ops, requested_partitions, "",
				AVB_SLOT_VERIFY_FLAGS_NONE,
				AVB_HASHTREE_ERROR_MODE_EIO, &verify_data);

	ret = process_verify_data(&ramboot_ops, result, verify_data, NULL,
				      zbi, capacity, false);
	if (verify_data) {
		avb_slot_verify_data_free(verify_data);
	}
	return ret;
}

int zircon_vboot_generate_unlock_challenge(
	AvbAtxUnlockChallenge *out_unlock_challenge)
{
	AvbIOResult ret = avb_atx_generate_unlock_challenge(
		&atx_ops, out_unlock_challenge);
	if (ret != AVB_IO_RESULT_OK) {
		fprintf(stderr, "Failed to generate unlock challenge\n");
		return -1;
	}

	return 0;
}

int zircon_vboot_validate_unlock_credential(
	AvbAtxUnlockCredential *unlock_credential, bool *out_is_trusted)
{
	AvbIOResult ret = avb_atx_validate_unlock_credential(
		&atx_ops, unlock_credential, out_is_trusted);
	if (ret != AVB_IO_RESULT_OK) {
		fprintf(stderr, "Failed to validate unlock challenge\n");
		return -1;
	}

	return 0;
}

int zircon_vboot_is_unlocked(bool *unlocked)
{
	AvbIOResult ret = ops.read_is_device_unlocked(&ops, unlocked);
	if (ret != AVB_IO_RESULT_OK) {
		fprintf(stderr, "Failed to get unlock status\n");
		return -1;
	}

	return 0;
}

#define ZBI_PROPERTY_PREFIX "zbi"

/* If the given property holds a ZBI container, appends its contents to the ZBI
 * container in |lookup_data|. */
static void
process_property(const AvbPropertyDescriptor *prop_desc,
		 const struct property_lookup_user_data *lookup_data)
{
	const char *key =
		(const char *)prop_desc + sizeof(AvbPropertyDescriptor);
	uint64_t offset;
	if (!avb_safe_add(&offset, sizeof(AvbPropertyDescriptor) + 1,
			  prop_desc->key_num_bytes)) {
		fprintf(stderr,
			"Overflow while computing offset for property value."
			"Skipping this property descriptor.\n");
		return;
	}

	const uint8_t *value = (const uint8_t *)prop_desc + offset;
	if (key[prop_desc->key_num_bytes] != 0) {
		fprintf(stderr, "No terminating NUL byte in the property key."
				"Skipping this property descriptor.\n");
		return;
	}
	if (value[prop_desc->value_num_bytes] != 0) {
		fprintf(stderr, "No terminating NUL byte in the property value."
				"Skipping this property descriptor.\n");
		return;
	}

	/* Only look at properties whose keys start with the 'zbi' prefix. */
	if (strncmp(key, ZBI_PROPERTY_PREFIX, strlen(ZBI_PROPERTY_PREFIX))) {
		return;
	}

	const zbi_header_t *vbmeta_zbi = (zbi_header_t *)value;
	printf("Found vbmeta ZBI property '%s' (%llu bytes)\n", key,
	       prop_desc->value_num_bytes);

	const uint64_t zbi_size = sizeof(*vbmeta_zbi) + vbmeta_zbi->length;
	if (zbi_size > prop_desc->value_num_bytes) {
		fprintf(stderr,
			"vbmeta ZBI length exceeds property size (%llu > %llu)\n",
			zbi_size, prop_desc->value_num_bytes);
		return;
	}

	zbi_result_t result = zbi_check(vbmeta_zbi, NULL);
	if (result != ZBI_RESULT_OK) {
		fprintf(stderr, "Mal-formed vbmeta ZBI: error %d\n", result);
		return;
	}

	result =
		zbi_extend(lookup_data->zbi, lookup_data->capacity, vbmeta_zbi);
	if (result != ZBI_RESULT_OK) {
		fprintf(stderr, "Failed to add vbmeta ZBI: error %d\n", result);
		return;
	}
}

/* Callback for vbmeta property iteration. |user_data| must be a pointer to a
 * property_lookup_user_data struct. */
static bool property_lookup_desc_foreach(const AvbDescriptor *header,
					 void *user_data)
{
	if (header->tag != AVB_DESCRIPTOR_TAG_PROPERTY) {
		return true;
	}

	/* recover original bytes order at the end of the function */
	AvbPropertyDescriptor *prop_desc = (AvbPropertyDescriptor *)header;
	if (!avb_property_descriptor_validate_and_byteswap(prop_desc,
							   prop_desc)) {
		return true;
	}

	process_property(prop_desc,
			 (struct property_lookup_user_data *)user_data);

	/* return error if byte order recovering failed */
	if (!avb_property_descriptor_validate_and_byteswap(prop_desc,
							   prop_desc)) {
		fprintf(stderr,
			"failed to recover byte order in a property descriptor.\n");
		return false;
	}
	return true;
}
