import math
from enum import Enum
from typing import Any, Optional

from aiortc.utils import uint32_add, uint32_gt

BURST_DELTA_THRESHOLD_MS = 5

# overuse detector
MAX_ADAPT_OFFSET_MS = 15
MIN_NUM_DELTAS = 60

# overuse estimator
DELTA_COUNTER_MAX = 1000
MIN_FRAME_PERIOD_HISTORY_LENGTH = 60

# abs-send-time estimator
INTER_ARRIVAL_SHIFT = 26
TIMESTAMP_GROUP_LENGTH_MS = 5
TIMESTAMP_TO_MS = 1000.0 / (1 << INTER_ARRIVAL_SHIFT)


class BandwidthUsage(Enum):
    NORMAL = 0
    UNDERUSING = 1
    OVERUSING = 2


class RateControlState(Enum):
    HOLD = 0
    INCREASE = 1
    DECREASE = 2


class AimdRateControl:
    def __init__(self) -> None:
        self.avg_max_bitrate_kbps: Optional[float] = None
        self.var_max_bitrate_kbps = 0.4
        self.current_bitrate = 30000000
        self.current_bitrate_initialized = False
        self.first_estimated_throughput_time: Optional[int] = None
        self.last_change_ms: Optional[int] = None
        self.near_max = False
        self.latest_estimated_throughput = 30000000
        self.rtt = 200
        self.state = RateControlState.HOLD

    def feedback_interval(self) -> int:
        return 500

    def set_estimate(self, bitrate: int, now_ms: int) -> None:
        """
        For testing purposes.
        """
        self.current_bitrate = self._clamp_bitrate(bitrate, bitrate)
        self.current_bitrate_initialized = True
        self.last_change_ms = now_ms

    def update(
        self,
        bandwidth_usage: BandwidthUsage,
        estimated_throughput: Optional[int],
        now_ms: int,
    ) -> Optional[int]:
        if not self.current_bitrate_initialized and estimated_throughput is not None:
            if self.first_estimated_throughput_time is None:
                self.first_estimated_throughput_time = now_ms
            elif now_ms - self.first_estimated_throughput_time > 3000:
                self.current_bitrate = estimated_throughput
                self.current_bitrate_initialized = True

        # wait for initialisation or overuse
        if (
            not self.current_bitrate_initialized
            and bandwidth_usage != BandwidthUsage.OVERUSING
        ):
            return None

        # update state
        if (
            bandwidth_usage == BandwidthUsage.NORMAL
            and self.state == RateControlState.HOLD
        ):
            self.last_change_ms = now_ms
            self.state = RateControlState.INCREASE
        elif bandwidth_usage == BandwidthUsage.OVERUSING:
            self.state = RateControlState.DECREASE
        elif bandwidth_usage == BandwidthUsage.UNDERUSING:
            self.state = RateControlState.HOLD

        # helper variables
        new_bitrate = self.current_bitrate
        if estimated_throughput is not None:
            self.latest_estimated_throughput = estimated_throughput
        else:
            estimated_throughput = self.latest_estimated_throughput
        estimated_throughput_kbps = estimated_throughput / 1000

        # update bitrate
        if self.state == RateControlState.INCREASE:
            # if the estimated throughput increases significantly,
            # clear estimated max throughput
            if self.avg_max_bitrate_kbps is not None:
                sigma_kbps = math.sqrt(
                    self.var_max_bitrate_kbps * self.avg_max_bitrate_kbps
                )
                if (
                    estimated_throughput_kbps
                    >= self.avg_max_bitrate_kbps + 3 * sigma_kbps
                ):
                    self.near_max = False
                    self.avg_max_bitrate_kbps = None

            # we use additive or multiplicative rate increase depending on whether
            # we are close to the maximum throughput
            if self.near_max:
                new_bitrate += self._additive_rate_increase(self.last_change_ms, now_ms)
            else:
                new_bitrate += self._multiplicative_rate_increase(
                    new_bitrate, self.last_change_ms, now_ms
                )
            self.last_change_ms = now_ms
        elif self.state == RateControlState.DECREASE:
            # if the estimated throughput drops significantly,
            # clear estimated max throughput
            if self.avg_max_bitrate_kbps is not None:
                sigma_kbps = math.sqrt(
                    self.var_max_bitrate_kbps * self.avg_max_bitrate_kbps
                )
                if (
                    estimated_throughput_kbps
                    < self.avg_max_bitrate_kbps - 3 * sigma_kbps
                ):
                    self.avg_max_bitrate_kbps = None
            self._update_max_throughput_estimate(estimated_throughput_kbps)

            self.near_max = True
            new_bitrate = round(0.85 * estimated_throughput)
            self.last_change_ms = now_ms
            self.state = RateControlState.HOLD

        self.current_bitrate = self._clamp_bitrate(new_bitrate, estimated_throughput)
        return self.current_bitrate

    def _additive_rate_increase(self, last_ms: int, now_ms: int) -> int:
        return int((now_ms - last_ms) * self._near_max_rate_increase() / 1000)

    def _clamp_bitrate(self, new_bitrate: int, estimated_throughput: int) -> int:
        max_bitrate = max(int(1.5 * estimated_throughput) + 10000, self.current_bitrate)
        return min(new_bitrate, max_bitrate)

    def _multiplicative_rate_increase(
        self, new_bitrate: int, last_ms: int, now_ms: int
    ) -> int:
        alpha = 1.08
        if last_ms is not None:
            elapsed_ms = min(now_ms - last_ms, 1000)
            alpha = pow(alpha, elapsed_ms / 1000)
        return int(max((alpha - 1) * new_bitrate, 1000))

    def _near_max_rate_increase(self) -> int:
        bits_per_frame = self.current_bitrate / 30
        packets_per_frame = math.ceil(bits_per_frame / (8 * 1200))
        avg_packet_size_bits = bits_per_frame / packets_per_frame

        response_time = self.rtt + 100
        return max(4000, int((avg_packet_size_bits * 1000) / response_time))

    def _update_max_throughput_estimate(self, estimated_throughput_kbps: float) -> None:
        alpha = 0.05
        if self.avg_max_bitrate_kbps is None:
            self.avg_max_bitrate_kbps = estimated_throughput_kbps
        else:
            self.avg_max_bitrate_kbps = (
                1 - alpha
            ) * self.avg_max_bitrate_kbps + alpha * estimated_throughput_kbps

        norm = max(1, self.avg_max_bitrate_kbps)
        self.var_max_bitrate_kbps = (1 - alpha) * self.var_max_bitrate_kbps + alpha * (
            (self.avg_max_bitrate_kbps - estimated_throughput_kbps) ** 2
        ) / norm
        self.var_max_bitrate_kbps = max(0.4, min(self.var_max_bitrate_kbps, 2.5))


