// SPDX-License-Identifier: GPL-2.0+

#include <common.h>
#include <malloc.h>
#include "memory_pool.h"

/**
 * An implementation of the simple and efficient memory pool allocator
 * described in the following article.
 *
 * http://www.thinkmind.org/download.php?articleid=computation_tools_2012_1_10_80006
 */
struct memory_pool {
	uintptr_t base;		// base address of memory pool.
	size_t block_size;	// fixed size in bytes of each block.
	size_t num_total;	// total number of blocks in pool.
	size_t num_free;	// number of free blocks in pool.
	size_t num_initialized;	// number of initialized blocks in pool.
	size_t head;		// index of the next allocatable block. i.e. HEAD node of the
						// implicit free-blocks list.
};

static bool is_power_of_2(size_t n) { return (n & (n - 1)) == 0; }

/* We could store a size_t index to point to the next free block at the
   beginning of each free block. */
#define MIN_BLOCK_SIZE (sizeof(size_t))

#define MIN_ALIGNMENT (sizeof(void *))

static int init_pool(struct memory_pool *pool, uintptr_t pool_start,
		uintptr_t pool_end, size_t block_size)
{
	if (!pool || !pool_start)
		return -EINVAL;

	if (block_size < MIN_BLOCK_SIZE)
		return -EINVAL;

	pool_start = round_up(pool_start, MIN_ALIGNMENT);

	// We need space for at least 1 block.
	if (pool_start + block_size > pool_end)
		return -ENOSPC;

	pool->base = pool_start;
	pool->block_size = block_size;
	pool->num_total = (pool_end - pool_start) / block_size;
	pool->num_free = pool->num_total;
	pool->num_initialized = 0;
	pool->head = 0;

	return 0;
}

struct memory_pool *memory_pool_create(uintptr_t pool_start,
		size_t pool_size, size_t block_size)
{
	struct memory_pool *new_pool = malloc(sizeof(struct memory_pool));

	if (!new_pool)
		return NULL;

	if (init_pool(new_pool, pool_start, pool_start + pool_size, block_size)) {
		free(new_pool);
		return NULL;
	}

	return new_pool;
}

void memory_pool_destroy(struct memory_pool *pool)
{
	free(pool);
}

static uintptr_t get_block_address(struct memory_pool *pool, size_t index)
{
	if (index < pool->num_total)
		return pool->base + index * pool->block_size;

	return 0;
}

static size_t get_block_index(struct memory_pool *pool, void *addr)
{
	uintptr_t ptr = (uintptr_t) addr;

	if (ptr < pool->base ||
	    ptr > (pool->base + pool->block_size * pool->num_total)) {
		return pool->num_total;  // tail.
	}

	return ((ptr - pool->base) / pool->block_size);
}

/**
 * Allocates a buffer from the memory pool.
 *
 * @pool -  pointer to an initialized pool.
 * @align - desired buffer address alignment, zero or a power of 2.
 * @bytes - desired buffer size in bytes.
 *
 * Returns NULL on error else buffer address.
 */
void *memory_pool_allocate(struct memory_pool *pool, size_t align, size_t bytes)
{
	if (!pool || !bytes)
		return NULL;

	if (align < MIN_ALIGNMENT)
		align = MIN_ALIGNMENT;

	if (!is_power_of_2(align)) {
		printf("%s: error: bad alignment: %zu\n", __func__, align);
		return NULL;
	}

	if (pool->num_free == 0) {
		printf("%s: error: out of free blocks.\n", __func__);
		return NULL;
	}

	/* Initialize one new block (lazy expansion of the free-blocks list) on each
	 * allocation. */
	if (pool->num_initialized < pool->num_total) {
		size_t *next =
		    	(size_t *)get_block_address(pool, pool->num_initialized);
		*next = pool->num_initialized + 1;
		pool->num_initialized++;
	}

	uintptr_t p_block = get_block_address(pool, pool->head);
	uintptr_t ptr = round_up((uintptr_t) p_block, align);
	if (ptr + bytes > p_block + pool->block_size) {
		printf("%s: error: block_size (%zu) too small.\n", __func__,
				pool->block_size);
		return NULL;
	}

	pool->head = *(size_t *)p_block;
	pool->num_free--;

	return (void *)ptr;
}

