// Copyright 2014-2023 Google LLC
//
#include <stdio.h>
#include <zfs_fletcher.h>
#include <sys/zfs_ioctl.h>

#include "zst.h"
#include "zfstlog.h"



/* private type definitions */
struct zst_handle {
    int err;
    int fd;
    char *buf;
    uint64_t off;
    size_t bufsize;     /* amount of memory allocated for the bufs */
    uint64_t objlen;    /* max valid offset+len */
    zio_cksum_t cksum;
    zst_callback_descr_t zc_array[DRR_NUMTYPES]; /* array of callback descriptors */
};

/* callback registration */
int
zst_register_callback(zst_handle_t *h, zst_callback_descr_t *d)
{
    /* check arguments */
    if (h == NULL)
        return (EINVAL);
    if (d == NULL || d->rtype < 0 || d->rtype >= DRR_NUMTYPES)
        return (EINVAL);

    h->zc_array[d->rtype] = d[0];

    return (0);
}

/* fini()/init(), resource accounting */
int
zst_fini(zst_handle_t *h)
{
    if (h == NULL)
        return (EINVAL);

    if (h->buf)
        free(h->buf);
    free(h);

    return (0);
}

zst_handle_t *
zst_init(int fd)
{
    size_t bufsize = 1 << 20; /* 1MB */
    zst_handle_t *himpl = NULL;

    /* alloc memory */
    if ((himpl = malloc(sizeof(zst_handle_t))) == NULL)
        goto err;
    bzero(himpl, sizeof(zst_handle_t));

    if ((himpl->buf = malloc(bufsize)) == NULL)
        goto err;

    /* init the remaining fields of himpl */
    himpl->bufsize = bufsize;
    himpl->objlen = ~0ULL;
    himpl->fd = fd;

    /* return the handle to caller */
    return (himpl);
err:
    /* cleanup */
    if (himpl) {
        if (himpl->buf)
            free(himpl->buf);
        free(himpl);
    }
    /* return invalid handle to client */
    return (NULL);
}

/*
 * Support functions for record traversal
 */
static ssize_t
sread(int fd, char *buffer, size_t length)
{
    ssize_t n = 0;
    int retries=4;

    while (n < length) {
        ssize_t io = read(fd, buffer, length);

        if (io < 0)
            return (-1);
        else if (io == 0)
	{
	    if(--retries <= 0)
                return (n);
	    else
		usleep(1000);
	    LOG_DEBUG("Retrying read: %s", strerror(errno));
	}
        else
            n += io;
    }

    return (n);
}

/* get next section header */
static dmu_replay_record_t *
get_next_section(zst_handle_t *hdl, size_t length)
{
    ssize_t io = sread(hdl->fd, hdl->buf, length);

    LOG_DEBUG("Read %lu bytes should have read %lu", io, length);

    if (io != length)
        return NULL;

    fletcher_4_incremental_native(hdl->buf, length, &hdl->cksum);

    return ((dmu_replay_record_t *)hdl->buf);
}

/* Processing of individual records */
static int
get_begin_section(zst_handle_t *hdl, struct drr_begin *drrb)
{
    if (drrb->drr_magic == BSWAP_64(DMU_BACKUP_MAGIC)) {
        LOG_ERR("Non-native stream format");
        hdl->err = EINVAL;
        return (hdl->err);
    }
    if (drrb->drr_magic != DMU_BACKUP_MAGIC) {
        LOG_ERR("Invalid stream format");
        hdl->err = EINVAL;
        return (hdl->err);
    }

    LOG_DEBUG("hdl->zc_array[DRR_BEGIN].cb=%d", hdl->zc_array[DRR_BEGIN].cb);
    if (hdl->zc_array[DRR_BEGIN].cb) {
        zst_callback_descr_t *d = &hdl->zc_array[DRR_BEGIN];
        return ((*d->cb)(DRR_BEGIN, (void *)drrb, d->arg));
    }

    return (0);
}

