import math
import os
import struct
from dataclasses import dataclass, field
from struct import pack, unpack, unpack_from
from typing import Any, Optional, Union

from av import AudioFrame

from .rtcrtpparameters import RTCRtpParameters

# used for NACK and retransmission
RTP_HISTORY_SIZE = 128

# reserved to avoid confusion with RTCP
FORBIDDEN_PAYLOAD_TYPES = range(72, 77)
DYNAMIC_PAYLOAD_TYPES = range(96, 128)

RTP_HEADER_LENGTH = 12
RTCP_HEADER_LENGTH = 4

PACKETS_LOST_MIN = -(1 << 23)
PACKETS_LOST_MAX = (1 << 23) - 1

RTCP_SR = 200
RTCP_RR = 201
RTCP_SDES = 202
RTCP_BYE = 203
RTCP_RTPFB = 205
RTCP_PSFB = 206

RTCP_RTPFB_NACK = 1

RTCP_PSFB_PLI = 1
RTCP_PSFB_SLI = 2
RTCP_PSFB_RPSI = 3
RTCP_PSFB_APP = 15


@dataclass
class HeaderExtensions:
    abs_send_time: Optional[int] = None
    audio_level: Any = None
    mid: Any = None
    repaired_rtp_stream_id: Any = None
    rtp_stream_id: Any = None
    transmission_offset: Optional[int] = None
    transport_sequence_number: Optional[int] = None


class HeaderExtensionsMap:
    def __init__(self) -> None:
        self.__ids = HeaderExtensions()

    def configure(self, parameters: RTCRtpParameters) -> None:
        for ext in parameters.headerExtensions:
            if ext.uri == "urn:ietf:params:rtp-hdrext:sdes:mid":
                self.__ids.mid = ext.id
            elif ext.uri == "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id":
                self.__ids.repaired_rtp_stream_id = ext.id
            elif ext.uri == "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id":
                self.__ids.rtp_stream_id = ext.id
            elif (
                ext.uri == "http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time"
            ):
                self.__ids.abs_send_time = ext.id
            elif ext.uri == "urn:ietf:params:rtp-hdrext:toffset":
                self.__ids.transmission_offset = ext.id
            elif ext.uri == "urn:ietf:params:rtp-hdrext:ssrc-audio-level":
                self.__ids.audio_level = ext.id
            elif (
                ext.uri
                == "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01"
            ):
                self.__ids.transport_sequence_number = ext.id

    def get(self, extension_profile: int, extension_value: bytes) -> HeaderExtensions:
        values = HeaderExtensions()
        for x_id, x_value in unpack_header_extensions(
            extension_profile, extension_value
        ):
            if x_id == self.__ids.mid:
                values.mid = x_value.decode("utf8")
            elif x_id == self.__ids.repaired_rtp_stream_id:
                values.repaired_rtp_stream_id = x_value.decode("ascii")
            elif x_id == self.__ids.rtp_stream_id:
                values.rtp_stream_id = x_value.decode("ascii")
            elif x_id == self.__ids.abs_send_time:
                values.abs_send_time = unpack("!L", b"\00" + x_value)[0]
            elif x_id == self.__ids.transmission_offset:
                values.transmission_offset = unpack("!l", x_value + b"\00")[0] >> 8
            elif x_id == self.__ids.audio_level:
                vad_level = unpack("!B", x_value)[0]
                values.audio_level = (vad_level & 0x80 == 0x80, vad_level & 0x7F)
            elif x_id == self.__ids.transport_sequence_number:
                values.transport_sequence_number = unpack("!H", x_value)[0]
        return values

    def set(self, values: HeaderExtensions) -> tuple[int, bytes]:
        extensions = []
        if values.mid is not None and self.__ids.mid:
            extensions.append((self.__ids.mid, values.mid.encode("utf8")))
        if (
            values.repaired_rtp_stream_id is not None
            and self.__ids.repaired_rtp_stream_id
        ):
            extensions.append(
                (
                    self.__ids.repaired_rtp_stream_id,
                    values.repaired_rtp_stream_id.encode("ascii"),
                )
            )
        if values.rtp_stream_id is not None and self.__ids.rtp_stream_id:
            extensions.append(
                (self.__ids.rtp_stream_id, values.rtp_stream_id.encode("ascii"))
            )
        if values.abs_send_time is not None and self.__ids.abs_send_time:
            extensions.append(
                (self.__ids.abs_send_time, pack("!L", values.abs_send_time)[1:])
            )
        if values.transmission_offset is not None and self.__ids.transmission_offset:
            extensions.append(
                (
                    self.__ids.transmission_offset,
                    pack("!l", values.transmission_offset << 8)[0:2],
                )
            )
        if values.audio_level is not None and self.__ids.audio_level:
            extensions.append(
                (
                    self.__ids.audio_level,
                    pack(
                        "!B",
                        (0x80 if values.audio_level[0] else 0)
                        | (values.audio_level[1] & 0x7F),
                    ),
                )
            )
        if (
            values.transport_sequence_number is not None
            and self.__ids.transport_sequence_number
        ):
            extensions.append(
                (
                    self.__ids.transport_sequence_number,
                    pack("!H", values.transport_sequence_number),
                )
            )
        return pack_header_extensions(extensions)