/**
 * Returns an allocated buffer to the memory pool.
 *
 * @pool - pointer to an initialized memory pool.
 * @ptr -  pointer to an previously allocated buffer.
 */
void memory_pool_deallocate(struct memory_pool *pool, void *ptr)
{
	if (!ptr)
		return;

	size_t index = get_block_index(pool, ptr);
	if (index == pool->num_total) {
		panic("index out of range.");
	}

	/* insert new block at head. */
	uintptr_t p_block = get_block_address(pool, index);
	*(size_t *)p_block = (pool->num_free) ? pool->head : pool->num_total;
	pool->head = index;
	pool->num_free++;

	return;
}

#ifdef MEMORY_POOL_TEST

static int unit_test(cmd_tbl_t * cmdtp, int flag, int argc, char *const argv[])
{
	size_t num_blocks = 4;
	size_t block_size = 70;  // block size does not have to be power of 2.
	size_t pool_size = num_blocks * block_size;

	// static such that if our test fails across multiple invocations we don't
	// waste unbounded amount memory.
	static uintptr_t pool_start = 0;
	if (!pool_start) {
		pool_start = (uintptr_t)malloc(pool_size);
		assert(pool_start);
	}

	struct memory_pool *pool = memory_pool_create (pool_start, pool_size,
			block_size);
	assert(pool);

	// typical allocation/deallocation path.
	void *p1 = memory_pool_allocate(pool, /* align = */ 0, /* bytes = */ 1);
	void *p2 = memory_pool_allocate(pool, /* align = */ 0, /* bytes = */ 1);
	void *p3 = memory_pool_allocate(pool, /* align = */ 0, /* bytes = */ 1);
	assert (p1 && p2 && p3);

	memory_pool_deallocate(pool, p1);
	memory_pool_deallocate(pool, p2);
	void *p = memory_pool_allocate(pool, /* align = */ 0, /* bytes = */ 1);
	assert(p && pool->num_free == 2 && pool->num_initialized == 4);
	p = memory_pool_allocate(pool, /* align = */ 0, /* bytes = */ 1);
	assert(p);
	p = memory_pool_allocate(pool, /* align = */ 0, /* bytes = */ 1);
	assert(p);
	p = memory_pool_allocate(pool, /* align = */ 0, /* bytes = */ 1);
	assert(!p);  // out of blocks

	// RESET.
	memory_pool_destroy(pool);
	pool = memory_pool_create (pool_start, pool_size, block_size);
	assert(pool);

	// bad alignment (not power of 2).
	p = memory_pool_allocate(pool, /* align = */ 10, /* bytes = */ 1);
	assert(!p);

	// valid alignment (best effort check).
	p = memory_pool_allocate(pool, /* align = */ 64, /* bytes = */ 1);
	assert(p && !((uintptr_t)p % 64));

	// bytes > block_size.
	p = memory_pool_allocate(pool, /* align = */ 0, /* bytes = */ block_size + 1);
	assert(!p);

	// RESET.
	memory_pool_destroy(pool);
	pool = memory_pool_create (pool_start, pool_size, block_size);
	assert(pool);

	// deallocating a ptr that has a larger alignment than block_start
	// (simulated).
	p = memory_pool_allocate(pool, /* align = */ 0, /* bytes = */ 64);
	assert(p && pool->num_free == (num_blocks - 1));
	memory_pool_deallocate(pool, (void *)((uintptr_t)p + 8));
	assert(pool->num_free = num_blocks);

	memory_pool_destroy(pool);

	// If we do end up passing the test, clean up.
	free((void *)pool_start);
	pool_start = 0;

	printf("PASS (ignore any error logs)!\n");

    return 0;
}

U_BOOT_CMD(test_memory_pool, /* maxargs = */ 1, /* repeatable = */ 1,
		unit_test, "Run unit tests", NULL);

#endif  /* MEMORY_POOL_TEST */
