#!/usr/bin/env python3
"""Converts a NAND+OOB dump into a flashable image.

To extract the NAND+OOB image:
  $ fastboot oem stage-partition-oob <partition>
  $ fastboot get_staged <file>

To flash it back after converting:
  $ fastboot flash <partition> <converted_file>

In general, we can't exactly replicate the source NAND+OOB state since each
device has its own bad block map, so trying to replay a dump from one device
onto a different device may not work as expected.
"""

import argparse
import logging
import struct
import sys
from typing import List

NAND_PAGE_SIZE = 4096
NAND_OOB_SIZE = 8
NANDOOB_PAGE_SIZE = NAND_PAGE_SIZE + NAND_OOB_SIZE

ESTELLE_NUM_FVM_PAGES = 0x1AD00
ESTELLE_PAGES_PER_BLOCK = 64

BADBLOCK_PAGE = b"BADBLOCK" + (b"\x00" * 4096)
READFAIL_PAGE = b"READFAIL" + (b"\x00" * 4096)
EMPTY_PAGE = b"\xFF" * NANDOOB_PAGE_SIZE

FVM_FTL_IMAGE_MAGIC = 0x12A17178711A711D
FVM_FTL_IMAGE_MAJOR = 1
FVM_FTL_IMAGE_MINOR = 1
FVM_FTL_IMAGE_FLAG_REQUIRE_WIPE = 1
FVM_FTL_IMAGE_FORMAT_RAW = 0

ZBI_TYPE_CONTAINER = 0x544f4f42
ZBI_CONTAINER_MAGIC = 0x868cf7e6
ZBI_ITEM_MAGIC = 0xb5781729


def nandoob_to_pages(nandoob: bytes) -> List[bytes]:
    """Splits a NAND+OOB dump into pages.

    Args:
        nandoob: NAND+OOB dump from `oem stage-partition-oob`

    Returns:
        A list of 4096+8 byte pages, does not include any BADBLOCK pages or
        trailing empty pages.
    """
    pages = []
    empty_count = 0
    index = 0
    readfails = []
    for (page,) in struct.iter_unpack(f"{NANDOOB_PAGE_SIZE}s", nandoob):
        if page == BADBLOCK_PAGE:
            # I *think* the NDM driver should transparently swap out bad blocks
            # on each device so that the FTL doesn't have to know or care about
            # them, which should mean that we can just ignore them on both the
            # original dump and the target device.
            logging.debug("Skipping BADBLOCK page %s", index)
        elif page == READFAIL_PAGE:
            # The best we can do with read errors is leave it empty and hope it
            # wasn't mapped by the FTL. We'll print out a warning at the end.
            readfails.append(index)
            empty_count += 1
        elif page == EMPTY_PAGE:
            # No point writing trailing empty pages, just buffer it for now.
            empty_count += 1
        else:
            # Flush any pending empty pages now that we've encountered data.
            pages += ([EMPTY_PAGE] * empty_count) + [page]
            empty_count = 0
        index += 1

    if readfails:
        logging.warning("Found READFAIL page(s): %s", readfails)
        logging.warning("Leaving contents at 0xFF on-device")
        logging.warning("If this page was in use, expect errors")

    return pages


def pages_to_fvm_image(pages: List[bytes]) -> bytes:
    """Converts a list of NAND+OOB pages into a fastboot flashable FVM image."""
    # Log the number of unused pages and blocks.
    unused_pages = ESTELLE_NUM_FVM_PAGES - len(pages)
    unused_blocks = unused_pages // ESTELLE_PAGES_PER_BLOCK
    logging.info("Unused pages: %s", unused_pages)
    logging.info("Target device can have at most %s bad blocks", unused_blocks)

    # Header format:
    #  uint64_t magic
    #  uint32_t version_major
    #  uint32_t version_minor
    #  uint32_t flags
    #  uint32_t format
    #  uint32_t page_size
    #  uint8_t oob_size
    #  uint8_t reserved[3] (0xFF)
    header = struct.pack("<QIIIII4B", FVM_FTL_IMAGE_MAGIC, FVM_FTL_IMAGE_MAJOR,
                         FVM_FTL_IMAGE_MINOR, FVM_FTL_IMAGE_FLAG_REQUIRE_WIPE,
                         FVM_FTL_IMAGE_FORMAT_RAW, NAND_PAGE_SIZE,
                         NAND_OOB_SIZE, 0xFF, 0xFF, 0xFF)
    return header + b''.join(pages)


def pages_to_skip_block_image(pages: List[bytes]) -> bytes:
    """Converts a list of NAND+OOB pages into a skip-block image.

    Skip-block images just ignore OOB data and bad blocks altogether. No
    wear-leveling is applied, bad blocks are just skipped entirely.
    """
    return b''.join([p[:NAND_PAGE_SIZE] for p in pages])


def pages_to_zbi_image(pages: List[bytes]) -> bytes:
    """A ZBI image is just a skip-block image trunctated to the ZBI size."""
    # ZBI header format:
    #   uint32_t type
    #   uint32_t length
    #   uint32_t extra
    #   uint32_t flags
    #   uint32_t reserved0
    #   uint32_t reserved1
    #   uint32_t magic
    #   uint32_t crc32
    header_format = "<IIIIIIII"
    header_size = struct.calcsize(header_format)
    (type, length, extra, _, _, _, magic,
     _) = struct.unpack_from(header_format, pages[0])
    if not (type == ZBI_TYPE_CONTAINER and extra == ZBI_CONTAINER_MAGIC and
            magic == ZBI_ITEM_MAGIC):
        raise ValueError(f"Image does not look like a ZBI")

    # Return just the ZBI according to the size indicated in the header.
    entire_partition = pages_to_skip_block_image(pages)
    return entire_partition[:header_size + length]


def _parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument("type",
                        choices=["fvm", "skip_block", "zbi"],
                        help="Image type. 'fvm' creates a fastboot-flashable"
                        " FVM image. 'skip_block' strips out OOB bytes and bad"
                        " blocks and returns just the data. 'zbi' does"
                        " skip_block processing but also truncates to just the"
                        " ZBI image.")
    parser.add_argument("source", help="Raw stage-partition-oob data")
    parser.add_argument(
        "dest",
        nargs="?",
        help="Output file, <source>.nandoob_to_flashable by default")
    parser.add_argument("-f",
                        "--force",
                        action="store_true",
                        help="Overwrite dest if it exists")

    args = parser.parse_args()

    # Set the default dest if necessary.
    if not args.dest:
        args.dest = f"{args.source}.nandoob_to_{args.type}"

    return args


def _main() -> int:
    logging.basicConfig(level=logging.INFO)
    args = _parse_args()

    with open(args.source, "rb") as file:
        nandoob = file.read()

    pages = nandoob_to_pages(nandoob)
    if args.type == "fvm":
        image = pages_to_fvm_image(pages)
    elif args.type == "skip_block":
        image = pages_to_skip_block_image(pages)
    elif args.type == "zbi":
        image = pages_to_zbi_image(pages)

    with open(args.dest, "wb" if args.force else "xb") as file:
        file.write(image)

    return 0


if __name__ == "__main__":
    sys.exit(_main())
