"""Test the TCP implementation of UCSPI."""

import argparse
import dataclasses
import enum
import pathlib
import socket

from typing import Any

import netifaces
import utf8_locale

import ucspi_test


@dataclasses.dataclass(frozen=True)
class Config(ucspi_test.Config):
    """Runtime configuration for the TCP test runner."""

    listen_addr: str
    listen_addr_len: set[int]
    # pylint: disable-next=no-member
    listen_family: socket.AddressFamily


class TcpRunner(ucspi_test.Runner):
    """Run ucspi-tcp tests."""

    def find_listening_address(self) -> list[str]:
        """Find a local address/port combination."""
        print(f"{self.proto}.find_listening_address() starting")
        for port in range(6502, 8086):
            assert isinstance(self.cfg, Config), repr(self.cfg)
            addr = self.cfg.listen_addr
            sock = socket.socket(self.cfg.listen_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                sock.bind((addr, port))
                print(f"- got {port}")
                sock.close()
                return [addr, str(port)]
            except OSError:
                pass

        raise ucspi_test.RunnerError(f"Could not find a suitable port on {addr}")

    def get_listening_socket(self, addr: list[str]) -> socket.socket:
        assert isinstance(self.cfg, Config), repr(self.cfg)
        if len(addr) not in self.cfg.listen_addr_len:
            raise ucspi_test.RunnerError(
                f"{self.proto}.get_listening_socket(): unexpected address length for {addr!r}"
            )
        laddr = addr[0]
        try:
            lport = int(addr[1])
        except ValueError as err:
            raise ucspi_test.RunnerError(
                f"{self.proto}.get_listening_socket(): could not convert "
                f"{addr[1]!r} to a number: {err}"
            ) from err

        try:
            sock = socket.socket(self.cfg.listen_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
        except OSError as err:
            raise ucspi_test.RunnerError(f"Could not create a TCP socket: {err}") from err
        try:
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        except OSError as err:
            raise ucspi_test.RunnerError(
                f"Could not set the 'reuse address' option on a TCP socket: {err}"
            ) from err
        try:
            sock.bind((laddr, lport))
        except OSError as err:
            raise ucspi_test.RunnerError(
                f"Could not get a TCP socket to bind to {laddr}:{lport}: {err}"
            ) from err
        try:
            sock.listen(5)
        except OSError as err:
            raise ucspi_test.RunnerError(
                f"Could not get a TCP socket to listen on {laddr}:{lport}: {err}"
            ) from err

        return sock

    def get_connected_socket(self, addr: list[str]) -> socket.socket:
        assert isinstance(self.cfg, Config), repr(self.cfg)
        if len(addr) not in self.cfg.listen_addr_len:
            raise ucspi_test.RunnerError(
                f"{self.proto}.get_connected_socket(): unexpected address length for {addr!r}"
            )
        laddr = addr[0]
        try:
            lport = int(addr[1])
        except ValueError as err:
            raise ucspi_test.RunnerError(
                f"{self.proto}.get_connected_socket(): could not convert "
                f"{addr[1]!r} to a number: {err}"
            ) from err

        try:
            sock = socket.socket(self.cfg.listen_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
        except OSError as err:
            raise ucspi_test.RunnerError(f"Could not create a TCP socket: {err}") from err
        try:
            sock.connect((laddr, lport))
        except OSError as err:
            raise ucspi_test.RunnerError(
                f"Could not connect a TCP socket to {laddr}:{lport}: {err}"
            ) from err

        return sock

    def format_local_addr(self, addr: list[str]) -> str:
        assert isinstance(self.cfg, Config), repr(self.cfg)
        assert len(addr) in self.cfg.listen_addr_len, repr(addr)
        return f"{addr[0]}:{addr[1]}"

    def format_remote_addr(self, addr: Any) -> str:
        assert isinstance(self.cfg, Config), repr(self.cfg)
        assert (
            isinstance(addr, tuple)
            and len(addr) in self.cfg.listen_addr_len
            and isinstance(addr[0], str)
            and isinstance(addr[1], int)
        ), repr(addr)
        return f"{addr[0]}:{addr[1]}"


class IPVersion(str, enum.Enum):
    """The IP address family for the listening socket."""

    IPV4 = "4"
    IPV6 = "6"

    def __str__(self) -> str:
        """Return the string value itself."""
        return self.value

    def addr_len(self) -> set[int]:
        """Obtain the expected length of an address/port tuple."""
        match self:
            case IPVersion.IPV4:
                return {2}

            case IPVersion.IPV6:
                return {2, 4}

    # pylint: disable-next=no-member
    def family(self) -> socket.AddressFamily:
        """Obtain the address family corresponding to this value."""
        match self:
            case IPVersion.IPV4:
                return socket.AF_INET

            case IPVersion.IPV6:
                return socket.AF_INET6


# pylint: disable-next=no-member
def get_listen_address(ip_version: IPVersion) -> tuple[str, set[int], socket.AddressFamily] | None:
    """Get a loopback address for the specified address family, if any are configured."""
    # pylint: disable=c-extension-no-member

    ifaces = netifaces.interfaces()
    if "lo" not in ifaces:
        print("No 'lo' interface at all?!")
        return None

    family = ip_version.family()
    addrs = netifaces.ifaddresses("lo")
    candidates = addrs.get(family)
    if not candidates:
        print("No addresses for the specified family on the 'lo' interface")
        return None

    return candidates[0]["addr"], ip_version.addr_len(), family


def parse_args() -> Config | None:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(prog="uctest")

    parser.add_argument(
        "-d", "--bindir", type=pathlib.Path, required=True, help="the path to the UCSPI utilities"
    )
    parser.add_argument(
        "-i",
        "--ip-version",
        type=IPVersion,
        default=IPVersion.IPV4,
        help="the address family to listen on ('4' for IPv4, '6' for IPv6)",
        choices=["4", "6"],
    )
    parser.add_argument(
        "-p", "--proto", type=str, required=True, help="the UCSPI protocol ('tcp', 'unix', etc)"
    )
    args = parser.parse_args()

    listen_data = get_listen_address(args.ip_version)
    if listen_data is None:
        return None

    return Config(
        bindir=args.bindir.absolute(),
        listen_addr=listen_data[0],
        listen_addr_len=listen_data[1],
        listen_family=listen_data[2],
        proto=args.proto,
        utf8_env=utf8_locale.UTF8Detect().detect().env,
    )


def main() -> None:
    """Parse command-line arguments, run the tests."""
    cfg = parse_args()
    if cfg is None:
        print("No loopback interface addresses for the requested family")
        return

    ucspi_test.add_handler("tcp", TcpRunner)
    ucspi_test.run_test_handler(cfg)


if __name__ == "__main__":
    main()