class TimestampGroup:
    def __init__(self, timestamp: Optional[int] = None) -> None:
        self.arrival_time: Optional[int] = None
        self.first_timestamp = timestamp
        self.last_timestamp = timestamp
        self.size = 0


class InterArrivalDelta:
    def __init__(self, timestamp: int, arrival_time: int, size: int) -> None:
        self.timestamp = timestamp
        self.arrival_time = arrival_time
        self.size = size


class InterArrival:
    """
    Inter-arrival time and size filter.

    Adapted from the webrtc.org codebase.
    """

    def __init__(self, group_length: int, timestamp_to_ms: float) -> None:
        self.group_length = group_length
        self.timestamp_to_ms = timestamp_to_ms
        self.current_group: Optional[TimestampGroup] = None
        self.previous_group: Optional[TimestampGroup] = None

    def compute_deltas(
        self, timestamp: int, arrival_time: int, packet_size: int
    ) -> Optional[InterArrivalDelta]:
        deltas = None
        if self.current_group is None:
            self.current_group = TimestampGroup(timestamp)
        elif self.packet_out_of_order(timestamp):
            return deltas
        elif self.new_timestamp_group(timestamp, arrival_time):
            if self.previous_group is not None:
                deltas = InterArrivalDelta(
                    timestamp=uint32_add(
                        self.current_group.last_timestamp,
                        -self.previous_group.last_timestamp,
                    ),
                    arrival_time=(
                        self.current_group.arrival_time
                        - self.previous_group.arrival_time
                    ),
                    size=self.current_group.size - self.previous_group.size,
                )

            # shift groups
            self.previous_group = self.current_group
            self.current_group = TimestampGroup(timestamp=timestamp)
        elif uint32_gt(timestamp, self.current_group.last_timestamp):
            self.current_group.last_timestamp = timestamp

        self.current_group.size += packet_size
        self.current_group.arrival_time = arrival_time

        return deltas

    def belongs_to_burst(self, timestamp: int, arrival_time: int) -> bool:
        timestamp_delta = uint32_add(timestamp, -self.current_group.last_timestamp)
        timestamp_delta_ms = round(self.timestamp_to_ms * timestamp_delta)
        arrival_time_delta = arrival_time - self.current_group.arrival_time
        return timestamp_delta_ms == 0 or (
            (arrival_time_delta - timestamp_delta_ms) < 0
            and arrival_time_delta <= BURST_DELTA_THRESHOLD_MS
        )

    def new_timestamp_group(self, timestamp: int, arrival_time: int) -> bool:
        if self.belongs_to_burst(timestamp, arrival_time):
            return False
        else:
            timestamp_delta = uint32_add(timestamp, -self.current_group.first_timestamp)
            return timestamp_delta > self.group_length

    def packet_out_of_order(self, timestamp: int) -> bool:
        timestamp_delta = uint32_add(timestamp, -self.current_group.first_timestamp)
        return timestamp_delta >= 0x80000000