static int
get_object_section(zst_handle_t *hdl, struct drr_object *drro)
{
    int rc = 0;
    // void *data = NULL;

    if (drro->drr_type == DMU_OT_NONE ||
        !DMU_OT_IS_VALID(drro->drr_type) ||
        !DMU_OT_IS_VALID(drro->drr_bonustype) ||
        drro->drr_checksumtype >= ZIO_CHECKSUM_FUNCTIONS ||
        drro->drr_compress >= ZIO_COMPRESS_FUNCTIONS ||
        P2PHASE(drro->drr_blksz, SPA_MINBLOCKSIZE) ||
        drro->drr_blksz < SPA_MINBLOCKSIZE ||
        drro->drr_blksz > SPA_MAXBLOCKSIZE ||
        drro->drr_bonuslen > DN_OLD_MAX_BONUSLEN) {
        return (EINVAL);
    }

    if (drro->drr_bonuslen) {
        // data = 
	get_next_section(hdl, P2ROUNDUP(drro->drr_bonuslen, 8));
        if (hdl->err != 0)
            return (hdl->err);
    }

    if (hdl->zc_array[DRR_OBJECT].cb) {
        zst_callback_descr_t *d = &hdl->zc_array[DRR_OBJECT];
        rc = (*d->cb)(DRR_OBJECT, (void *)drro, d->arg);
    }

    return (rc);
}

static int
get_freeobjects_section(zst_handle_t *hdl, struct drr_freeobjects *drrfo)
{
    if (drrfo->drr_firstobj + drrfo->drr_numobjs < drrfo->drr_firstobj)
        return (EINVAL);

    if (hdl->zc_array[DRR_FREEOBJECTS].cb) {
        zst_callback_descr_t *d = &hdl->zc_array[DRR_FREEOBJECTS];
        return ((*d->cb)(DRR_FREEOBJECTS, (void *)drrfo, d->arg));
    }

    return (0);
}

static int
get_write_section(zst_handle_t *hdl, struct drr_write *drrw)
{
    if (drrw->drr_offset + drrw->drr_logical_size < drrw->drr_logical_size ||
        !DMU_OT_IS_VALID(drrw->drr_type))
        return (EINVAL);

    if (hdl->zc_array[DRR_WRITE].cb) {
        zst_callback_descr_t *d = &hdl->zc_array[DRR_WRITE];
        return ((*d->cb)(DRR_WRITE, (void *)drrw, d->arg));
    }

    return (0);
}

/*
 * Handle a DRR_WRITE_EMBEDDED section. This record can be generated if the
 * embedded data feature is enabled, and is very similar to the write DDR_WRITE
 * section, yet it has its own type, and it needs to be handled explicitly.
 */
static int
get_write_embedded_section(zst_handle_t *hdl, struct drr_write_embedded *drrwe)
{
    if (drrwe->drr_offset + drrwe->drr_length < drrwe->drr_offset)
        return (EINVAL);

    if (hdl->zc_array[DRR_WRITE_EMBEDDED].cb) {
        zst_callback_descr_t *d = &hdl->zc_array[DRR_WRITE_EMBEDDED];
        return ((*d->cb)(DRR_WRITE_EMBEDDED, (void *)drrwe, d->arg));
    }

    return (0);
}

/*
 * Handle a DRR_WRITE_BYREF record.  This record is used in dedup'ed
 * streams to refer to a copy of the data that is already on the
 * system because it came in earlier in the stream.  This function
 * finds the earlier copy of the data, and uses that copy instead of
 * data from the stream to fulfill this write.
 */
static int
get_write_byref_section(zst_handle_t *hdl, struct drr_write_byref *drrwbr)
{
    if (drrwbr->drr_offset + drrwbr->drr_length < drrwbr->drr_offset)
        return (EINVAL);

    if (hdl->zc_array[DRR_WRITE_BYREF].cb) {
        zst_callback_descr_t *d = &hdl->zc_array[DRR_WRITE_BYREF];
        return ((*d->cb)(DRR_WRITE_BYREF, (void *)drrwbr, d->arg));
    }

    return (0);
}

static int
get_spill_section(zst_handle_t *hdl, struct drr_spill *drrs)
{
    if (drrs->drr_length < SPA_MINBLOCKSIZE ||
        drrs->drr_length > SPA_MAXBLOCKSIZE)
        return (EINVAL);

    if (hdl->zc_array[DRR_SPILL].cb) {
        zst_callback_descr_t *d = &hdl->zc_array[DRR_SPILL];
        return ((*d->cb)(DRR_SPILL, (void *)drrs, d->arg));
    }

    return (0);
}

