# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""

import io
import json
import struct
import typing as tp

# format is `ECDC` magic code, followed by the header size as uint32.
# Then an uint8 indicates the protocol version (0.)
# The header is then provided as json and should contain all required
# informations for decoding. A raw stream of bytes is then provided
# and should be interpretable using the json header.
_encodec_header_struct = struct.Struct('!4sBI')
_ENCODEC_MAGIC = b'ECDC'


def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any):
    meta_dumped = json.dumps(metadata).encode('utf-8')
    version = 0
    header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped))
    fo.write(header)
    fo.write(meta_dumped)
    fo.flush()


def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes:
    buf = b""
    while len(buf) < size:
        new_buf = fo.read(size)
        if not new_buf:
            raise EOFError("Impossible to read enough data from the stream, "
                           f"{size} bytes remaining.")
        buf += new_buf
        size -= len(new_buf)
    return buf


def read_ecdc_header(fo: tp.IO[bytes]):
    header_bytes = _read_exactly(fo, _encodec_header_struct.size)
    magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
    if magic != _ENCODEC_MAGIC:
        raise ValueError("File is not in ECDC format.")
    if version != 0:
        raise ValueError("Version not supported.")
    meta_bytes = _read_exactly(fo, meta_size)
    return json.loads(meta_bytes.decode('utf-8'))


class BitPacker:
    """Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
    Note that for some bandwidth (1.5, 3), the codebook representation
    will not cover an integer number of bytes.

    Args:
        bits (int): number of bits per value that will be pushed.
        fo (IO[bytes]): file-object to push the bytes to.
    """
    def __init__(self, bits: int, fo: tp.IO[bytes]):
        self._current_value = 0
        self._current_bits = 0
        self.bits = bits
        self.fo = fo

    def push(self, value: int):
        """Push a new value to the stream. This will immediately
        write as many uint8 as possible to the underlying file-object."""
        self._current_value += (value << self._current_bits)
        self._current_bits += self.bits
        while self._current_bits >= 8:
            lower_8bits = self._current_value & 0xff
            self._current_bits -= 8
            self._current_value >>= 8
            self.fo.write(bytes([lower_8bits]))

    def flush(self):
        """Flushes the remaining partial uint8, call this at the end
        of the stream to encode."""
        if self._current_bits:
            self.fo.write(bytes([self._current_value]))
            self._current_value = 0
            self._current_bits = 0
        self.fo.flush()


class BitUnpacker:
    """BitUnpacker does the opposite of `BitPacker`.

    Args:
        bits (int): number of bits of the values to decode.
        fo (IO[bytes]): file-object to push the bytes to.
        """
    def __init__(self, bits: int, fo: tp.IO[bytes]):
        self.bits = bits
        self.fo = fo
        self._mask = (1 << bits) - 1
        self._current_value = 0
        self._current_bits = 0

    def pull(self) -> tp.Optional[int]:
        """
        Pull a single value from the stream, potentially reading some
        extra bytes from the underlying file-object.
        Returns `None` when reaching the end of the stream.
        """
        while self._current_bits < self.bits:
            buf = self.fo.read(1)
            if not buf:
                return None
            character = buf[0]
            self._current_value += character << self._current_bits
            self._current_bits += 8

        out = self._current_value & self._mask
        self._current_value >>= self.bits
        self._current_bits -= self.bits
        return out


def test():
    import torch
    torch.manual_seed(1234)
    for rep in range(4):
        length: int = torch.randint(10, 2_000, (1,)).item()
        bits: int = torch.randint(1, 16, (1,)).item()
        tokens: tp.List[int] = torch.randint(2 ** bits, (length,)).tolist()
        rebuilt: tp.List[int] = []
        buf = io.BytesIO()
        packer = BitPacker(bits, buf)
        for token in tokens:
            packer.push(token)
        packer.flush()
        buf.seek(0)
        unpacker = BitUnpacker(bits, buf)
        while True:
            value = unpacker.pull()
            if value is None:
                break
            rebuilt.append(value)
        assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
        # The flushing mechanism might lead to "ghost" values at the end of the stream.
        assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), len(tokens), bits)
        for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
            assert a == b, (idx, a, b)


if __name__ == '__main__':
    test()