class OveruseDetector:
    """
    Bandwidth overuse detector.

    Adapted from the webrtc.org codebase.
    """

    def __init__(self) -> None:
        self.hypothesis = BandwidthUsage.NORMAL
        self.last_update_ms: Optional[int] = None
        self.k_up = 0.0087
        self.k_down = 0.039
        self.overuse_counter = 0
        self.overuse_time: Optional[float] = None
        self.overuse_time_threshold = 10
        self.previous_offset = 0.0
        self.threshold = 12.5

    def detect(
        self, offset: float, timestamp_delta_ms: float, num_of_deltas: int, now_ms: int
    ) -> BandwidthUsage:
        if num_of_deltas < 2:
            return BandwidthUsage.NORMAL

        T = min(num_of_deltas, MIN_NUM_DELTAS) * offset
        if T > self.threshold:
            if self.overuse_time is None:
                self.overuse_time = timestamp_delta_ms / 2
            else:
                self.overuse_time += timestamp_delta_ms
            self.overuse_counter += 1

            if (
                self.overuse_time > self.overuse_time_threshold
                and self.overuse_counter > 1
                and offset >= self.previous_offset
            ):
                self.overuse_counter = 0
                self.overuse_time = 0
                self.hypothesis = BandwidthUsage.OVERUSING
        elif T < -self.threshold:
            self.overuse_counter = 0
            self.overuse_time = None
            self.hypothesis = BandwidthUsage.UNDERUSING
        else:
            self.overuse_counter = 0
            self.overuse_time = None
            self.hypothesis = BandwidthUsage.NORMAL

        self.previous_offset = offset
        self.update_threshold(T, now_ms)
        return self.hypothesis

    def state(self) -> BandwidthUsage:
        return self.hypothesis

    def update_threshold(self, modified_offset: float, now_ms: int) -> None:
        if self.last_update_ms is None:
            self.last_update_ms = now_ms

        if abs(modified_offset) > self.threshold + MAX_ADAPT_OFFSET_MS:
            self.last_update_ms = now_ms
            return

        k = self.k_down if abs(modified_offset) < self.threshold else self.k_up
        time_delta_ms = min(now_ms - self.last_update_ms, 100)
        self.threshold += k * (abs(modified_offset) - self.threshold) * time_delta_ms
        self.threshold = max(6, min(self.threshold, 600))
        self.last_update_ms = now_ms