static int
get_free_section(zst_handle_t *hdl, struct drr_free *drrf)
{
    if (drrf->drr_length != ~0ULL &&
        drrf->drr_offset + drrf->drr_length < drrf->drr_offset)
        return (EINVAL);

    if (hdl->zc_array[DRR_FREE].cb) {
        zst_callback_descr_t *d = &hdl->zc_array[DRR_FREE];
        return ((*d->cb)(DRR_FREE, (void *)drrf, d->arg));
    }

    return (0);
}

int
zst_traverse(zst_handle_t *hdl)
{
    dmu_replay_record_t *drr = NULL;
    zio_cksum_t pcksum = {{0}};

    if (hdl == NULL)
    {
	    LOG_ERR( "invalid handle");
        return (EINVAL);
    }

    /*
     * Go through the records invoking the registered callbacks
     *
     * We may need to make a copy of the record header, because some calls
     * e.g. get_{object,write} may need to read more,
     * which will invalidate drr.
     */
    while ((hdl->err == 0) &&
           ((drr = get_next_section(hdl, sizeof (*drr))) != NULL)) {

        LOG_INFO( "type=%d",drr->drr_type);
	
        switch (drr->drr_type) {
        case DRR_BEGIN:
        {
            struct drr_begin drrb = drr->drr_u.drr_begin;
            LOG_INFO( "DRR_BEGIN");
            hdl->err = get_begin_section(hdl, &drrb);
            break;
        }
        case DRR_END:
        {
            struct drr_end drre = drr->drr_u.drr_end;
            /*
             * We compare against the *previous* checksum
             * value, because the stored checksum is of
             * everything before the DRR_END record.
             */
            LOG_INFO( "DRR_END");
            if (!ZIO_CHECKSUM_EQUAL(drre.drr_checksum, pcksum))
                hdl->err = ECKSUM;
            if (hdl->zc_array[DRR_END].cb) {
                zst_callback_descr_t *d = &hdl->zc_array[DRR_END];
                hdl->err = (*d->cb)(DRR_END, &drre, d->arg);
            }
            goto done;
            break;
        }
        case DRR_OBJECT:
        {
            struct drr_object drro = drr->drr_u.drr_object;
            LOG_INFO( "DRR_OBJECT");
            hdl->err = get_object_section(hdl, &drro);
            break;
        }
        case DRR_FREEOBJECTS:
        {
            struct drr_freeobjects drrfo =
                drr->drr_u.drr_freeobjects;
            LOG_INFO( "DRR_FREEOBJECT");
            hdl->err = get_freeobjects_section(hdl, &drrfo);
            break;
        }
        case DRR_WRITE:
        {
            struct drr_write drrw = drr->drr_u.drr_write;
            LOG_INFO( "DRR_WRITE");
            hdl->err = get_write_section(hdl, &drrw);
            break;
        }
        case DRR_WRITE_EMBEDDED:
        {
            struct drr_write_embedded drrwe =
                drr->drr_u.drr_write_embedded;
            LOG_INFO( "DRR_EMBEDDED");
            hdl->err = get_write_embedded_section(hdl, &drrwe);
            break;
        }
        case DRR_WRITE_BYREF:
        {
            struct drr_write_byref drrwbr =
                drr->drr_u.drr_write_byref;
            LOG_INFO( "DRR_BYREF");
            hdl->err = get_write_byref_section(hdl, &drrwbr);
            break;
        }
        case DRR_FREE:
        {
            struct drr_free drrf = drr->drr_u.drr_free;
            LOG_INFO( "DRR_FREE");
            hdl->err = get_free_section(hdl, &drrf);
            break;
        }
        case DRR_SPILL:
        {
            struct drr_spill drrs = drr->drr_u.drr_spill;
            LOG_INFO( "DRR_SPILL");
            hdl->err = get_spill_section(hdl, &drrs);
            break;
        }
        default:
            LOG_INFO( "default");
         //   hdl->err = EINVAL;
              hdl->err = 0;
        }

        pcksum = hdl->cksum;
    }

done:
    /* checked for well-formed stream */
    if (drr == NULL || hdl->err) {
        LOG_ERR( "Incomplete/invalid stream format, status %d (%d)", hdl->err, (drr) ? drr->drr_type : 9999999);
        return (-1);
    }

    /* Done */
    return (0);
}