def clamp_packets_lost(count: int) -> int:
    return max(PACKETS_LOST_MIN, min(count, PACKETS_LOST_MAX))


def pack_packets_lost(count: int) -> bytes:
    return pack("!l", count)[1:]


def unpack_packets_lost(d: bytes) -> int:
    if d[0] & 0x80:
        d = b"\xff" + d
    else:
        d = b"\x00" + d
    return unpack("!l", d)[0]


def pack_rtcp_packet(packet_type: int, count: int, payload: bytes) -> bytes:
    assert len(payload) % 4 == 0
    return pack("!BBH", (2 << 6) | count, packet_type, len(payload) // 4) + payload


def pack_remb_fci(bitrate: int, ssrcs: list[int]) -> bytes:
    """
    Pack the FCI for a Receiver Estimated Maximum Bitrate report.

    https://tools.ietf.org/html/draft-alvestrand-rmcat-remb-03
    """
    data = b"REMB"
    exponent = 0
    mantissa = bitrate
    while mantissa > 0x3FFFF:
        mantissa >>= 1
        exponent += 1
    data += pack(
        "!BBH", len(ssrcs), (exponent << 2) | (mantissa >> 16), (mantissa & 0xFFFF)
    )
    for ssrc in ssrcs:
        data += pack("!L", ssrc)
    return data


def unpack_remb_fci(data: bytes) -> tuple[int, list[int]]:
    """
    Unpack the FCI for a Receiver Estimated Maximum Bitrate report.

    https://tools.ietf.org/html/draft-alvestrand-rmcat-remb-03
    """
    if len(data) < 8 or data[0:4] != b"REMB":
        raise ValueError("Invalid REMB prefix")

    exponent = (data[5] & 0xFC) >> 2
    mantissa = ((data[5] & 0x03) << 16) | (data[6] << 8) | data[7]
    bitrate = mantissa << exponent

    pos = 8
    ssrcs = []
    for r in range(data[4]):
        ssrcs.append(unpack_from("!L", data, pos)[0])
        pos += 4

    return (bitrate, ssrcs)


def is_rtcp(msg: bytes) -> bool:
    return len(msg) >= 2 and msg[1] >= 192 and msg[1] <= 208


def padl(length: int) -> int:
    """
    Return amount of padding needed for a 4-byte multiple.
    """
    return 4 * ((length + 3) // 4) - length


def unpack_header_extensions(
    extension_profile: int, extension_value: bytes
) -> list[tuple[int, bytes]]:
    """
    Parse header extensions according to RFC 5285.
    """
    extensions = []
    pos = 0

    if extension_profile == 0xBEDE:
        # One-Byte Header
        while pos < len(extension_value):
            # skip padding byte
            if extension_value[pos] == 0:
                pos += 1
                continue

            x_id = (extension_value[pos] & 0xF0) >> 4
            x_length = (extension_value[pos] & 0x0F) + 1
            pos += 1

            if len(extension_value) < pos + x_length:
                raise ValueError("RTP one-byte header extension value is truncated")
            x_value = extension_value[pos : pos + x_length]
            extensions.append((x_id, x_value))
            pos += x_length
    elif extension_profile == 0x1000:
        # Two-Byte Header
        while pos < len(extension_value):
            # skip padding byte
            if extension_value[pos] == 0:
                pos += 1
                continue

            if len(extension_value) < pos + 2:
                raise ValueError("RTP two-byte header extension is truncated")
            x_id, x_length = unpack_from("!BB", extension_value, pos)
            pos += 2

            if len(extension_value) < pos + x_length:
                raise ValueError("RTP two-byte header extension value is truncated")
            x_value = extension_value[pos : pos + x_length]
            extensions.append((x_id, x_value))
            pos += x_length

    return extensions


def pack_header_extensions(extensions: list[tuple[int, bytes]]) -> tuple[int, bytes]:
    """
    Serialize header extensions according to RFC 5285.
    """
    extension_profile = 0
    extension_value = b""

    if not extensions:
        return extension_profile, extension_value

    one_byte = True
    for x_id, x_value in extensions:
        x_length = len(x_value)
        assert x_id > 0 and x_id < 256
        assert x_length >= 0 and x_length < 256
        if x_id > 14 or x_length == 0 or x_length > 16:
            one_byte = False

    if one_byte:
        # One-Byte Header
        extension_profile = 0xBEDE
        extension_value = b""
        for x_id, x_value in extensions:
            x_length = len(x_value)
            extension_value += pack("!B", (x_id << 4) | (x_length - 1))
            extension_value += x_value
    else:
        # Two-Byte Header
        extension_profile = 0x1000
        extension_value = b""
        for x_id, x_value in extensions:
            x_length = len(x_value)
            extension_value += pack("!BB", x_id, x_length)
            extension_value += x_value

    extension_value += b"\x00" * padl(len(extension_value))
    return extension_profile, extension_value


def compute_audio_level_dbov(frame: AudioFrame) -> int:
    """
    Compute the energy level as spelled out in RFC 6465, Appendix A.
    """
    MAX_SAMPLE_VALUE = 32767
    MAX_AUDIO_LEVEL = 0
    MIN_AUDIO_LEVEL = -127
    rms = 0.0
    buf = bytes(frame.planes[0])
    s = struct.Struct("h")
    for unpacked in s.iter_unpack(buf):
        sample = unpacked[0]
        rms += sample * sample
    rms = math.sqrt(rms / (frame.samples * MAX_SAMPLE_VALUE * MAX_SAMPLE_VALUE))
    if rms > 0:
        db = 20 * math.log10(rms)
        db = max(db, MIN_AUDIO_LEVEL)
        db = min(db, MAX_AUDIO_LEVEL)
    else:
        db = MIN_AUDIO_LEVEL
    return round(db)


@dataclass
class RtcpReceiverInfo:
    ssrc: int
    fraction_lost: int
    packets_lost: int
    highest_sequence: int
    jitter: int
    lsr: int
    dlsr: int

    def __bytes__(self) -> bytes:
        data = pack("!LB", self.ssrc, self.fraction_lost)
        data += pack_packets_lost(self.packets_lost)
        data += pack("!LLLL", self.highest_sequence, self.jitter, self.lsr, self.dlsr)
        return data

    @classmethod
    def parse(cls, data: bytes) -> "RtcpReceiverInfo":
        ssrc, fraction_lost = unpack("!LB", data[0:5])
        packets_lost = unpack_packets_lost(data[5:8])
        highest_sequence, jitter, lsr, dlsr = unpack("!LLLL", data[8:])
        return cls(
            ssrc=ssrc,
            fraction_lost=fraction_lost,
            packets_lost=packets_lost,
            highest_sequence=highest_sequence,
            jitter=jitter,
            lsr=lsr,
            dlsr=dlsr,
        )


@dataclass
class RtcpSenderInfo:
    ntp_timestamp: int
    rtp_timestamp: int
    packet_count: int
    octet_count: int

    def __bytes__(self) -> bytes:
        return pack(
            "!QLLL",
            self.ntp_timestamp,
            self.rtp_timestamp,
            self.packet_count,
            self.octet_count,
        )

    @classmethod
    def parse(cls, data: bytes) -> "RtcpSenderInfo":
        ntp_timestamp, rtp_timestamp, packet_count, octet_count = unpack("!QLLL", data)
        return cls(
            ntp_timestamp=ntp_timestamp,
            rtp_timestamp=rtp_timestamp,
            packet_count=packet_count,
            octet_count=octet_count,
        )


@dataclass
class RtcpSourceInfo:
    ssrc: int
    items: list[tuple[Any, bytes]]


@dataclass
class RtcpByePacket:
    sources: list[int]

    def __bytes__(self) -> bytes:
        payload = b"".join([pack("!L", ssrc) for ssrc in self.sources])
        return pack_rtcp_packet(RTCP_BYE, len(self.sources), payload)

    @classmethod
    def parse(cls, data: bytes, count: int) -> "RtcpByePacket":
        if len(data) < 4 * count:
            raise ValueError("RTCP bye length is invalid")
        if count > 0:
            sources = list(unpack_from("!" + ("L" * count), data, 0))
        else:
            sources = []
        return cls(sources=sources)


@dataclass
class RtcpPsfbPacket:
    """
    Payload-Specific Feedback Message (RFC 4585).
    """

    fmt: int
    ssrc: int
    media_ssrc: int
    fci: bytes = b""

    def __bytes__(self) -> bytes:
        payload = pack("!LL", self.ssrc, self.media_ssrc) + self.fci
        return pack_rtcp_packet(RTCP_PSFB, self.fmt, payload)

    @classmethod
    def parse(cls, data: bytes, fmt: int) -> "RtcpPsfbPacket":
        if len(data) < 8:
            raise ValueError("RTCP payload-specific feedback length is invalid")

        ssrc, media_ssrc = unpack("!LL", data[0:8])
        fci = data[8:]
        return cls(fmt=fmt, ssrc=ssrc, media_ssrc=media_ssrc, fci=fci)


@dataclass
class RtcpRrPacket:
    ssrc: int
    reports: list[RtcpReceiverInfo] = field(default_factory=list)

    def __bytes__(self) -> bytes:
        payload = pack("!L", self.ssrc)
        for report in self.reports:
            payload += bytes(report)
        return pack_rtcp_packet(RTCP_RR, len(self.reports), payload)

    @classmethod
    def parse(cls, data: bytes, count: int) -> "RtcpRrPacket":
        if len(data) != 4 + 24 * count:
            raise ValueError("RTCP receiver report length is invalid")

        ssrc = unpack("!L", data[0:4])[0]
        pos = 4
        reports = []
        for r in range(count):
            reports.append(RtcpReceiverInfo.parse(data[pos : pos + 24]))
            pos += 24
        return cls(ssrc=ssrc, reports=reports)


@dataclass
class RtcpRtpfbPacket:
    """
    Generic RTP Feedback Message (RFC 4585).
    """

    fmt: int
    ssrc: int
    media_ssrc: int

    # generick NACK
    lost: list[int] = field(default_factory=list)

    def __bytes__(self) -> bytes:
        payload = pack("!LL", self.ssrc, self.media_ssrc)
        if self.lost:
            pid = self.lost[0]
            blp = 0
            for p in self.lost[1:]:
                d = p - pid - 1
                if d < 16:
                    blp |= 1 << d
                else:
                    payload += pack("!HH", pid, blp)
                    pid = p
                    blp = 0
            payload += pack("!HH", pid, blp)
        return pack_rtcp_packet(RTCP_RTPFB, self.fmt, payload)

    @classmethod
    def parse(cls, data: bytes, fmt: int) -> "RtcpRtpfbPacket":
        if len(data) < 8 or len(data) % 4:
            raise ValueError("RTCP RTP feedback length is invalid")

        ssrc, media_ssrc = unpack("!LL", data[0:8])
        lost = []
        for pos in range(8, len(data), 4):
            pid, blp = unpack("!HH", data[pos : pos + 4])
            lost.append(pid)
            for d in range(0, 16):
                if (blp >> d) & 1:
                    lost.append(pid + d + 1)
        return cls(fmt=fmt, ssrc=ssrc, media_ssrc=media_ssrc, lost=lost)


@dataclass
class RtcpSdesPacket:
    chunks: list[RtcpSourceInfo] = field(default_factory=list)

    def __bytes__(self) -> bytes:
        payload = b""
        for chunk in self.chunks:
            payload += pack("!L", chunk.ssrc)
            for d_type, d_value in chunk.items:
                payload += pack("!BB", d_type, len(d_value)) + d_value
            payload += b"\x00\x00"
        while len(payload) % 4:
            payload += b"\x00"
        return pack_rtcp_packet(RTCP_SDES, len(self.chunks), payload)

    @classmethod
    def parse(cls, data: bytes, count: int) -> "RtcpSdesPacket":
        pos = 0
        chunks = []
        for r in range(count):
            if len(data) < pos + 4:
                raise ValueError("RTCP SDES source is truncated")
            ssrc = unpack_from("!L", data, pos)[0]
            pos += 4

            items = []
            while pos < len(data) - 1:
                d_type, d_length = unpack_from("!BB", data, pos)
                pos += 2

                if len(data) < pos + d_length:
                    raise ValueError("RTCP SDES item is truncated")
                d_value = data[pos : pos + d_length]
                pos += d_length
                if d_type == 0:
                    break
                else:
                    items.append((d_type, d_value))
            chunks.append(RtcpSourceInfo(ssrc=ssrc, items=items))
        return cls(chunks=chunks)


@dataclass
class RtcpSrPacket:
    ssrc: int
    sender_info: RtcpSenderInfo
    reports: list[RtcpReceiverInfo] = field(default_factory=list)

    def __bytes__(self) -> bytes:
        payload = pack("!L", self.ssrc)
        payload += bytes(self.sender_info)
        for report in self.reports:
            payload += bytes(report)
        return pack_rtcp_packet(RTCP_SR, len(self.reports), payload)

    @classmethod
    def parse(cls, data: bytes, count: int) -> "RtcpSrPacket":
        if len(data) != 24 + 24 * count:
            raise ValueError("RTCP sender report length is invalid")

        ssrc = unpack_from("!L", data)[0]
        sender_info = RtcpSenderInfo.parse(data[4:24])
        pos = 24
        reports = []
        for r in range(count):
            reports.append(RtcpReceiverInfo.parse(data[pos : pos + 24]))
            pos += 24
        return RtcpSrPacket(ssrc=ssrc, sender_info=sender_info, reports=reports)


AnyRtcpPacket = Union[
    RtcpByePacket,
    RtcpPsfbPacket,
    RtcpRrPacket,
    RtcpRtpfbPacket,
    RtcpSdesPacket,
    RtcpSrPacket,
]


class RtcpPacket:
    @classmethod
    def parse(cls, data: bytes) -> list[AnyRtcpPacket]:
        pos = 0
        packets: list[AnyRtcpPacket] = []

        while pos < len(data):
            if len(data) < pos + RTCP_HEADER_LENGTH:
                raise ValueError(
                    f"RTCP packet length is less than {RTCP_HEADER_LENGTH} bytes"
                )

            v_p_count, packet_type, length = unpack("!BBH", data[pos : pos + 4])
            version = v_p_count >> 6
            padding = (v_p_count >> 5) & 1
            count = v_p_count & 0x1F
            if version != 2:
                raise ValueError("RTCP packet has invalid version")
            pos += 4

            end = pos + length * 4
            if len(data) < end:
                raise ValueError("RTCP packet is truncated")
            payload = data[pos:end]
            pos = end

            if padding:
                if not payload or not payload[-1] or payload[-1] > len(payload):
                    raise ValueError("RTCP packet padding length is invalid")
                payload = payload[0 : -payload[-1]]

            if packet_type == RTCP_BYE:
                packets.append(RtcpByePacket.parse(payload, count))
            elif packet_type == RTCP_SDES:
                packets.append(RtcpSdesPacket.parse(payload, count))
            elif packet_type == RTCP_SR:
                packets.append(RtcpSrPacket.parse(payload, count))
            elif packet_type == RTCP_RR:
                packets.append(RtcpRrPacket.parse(payload, count))
            elif packet_type == RTCP_RTPFB:
                packets.append(RtcpRtpfbPacket.parse(payload, count))
            elif packet_type == RTCP_PSFB:
                packets.append(RtcpPsfbPacket.parse(payload, count))

        return packets


class RtpPacket:
    def __init__(
        self,
        payload_type: int = 0,
        marker: int = 0,
        sequence_number: int = 0,
        timestamp: int = 0,
        ssrc: int = 0,
        payload: bytes = b"",
    ) -> None:
        self.version = 2
        self.marker = marker
        self.payload_type = payload_type
        self.sequence_number = sequence_number
        self.timestamp = timestamp
        self.ssrc = ssrc
        self.csrc: list[int] = []
        self.extensions = HeaderExtensions()
        self.payload = payload
        self.padding_size = 0

    def __repr__(self) -> str:
        return (
            f"RtpPacket(seq={self.sequence_number}, ts={self.timestamp}, "
            f"marker={self.marker}, payload={self.payload_type}, "
            f"{len(self.payload)} bytes)"
        )

    @classmethod
    def parse(
        cls, data: bytes, extensions_map: HeaderExtensionsMap = HeaderExtensionsMap()
    ) -> "RtpPacket":
        if len(data) < RTP_HEADER_LENGTH:
            raise ValueError(
                f"RTP packet length is less than {RTP_HEADER_LENGTH} bytes"
            )

        v_p_x_cc, m_pt, sequence_number, timestamp, ssrc = unpack("!BBHLL", data[0:12])
        version = v_p_x_cc >> 6
        padding = (v_p_x_cc >> 5) & 1
        extension = (v_p_x_cc >> 4) & 1
        cc = v_p_x_cc & 0x0F
        if version != 2:
            raise ValueError("RTP packet has invalid version")
        if len(data) < RTP_HEADER_LENGTH + 4 * cc:
            raise ValueError("RTP packet has truncated CSRC")

        packet = cls(
            marker=(m_pt >> 7),
            payload_type=(m_pt & 0x7F),
            sequence_number=sequence_number,
            timestamp=timestamp,
            ssrc=ssrc,
        )

        pos = RTP_HEADER_LENGTH
        for i in range(0, cc):
            packet.csrc.append(unpack_from("!L", data, pos)[0])
            pos += 4

        if extension:
            if len(data) < pos + 4:
                raise ValueError("RTP packet has truncated extension profile / length")
            extension_profile, extension_length = unpack_from("!HH", data, pos)
            extension_length *= 4
            pos += 4

            if len(data) < pos + extension_length:
                raise ValueError("RTP packet has truncated extension value")
            extension_value = data[pos : pos + extension_length]
            pos += extension_length
            packet.extensions = extensions_map.get(extension_profile, extension_value)

        if padding:
            padding_len = data[-1]
            if not padding_len or padding_len > len(data) - pos:
                raise ValueError("RTP packet padding length is invalid")
            packet.padding_size = padding_len
            packet.payload = data[pos:-padding_len]
        else:
            packet.payload = data[pos:]

        return packet

    def serialize(
        self, extensions_map: HeaderExtensionsMap = HeaderExtensionsMap()
    ) -> bytes:
        extension_profile, extension_value = extensions_map.set(self.extensions)
        has_extension = bool(extension_value)

        padding = self.padding_size > 0
        data = pack(
            "!BBHLL",
            (self.version << 6)
            | (padding << 5)
            | (has_extension << 4)
            | len(self.csrc),
            (self.marker << 7) | self.payload_type,
            self.sequence_number,
            self.timestamp,
            self.ssrc,
        )
        for csrc in self.csrc:
            data += pack("!L", csrc)
        if has_extension:
            data += pack("!HH", extension_profile, len(extension_value) >> 2)
            data += extension_value
        data += self.payload
        if padding:
            data += os.urandom(self.padding_size - 1)
            data += bytes([self.padding_size])
        return data


def unwrap_rtx(rtx: RtpPacket, payload_type: int, ssrc: int) -> RtpPacket:
    """
    Recover initial packet from a retransmission packet.
    """
    packet = RtpPacket(
        payload_type=payload_type,
        marker=rtx.marker,
        sequence_number=unpack("!H", rtx.payload[0:2])[0],
        timestamp=rtx.timestamp,
        ssrc=ssrc,
        payload=rtx.payload[2:],
    )
    packet.csrc = rtx.csrc
    packet.extensions = rtx.extensions
    return packet


def wrap_rtx(
    packet: RtpPacket, payload_type: int, sequence_number: int, ssrc: int
) -> RtpPacket:
    """
    Create a retransmission packet from a lost packet.
    """
    rtx = RtpPacket(
        payload_type=payload_type,
        marker=packet.marker,
        sequence_number=sequence_number,
        timestamp=packet.timestamp,
        ssrc=ssrc,
        payload=pack("!H", packet.sequence_number) + packet.payload,
    )
    rtx.csrc = packet.csrc
    rtx.extensions = packet.extensions
    return rtx