class OveruseEstimator:
    """
    Bandwidth overuse estimator.

    Adapted from the webrtc.org codebase.
    """

    def __init__(self) -> None:
        self.E = [[100.0, 0.0], [0.0, 0.1]]
        self._num_of_deltas = 0
        self._offset = 0.0
        self.previous_offset = 0.0
        self.slope = 1 / 64
        self.ts_delta_hist: list[float] = []

        self.avg_noise = 0.0
        self.var_noise = 50.0
        self.process_noise = [1e-13, 1e-3]

    def num_of_deltas(self) -> int:
        return self._num_of_deltas

    def offset(self) -> float:
        return self._offset

    def update(
        self,
        time_delta_ms: int,
        timestamp_delta_ms: float,
        size_delta: int,
        current_hypothesis: BandwidthUsage,
        now_ms: int,
    ) -> None:
        min_frame_period = self.update_min_frame_period(timestamp_delta_ms)
        t_ts_delta = time_delta_ms - timestamp_delta_ms
        fs_delta = size_delta

        self._num_of_deltas = min(self._num_of_deltas + 1, DELTA_COUNTER_MAX)

        # update  Kalman filter
        self.E[0][0] += self.process_noise[0]
        self.E[1][1] += self.process_noise[1]
        if (
            current_hypothesis == BandwidthUsage.OVERUSING
            and self._offset < self.previous_offset
        ) or (
            current_hypothesis == BandwidthUsage.UNDERUSING
            and self._offset > self.previous_offset
        ):
            self.E[1][1] += 10 * self.process_noise[1]

        h = [fs_delta, 1.0]
        Eh = [
            self.E[0][0] * h[0] + self.E[0][1] * h[1],
            self.E[1][0] * h[0] + self.E[1][1] * h[1],
        ]

        # update noise estimate
        residual = t_ts_delta - self.slope * h[0] - self._offset
        if current_hypothesis == BandwidthUsage.NORMAL:
            max_residual = 3.0 * math.sqrt(self.var_noise)
            if abs(residual) < max_residual:
                self.update_noise_estimate(residual, min_frame_period)
            else:
                self.update_noise_estimate(
                    -max_residual if residual < 0 else max_residual, min_frame_period
                )

        denom = self.var_noise + h[0] * Eh[0] + h[1] * Eh[1]
        K = [Eh[0] / denom, Eh[1] / denom]

        IKh = [[1.0 - K[0] * h[0], -K[0] * h[1]], [-K[1] * h[0], 1.0 - K[1] * h[1]]]
        e00 = self.E[0][0]
        e01 = self.E[0][1]

        # update state
        self.E[0][0] = e00 * IKh[0][0] + self.E[1][0] * IKh[0][1]
        self.E[0][1] = e01 * IKh[0][0] + self.E[1][1] * IKh[0][1]
        self.E[1][0] = e00 * IKh[1][0] + self.E[1][0] * IKh[1][1]
        self.E[1][1] = e01 * IKh[1][0] + self.E[1][1] * IKh[1][1]

        self.previous_offset = self._offset
        self.slope += K[0] * residual
        self._offset += K[1] * residual

    def update_min_frame_period(self, ts_delta: float) -> float:
        min_frame_period = ts_delta
        if len(self.ts_delta_hist) >= MIN_FRAME_PERIOD_HISTORY_LENGTH:
            self.ts_delta_hist.pop(0)

        for old_ts_delta in self.ts_delta_hist:
            min_frame_period = min(old_ts_delta, min_frame_period)

        self.ts_delta_hist.append(ts_delta)
        return min_frame_period

    def update_noise_estimate(self, residual: float, ts_delta: float) -> None:
        alpha = 0.01
        if self._num_of_deltas > 10 * 30:
            alpha = 0.002

        beta = pow(1 - alpha, ts_delta * 30.0 / 1000.0)
        self.avg_noise = beta * self.avg_noise + (1 - beta) * residual
        self.var_noise = (
            beta * self.var_noise + (1 - beta) * (self.avg_noise - residual) ** 2
        )

        if self.var_noise < 1:
            self.var_noise = 1


class RateBucket:
    def __init__(self, count: int = 0, value: int = 0) -> None:
        self.count = count
        self.value = value

    def __eq__(self, other: Any) -> bool:
        return self.count == other.count and self.value == other.value


