# -*- coding: utf-8 -*-
#
# This file is part of Python-ASN1. Python-ASN1 is free software that is
# made available under the MIT license. Consult the file "LICENSE" that is
# distributed together with this file for the exact licensing terms.
#
# Python-ASN1 is copyright (c) 2007-2025 by the Python-ASN1 authors. See the
# file "AUTHORS" for a complete overview.

"""
This module provides ASN.1 encoder and decoder.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import collections
import re
from builtins import bytes
from builtins import int
from builtins import range
from builtins import str
from contextlib import contextmanager
from enum import IntEnum

__version__ = "2.8.0"


class HexEnum(IntEnum):
    def __repr__(self):
        return '<{cls}.{name}: 0x{value:02x}>'.format(
            cls=type(self).__name__,
            name=self.name,
            value=self.value
        )


class Numbers(HexEnum):
    Boolean = 0x01
    Integer = 0x02
    BitString = 0x03
    OctetString = 0x04
    Null = 0x05
    ObjectIdentifier = 0x06
    Enumerated = 0x0a
    UTF8String = 0x0c
    Sequence = 0x10
    Set = 0x11
    PrintableString = 0x13
    IA5String = 0x16
    UTCTime = 0x17
    GeneralizedTime = 0x18
    UnicodeString = 0x1e


class Types(HexEnum):
    Constructed = 0x20
    Primitive = 0x00


class Classes(HexEnum):
    Universal = 0x00
    Application = 0x40
    Context = 0x80
    Private = 0xc0


Tag = collections.namedtuple('Tag', 'nr typ cls')
"""
A named tuple to represent ASN.1 tags as returned by `Decoder.peek()` and
`Decoder.read()`.
"""


class Error(Exception):
    """
    ASN.11 encoding or decoding error.
    """


class Encoder(object):
    """
    ASN.1 encoder. Uses DER encoding.
    """

    def __init__(self):  # type: () -> None
        """
        Constructor.
        """
        self.m_stack = None

    def start(self):  # type: () -> None
        """
        This method instructs the encoder to start encoding a new ASN.1
        output. This method may be called at any time to reset the encoder,
        and resets the current output (if any).
        """
        self.m_stack = [[]]

    def enter(self, nr, cls=None):  # type: (int, int) -> None
        """
        This method starts the construction of a constructed type.

        Args:
            nr (int): The desired ASN.1 type. Use ``Numbers`` enumeration.

            cls (int): This optional parameter specifies the class
                of the constructed type. The default class to use is the
                universal class. Use ``Classes`` enumeration.

        Returns:
            None

        Raises:
            `Error`
        """
        if self.m_stack is None:
            raise Error('Encoder not initialized. Call start() first.')
        if cls is None:
            cls = Classes.Universal
        self._emit_tag(nr, Types.Constructed, cls)
        self.m_stack.append([])

    def leave(self):  # type: () -> None
        """
        This method completes the construction of a constructed type and
        writes the encoded representation to the output buffer.
        """
        if self.m_stack is None:
            raise Error('Encoder not initialized. Call start() first.')
        if len(self.m_stack) == 1:
            raise Error('Tag stack is empty.')
        value = b''.join(self.m_stack[-1])
        del self.m_stack[-1]
        self._emit_length(len(value))
        self._emit(value)

    @contextmanager
    def construct(self, nr, cls=None):  # type: (int, int) -> None
        """
        This method - context manager calls enter and leave methods,
        for better code mapping.

        Usage:
        ```
        with encoder.construct(asn1.Numbers.Sequence):
            encoder.write(1)
            with encoder.construct(asn1.Numbers.Sequence):
                encoder.write('foo')
                encoder.write('bar')
            encoder.write(2)
        ```
        encoder.output() will result following structure:
        SEQUENCE:
            INTEGER: 1
            SEQUENCE:
                STRING: foo
                STRING: bar
            INTEGER: 2

        Args:
            nr (int): The desired ASN.1 type. Use ``Numbers`` enumeration.

            cls (int): This optional parameter specifies the class
                of the constructed type. The default class to use is the
                universal class. Use ``Classes`` enumeration.

        Returns:
            None

        Raises:
            `Error`

        """
        self.enter(nr, cls)
        yield
        self.leave()

    def write(self, value, nr=None, typ=None, cls=None):  # type: (object, int, int, int) -> None
        """
        This method encodes one ASN.1 tag and writes it to the output buffer.

        Note:
            Normally, ``value`` will be the only parameter to this method.
            In this case Python-ASN1 will autodetect the correct ASN.1 type from
            the type of ``value``, and will output the encoded value based on this
            type.

        Args:
            value (any): The value of the ASN.1 tag to write. Python-ASN1 will
                try to autodetect the correct ASN.1 type from the type of
                ``value``.

            nr (int): If the desired ASN.1 type cannot be autodetected or is
                autodetected wrongly, the ``nr`` parameter can be provided to
                specify the ASN.1 type to be used. Use ``Numbers`` enumeration.

            typ (int): This optional parameter can be used to write constructed
                types to the output by setting it to indicate the constructed
                encoding type. In this case, ``value`` must already be valid ASN.1
                encoded data as plain Python bytes. This is not normally how
                constructed types should be encoded though, see `Encoder.enter()`
                and `Encoder.leave()` for the recommended way of doing this.
                Use ``Types`` enumeration.

            cls (int): This parameter can be used to override the class of the
                ``value``. The default class is the universal class.
                Use ``Classes`` enumeration.

        Returns:
            None

        Raises:
            `Error`
        """
        if self.m_stack is None:
            raise Error('Encoder not initialized. Call start() first.')

        if typ is None:
            typ = Types.Primitive
        if cls is None:
            cls = Classes.Universal

        if cls != Classes.Universal and nr is None:
            raise Error('Please specify a tag number (nr) when using classes Application, Context or Private')

        if nr is None:
            if isinstance(value, bool):
                nr = Numbers.Boolean
            elif isinstance(value, int):
                nr = Numbers.Integer
            elif isinstance(value, str):
                nr = Numbers.PrintableString
            elif isinstance(value, bytes):
                nr = Numbers.OctetString
            elif value is None:
                nr = Numbers.Null

        value = self._encode_value(cls, nr, value)
        self._emit_tag(nr, typ, cls)
        self._emit_length(len(value))
        self._emit(value)

    def output(self):  # type: () -> bytes
        """
        This method returns the encoded ASN.1 data as plain Python ``bytes``.
        This method can be called multiple times, also during encoding.
        In the latter case the data that has been encoded so far is
        returned.

        Note:
            It is an error to call this method if the encoder is still
            constructing a constructed type, i.e. if `Encoder.enter()` has been
            called more times that `Encoder.leave()`.

        Returns:
            bytes: The DER encoded ASN.1 data.

        Raises:
            `Error`
        """
        if self.m_stack is None:
            raise Error('Encoder not initialized. Call start() first.')
        if len(self.m_stack) != 1:
            raise Error('Stack is not empty.')
        output = b''.join(self.m_stack[0])
        return output

    def _emit_tag(self, nr, typ, cls):  # type: (int, int, int) -> None
        """Emit a tag."""
        if nr < 31:
            self._emit_tag_short(nr, typ, cls)
        else:
            self._emit_tag_long(nr, typ, cls)

    def _emit_tag_short(self, nr, typ, cls):  # type: (int, int, int) -> None
        """Emit a short (< 31 bytes) tag."""
        assert nr < 31
        self._emit(bytes([nr | typ | cls]))

    def _emit_tag_long(self, nr, typ, cls):  # type: (int, int, int) -> None
        """Emit a long (>= 31 bytes) tag."""
        head = bytes([typ | cls | 0x1f])
        self._emit(head)
        values = [(nr & 0x7f)]
        nr >>= 7
        while nr:
            values.append((nr & 0x7f) | 0x80)
            nr >>= 7
        values.reverse()
        for val in values:
            self._emit(bytes([val]))

    def _emit_length(self, length):  # type: (int) -> None
        """Emit length octects."""
        if length < 128:
            self._emit_length_short(length)
        else:
            self._emit_length_long(length)

    def _emit_length_short(self, length):  # type: (int) -> None
        """Emit the short length form (< 128 octets)."""
        assert length < 128
        self._emit(bytes([length]))

    def _emit_length_long(self, length):  # type: (int) -> None
        """Emit the long length form (>= 128 octets)."""
        values = []
        while length:
            values.append(length & 0xff)
            length >>= 8
        values.reverse()
        # really for correctness as this should not happen anytime soon
        assert len(values) < 127
        head = bytes([0x80 | len(values)])
        self._emit(head)
        for val in values:
            self._emit(bytes([val]))

    def _emit(self, s):  # type: (bytes) -> None
        """Emit raw bytes."""
        assert isinstance(s, bytes)
        self.m_stack[-1].append(s)

    def _encode_value(self, cls, nr, value):  # type: (int, int, any) -> bytes
        """Encode a value."""
        if cls != Classes.Universal:
            return value
        if nr in (Numbers.Integer, Numbers.Enumerated):
            return self._encode_integer(value)
        if nr in (Numbers.OctetString, Numbers.PrintableString,
                  Numbers.UTF8String, Numbers.IA5String,
                  Numbers.UnicodeString, Numbers.UTCTime,
                  Numbers.GeneralizedTime):
            return self._encode_octet_string(value)
        if nr == Numbers.BitString:
            return self._encode_bit_string(value)
        if nr == Numbers.Boolean:
            return self._encode_boolean(value)
        if nr == Numbers.Null:
            return self._encode_null()
        if nr == Numbers.ObjectIdentifier:
            return self._encode_object_identifier(value)
        return value

    @staticmethod
    def _encode_boolean(value):  # type: (bool) -> bytes
        """Encode a boolean."""
        return value and bytes(b'\xff') or bytes(b'\x00')

    @staticmethod
    def _encode_integer(value):  # type: (int) -> bytes
        """Encode an integer."""
        if value < 0:
            value = -value
            negative = True
            limit = 0x80
        else:
            negative = False
            limit = 0x7f
        values = []
        while value > limit:
            values.append(value & 0xff)
            value >>= 8
        values.append(value & 0xff)
        if negative:
            # create two's complement
            for i in range(len(values)):  # Invert bits
                values[i] = 0xff - values[i]
            for i in range(len(values)):  # Add 1
                values[i] += 1
                if values[i] <= 0xff:
                    break
                assert i != len(values) - 1
                values[i] = 0x00
        if negative and values[len(values) - 1] == 0x7f:  # Two's complement corner case
            values.append(0xff)
        values.reverse()
        return bytes(values)

    @staticmethod
    def _encode_octet_string(value):  # type: (object) -> bytes
        """Encode an octetstring."""
        # Use the primitive encoding
        assert isinstance(value, str) or isinstance(value, bytes)
        if isinstance(value, str):
            return value.encode('utf-8')
        else:
            return value

    @staticmethod
    def _encode_bit_string(value):  # type: (object) -> bytes
        """Encode a bitstring. Assumes no unused bytes."""
        # Use the primitive encoding
        assert isinstance(value, bytes)
        return b'\x00' + value

    @staticmethod
    def _encode_null():  # type: () -> bytes
        """Encode a Null value."""
        return bytes(b'')

    _re_oid = re.compile(r'^[0-9]+(\.[0-9]+)+$')

    def _encode_object_identifier(self, oid):  # type: (str) -> bytes
        """Encode an object identifier."""
        if not self._re_oid.match(oid):
            raise Error('Illegal object identifier')
        cmps = list(map(int, oid.split('.')))
        if (cmps[0] <= 1 and cmps[1] > 39) or cmps[0] > 2:
            raise Error('Illegal object identifier')
        cmps = [40 * cmps[0] + cmps[1]] + cmps[2:]
        cmps.reverse()
        result = []
        for cmp_data in cmps:
            result.append(cmp_data & 0x7f)
            while cmp_data > 0x7f:
                cmp_data >>= 7
                result.append(0x80 | (cmp_data & 0x7f))
        result.reverse()
        return bytes(result)


class Decoder(object):
    """
    ASN.1 decoder. Understands BER (and DER which is a subset).
    """

    def __init__(self):  # type: () -> None
        """Constructor."""
        self.m_stack = None
        self.m_tag = None

    def start(self, data):  # type: (bytes) -> None
        """
        This method instructs the decoder to start decoding the ASN.1 input
        ``data``, which must be a passed in as plain Python bytes.
        This method may be called at any time to start a new decoding job.
        If this method is called while currently decoding another input, that
        decoding context is discarded.

        Note:
            It is not necessary to specify the encoding because the decoder
            assumes the input is in BER or DER format.

        Args:
            data (bytes): ASN.1 input, in BER or DER format, to be decoded.

        Returns:
            None

        Raises:
            `Error`
        """
        if not isinstance(data, bytes):
            raise Error('Expecting bytes instance.')
        self.m_stack = [[0, bytes(data)]]
        self.m_tag = None

    def peek(self):  # type: () -> Tag
        """
        This method returns the current ASN.1 tag (i.e. the tag that a
        subsequent `Decoder.read()` call would return) without updating the
        decoding offset. In case no more data is available from the input,
        this method returns ``None`` to signal end-of-file.

        This method is useful if you don't know whether the next tag will be a
        primitive or a constructed tag. Depending on the return value of `peek`,
        you would decide to either issue a `Decoder.read()` in case of a primitive
        type, or an `Decoder.enter()` in case of a constructed type.

        Note:
            Because this method does not advance the current offset in the input,
            calling it multiple times in a row will return the same value for all
            calls.

        Returns:
            `Tag`: The current ASN.1 tag.

        Raises:
            `Error`
        """
        if self.m_stack is None:
            raise Error('No input selected. Call start() first.')
        if self._end_of_input():
            return None
        if self.m_tag is None:
            self.m_tag = self._read_tag()
        return self.m_tag

    def read(self, tagnr=None):  # type: (Number) -> (Tag, any)
        """
        This method decodes one ASN.1 tag from the input and returns it as a
        ``(tag, value)`` tuple. ``tag`` is a 3-tuple ``(nr, typ, cls)``,
        while ``value`` is a Python object representing the ASN.1 value.
        The offset in the input is increased so that the next `Decoder.read()`
        call will return the next tag. In case no more data is available from
        the input, this method returns ``None`` to signal end-of-file.

        Returns:
            `Tag`, value: The current ASN.1 tag and its value.

        Raises:
            `Error`
        """
        if self.m_stack is None:
            raise Error('No input selected. Call start() first.')
        if self._end_of_input():
            return None
        tag = self.peek()
        length = self._read_length()
        if tagnr is None:
            tagnr = tag.nr
        value = self._read_value(tag.cls, tagnr, length)
        self.m_tag = None
        return tag, value

    def eof(self):  # type: () -> bool
        """
        Return True if we are at the end of input.

        Returns:
            bool: True if all input has been decoded, and False otherwise.
        """
        return self._end_of_input()

    def enter(self):  # type: () -> None
        """
        This method enters the constructed type that is at the current
        decoding offset.

        Note:
            It is an error to call `Decoder.enter()` if the to be decoded ASN.1 tag
            is not of a constructed type.

        Returns:
            None
        """
        if self.m_stack is None:
            raise Error('No input selected. Call start() first.')
        tag = self.peek()
        if tag.typ != Types.Constructed:
            raise Error('Cannot enter a non-constructed tag.')
        length = self._read_length()
        bytes_data = self._read_bytes(length)
        self.m_stack.append([0, bytes_data])
        self.m_tag = None

    def leave(self):  # type: () -> None
        """
        This method leaves the last constructed type that was
        `Decoder.enter()`-ed.

        Note:
            It is an error to call `Decoder.leave()` if the current ASN.1 tag
            is not of a constructed type.

        Returns:
            None
        """
        if self.m_stack is None:
            raise Error('No input selected. Call start() first.')
        if len(self.m_stack) == 1:
            raise Error('Tag stack is empty.')
        del self.m_stack[-1]
        self.m_tag = None

    def _read_tag(self):  # type: () -> Tag
        """
        Read a tag from the input.
        """
        byte = self._read_byte()
        cls = byte & 0xc0
        typ = byte & 0x20
        nr = byte & 0x1f
        if nr == 0x1f:  # Long form of tag encoding
            nr = 0
            while True:
                byte = self._read_byte()
                nr = (nr << 7) | (byte & 0x7f)
                if not byte & 0x80:
                    break
        try:
            typ = Types(typ)
        except ValueError:
            pass
        try:
            cls = Classes(cls)
        except ValueError:
            pass
        if cls == Classes.Universal:
            try:
                nr = Numbers(nr)
            except ValueError:
                pass
        return Tag(nr=nr, typ=typ, cls=cls)

    def _read_length(self):  # type: () -> int
        """
        Read a length from the input.
        """
        byte = self._read_byte()
        if byte & 0x80:
            count = byte & 0x7f
            if count == 0x7f:
                raise Error('ASN1 syntax error')
            bytes_data = self._read_bytes(count)
            length = 0
            for byte in bytes_data:
                length = (length << 8) | int(byte)
            try:
                length = int(length)
            except OverflowError:
                pass
        else:
            length = byte
        return length

    def _read_value(self, cls, nr, length):  # type: (int, int, int) -> any
        """
        Read a value from the input.
        """
        bytes_data = self._read_bytes(length)
        if cls != Classes.Universal:
            value = bytes_data
        elif nr == Numbers.Boolean:
            value = self._decode_boolean(bytes_data)
        elif nr in (Numbers.Integer, Numbers.Enumerated):
            value = self._decode_integer(bytes_data)
        elif nr == Numbers.OctetString:
            value = self._decode_octet_string(bytes_data)
        elif nr == Numbers.Null:
            value = self._decode_null(bytes_data)
        elif nr == Numbers.ObjectIdentifier:
            value = self._decode_object_identifier(bytes_data)
        elif nr in (Numbers.PrintableString, Numbers.IA5String,
                    Numbers.UTF8String, Numbers.UTCTime,
                    Numbers.GeneralizedTime):
            value = self._decode_printable_string(bytes_data)
        elif nr == Numbers.BitString:
            value = self._decode_bitstring(bytes_data)
        else:
            value = bytes_data
        return value

    def _read_byte(self):  # type: () -> int
        """
        Return the next input byte, or raise an error on end-of-input.
        """
        index, input_data = self.m_stack[-1]
        try:
            byte = input_data[index]
        except IndexError:
            raise Error('Premature end of input.')
        self.m_stack[-1][0] += 1
        return byte

    def _read_bytes(self, count):  # type: (int) -> bytes
        """
        Return the next ``count`` bytes of input. Raise error on
        end-of-input.
        """
        index, input_data = self.m_stack[-1]
        bytes_data = input_data[index:index + count]
        if len(bytes_data) != count:
            raise Error('Premature end of input.')
        self.m_stack[-1][0] += count
        return bytes_data

    def _end_of_input(self):  # type: () -> bool
        """
        Return True if we are at the end of input.
        """
        index, input_data = self.m_stack[-1]
        assert not index > len(input_data)
        return index == len(input_data)

    @staticmethod
    def _decode_boolean(bytes_data):  # type: (bytes) -> bool
        """
        Decode a boolean value.
        """
        if len(bytes_data) != 1:
            raise Error('ASN1 syntax error')
        if bytes_data[0] == 0:
            return False
        return True

    @staticmethod
    def _decode_integer(bytes_data):  # type: (bytes) -> int
        """
        Decode an integer value.
        """
        values = [int(b) for b in bytes_data]
        # check if the integer is normalized
        if len(values) > 1 and (values[0] == 0xff and values[1] & 0x80 or values[0] == 0x00 and not (values[1] & 0x80)):
            raise Error('ASN1 syntax error')
        negative = values[0] & 0x80
        if negative:
            # make positive by taking two's complement
            for i in range(len(values)):
                values[i] = 0xff - values[i]
            for i in range(len(values) - 1, -1, -1):
                values[i] += 1
                if values[i] <= 0xff:
                    break
                assert i > 0
                values[i] = 0x00
        value = 0
        for val in values:
            value = (value << 8) | val
        if negative:
            value = -value
        try:
            value = int(value)
        except OverflowError:
            pass
        return value

    @staticmethod
    def _decode_octet_string(bytes_data):  # type: (bytes) -> bytes
        """
        Decode an octet string.
        """
        return bytes_data

    @staticmethod
    def _decode_null(bytes_data):  # type: (bytes) -> any
        """
        Decode a Null value.
        """
        if len(bytes_data) != 0:
            raise Error('ASN1 syntax error')
        return None

    @staticmethod
    def _decode_object_identifier(bytes_data):  # type: (bytes) -> str
        """
        Decode an object identifier.
        """
        result = []
        value = 0
        for i in range(len(bytes_data)):
            byte = int(bytes_data[i])
            if value == 0 and byte == 0x80:
                raise Error('ASN1 syntax error')
            value = (value << 7) | (byte & 0x7f)
            if not byte & 0x80:
                result.append(value)
                value = 0
        if len(result) == 0:
            raise Error('ASN1 syntax error')
        if result[0] // 40 <= 1:
            result = [result[0] // 40, result[0] % 40] + result[1:]
        else:
            result = [2, result[0] - 80] + result[1:]
        result = list(map(str, result))
        return str('.'.join(result))

    @staticmethod
    def _decode_printable_string(bytes_data):  # type: (bytes) -> str
        """
        Decode a printable string.
        """
        return bytes_data.decode('utf-8')

    @staticmethod
    def _decode_bitstring(bytes_data):  # type: (bytes) -> str
        """
        Decode a bitstring.
        """
        if len(bytes_data) == 0:
            raise Error('ASN1 syntax error')

        num_unused_bits = bytes_data[0]
        if not (0 <= num_unused_bits <= 7):
            raise Error('ASN1 syntax error')

        if num_unused_bits == 0:
            return bytes_data[1:]

        # Shift off unused bits
        remaining = bytearray(bytes_data[1:])
        bitmask = (1 << num_unused_bits) - 1
        removed_bits = 0

        for i in range(len(remaining)):
            byte = int(remaining[i])
            remaining[i] = (byte >> num_unused_bits) | (removed_bits << num_unused_bits)
            removed_bits = byte & bitmask

        return bytes(remaining)