class RateCounter:
    """
    Rate counter, which stores the amount received in 1ms buckets.
    """

    def __init__(self, window_size: int, scale: int = 8000) -> None:
        self._origin_index = 0
        self._origin_ms: Optional[int] = None
        self._scale = scale
        self._window_size = window_size
        self.reset()

    def add(self, value: int, now_ms: int) -> None:
        if self._origin_ms is None:
            self._origin_ms = now_ms
        else:
            self._erase_old(now_ms)

        index = (self._origin_index + now_ms - self._origin_ms) % self._window_size
        self._buckets[index].count += 1
        self._buckets[index].value += value
        self._total.count += 1
        self._total.value += value

    def rate(self, now_ms: int) -> Optional[int]:
        if self._origin_ms is not None:
            self._erase_old(now_ms)
            active_window_size = now_ms - self._origin_ms + 1
            if self._total.count > 0 and active_window_size > 1:
                return round(self._scale * self._total.value / active_window_size)
        return None

    def reset(self) -> None:
        self._buckets = [RateBucket() for i in range(self._window_size)]
        self._origin_index = 0
        self._origin_ms = None
        self._total = RateBucket()

    def _erase_old(self, now_ms: int) -> None:
        new_origin_ms = now_ms - self._window_size + 1
        while self._origin_ms < new_origin_ms:
            bucket = self._buckets[self._origin_index]
            self._total.count -= bucket.count
            self._total.value -= bucket.value
            bucket.count = 0
            bucket.value = 0

            self._origin_index = (self._origin_index + 1) % self._window_size
            self._origin_ms += 1


class RemoteBitrateEstimator:
    def __init__(self) -> None:
        self.incoming_bitrate = RateCounter(1000, 8000)
        self.incoming_bitrate_initialized = True
        self.inter_arrival = InterArrival(
            (TIMESTAMP_GROUP_LENGTH_MS << INTER_ARRIVAL_SHIFT) // 1000, TIMESTAMP_TO_MS
        )
        self.estimator = OveruseEstimator()
        self.detector = OveruseDetector()
        self.rate_control = AimdRateControl()
        self.last_update_ms: Optional[int] = None
        self.ssrcs: dict[int, int] = {}

    def add(
        self, arrival_time_ms: int, abs_send_time: int, payload_size: int, ssrc: int
    ) -> Optional[tuple[int, list[int]]]:
        timestamp = abs_send_time << 8
        update_estimate = False

        # make note of SSRC
        self.ssrcs[ssrc] = arrival_time_ms

        # update incoming bitrate
        if self.incoming_bitrate.rate(arrival_time_ms) is not None:
            self.incoming_bitrate_initialized = True
        elif self.incoming_bitrate_initialized:
            self.incoming_bitrate.reset()
            self.incoming_bitrate_initialized = False
        self.incoming_bitrate.add(payload_size, arrival_time_ms)

        # calculate inter-arrival deltas
        deltas = self.inter_arrival.compute_deltas(
            timestamp, arrival_time_ms, payload_size
        )
        if deltas is not None:
            timestamp_delta_ms = deltas.timestamp * TIMESTAMP_TO_MS
            self.estimator.update(
                deltas.arrival_time,
                timestamp_delta_ms,
                deltas.size,
                self.detector.state(),
                arrival_time_ms,
            )
            self.detector.detect(
                self.estimator.offset(),
                timestamp_delta_ms,
                self.estimator.num_of_deltas(),
                arrival_time_ms,
            )

        if not update_estimate:
            if (
                self.last_update_ms is None
                or (arrival_time_ms - self.last_update_ms)
                > self.rate_control.feedback_interval()
            ):
                update_estimate = True
            elif self.detector.state() == BandwidthUsage.OVERUSING:
                update_estimate = True

        if update_estimate:
            target_bitrate = self.rate_control.update(
                self.detector.state(),
                self.incoming_bitrate.rate(arrival_time_ms),
                arrival_time_ms,
            )
            if target_bitrate is not None:
                self.last_update_ms = arrival_time_ms
                return target_bitrate, list(self.ssrcs.keys())

        return None
